In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [None]:

batch_size = 128
latent_dim = 10
epochs = 50
lr = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:

transform = transforms.Compose([
    transforms.ToTensor()
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)


In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = nn.Linear(28*28, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)

        # Decoder
        self.fc2 = nn.Linear(latent_dim, 400)
        self.fc3 = nn.Linear(400, 28*28)

    def encode(self, x):
        h = torch.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = torch.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [None]:
model = VAE().to(device)


In [None]:
def vae_loss(recon_x, x, mu, logvar):
    MSE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD


In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr)

train_losses = []

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for data, _ in train_loader:
        data = data.view(-1, 28*28).to(device)

        optimizer.zero_grad()
        recon, mu, logvar = model(data)
        loss = vae_loss(recon, data, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader.dataset)
    train_losses.append(avg_loss)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.2f}")


In [None]:
model.eval()
data, _ = next(iter(test_loader))
data = data.view(-1, 28*28).to(device)

with torch.no_grad():
    recon, _, _ = model(data)

# Visualization
n = 10
plt.figure(figsize=(15,4))
for i in range(n):
    # Original
    plt.subplot(2, n, i+1)
    plt.imshow(data[i].view(28,28).cpu(), cmap='gray')
    plt.axis('off')

    # Reconstructed
    plt.subplot(2, n, i+1+n)
    plt.imshow(recon[i].view(28,28).cpu(), cmap='gray')
    plt.axis('off')

plt.show()


In [None]:
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    samples = model.decode(z)

plt.figure(figsize=(6,6))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(samples[i].view(28,28).cpu(), cmap='gray')
    plt.axis('off')
plt.show()


In [None]:
plt.plot(train_losses)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("VAE Training Loss")
plt.show()


In [None]:
def vae_loss_mse(recon_x, x, mu, logvar):
    MSE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD


In [None]:
epochs = 50
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_losses_mse = []

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for data, _ in train_loader:
        data = data.view(-1, 28*28).to(device)

        optimizer.zero_grad()
        recon, mu, logvar = model(data)
        loss = vae_loss_mse(recon, data, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader.dataset)
    train_losses_mse.append(avg_loss)
    print(f"[MSE+KL] Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.2f}")


In [None]:
plt.plot(train_losses_mse)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("VAE Training Loss (MSE + KL)")
plt.show()


In [None]:
model.eval()
data, _ = next(iter(test_loader))
data = data.view(-1, 28*28).to(device)

with torch.no_grad():
    recon, _, _ = model(data)

n = 10
plt.figure(figsize=(15,4))
for i in range(n):
    plt.subplot(2, n, i+1)
    plt.imshow(data[i].view(28,28).cpu(), cmap='gray')
    plt.axis('off')

    plt.subplot(2, n, i+1+n)
    plt.imshow(recon[i].view(28,28).cpu(), cmap='gray')
    plt.axis('off')

plt.show()


In [None]:
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    samples = model.decode(z)

plt.figure(figsize=(6,6))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(samples[i].view(28,28).cpu(), cmap='gray')
    plt.axis('off')
plt.show()
