In [1]:
import torch
import torch.nn as nn

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim = 20):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1*28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
        )

        self.mu = nn.Linear(64,latent_dim)
        self.logvar = nn.Linear(64,latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid(),
        )

    def encode(self, x):
        encoder = self.encoder(x)
        mu = self.mu(encoder)
        logvar = self.mu(encoder)
        return mu, logvar
    
    def decode(self,x):
        return self.decoder(x)

    def reparametrized(self, mu, logvar):
        std=torch.exp(0.5*logvar)
        eps=torch.randn_like(std)
        return mu+eps*std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrized(mu, logvar)
        return self.decoder(z), mu, logvar



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 20
model = VAE(latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def train(model, train_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.to(device)

            optimizer.zero_grad()
            x_recon, mu, log_var = model(x)

            # Reconstruction Loss (BCE)
            recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')

            # KL Divergence Loss
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

            # Total Loss
            loss = recon_loss + kl_loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_loader.dataset):.4f}")

train(model, train_loader, epochs=10)

In [None]:
with torch.no_grad():
    dataiter = iter(test_loader)  # Get a batch of test data
    images, _ = next(dataiter)

    images = images.to(device)
    outputs, _, _ = model(images)  # Get the reconstructed images

    images = images.cpu().view(-1, 28, 28)  # Reshape to 28x28
    outputs = outputs.cpu().view(-1, 28, 28)  # Reshape to 28x28

    fig, axes = plt.subplots(10, 2, figsize=(4, 20))  # 10 pairs of images
    for i in range(10):
        # Original Image
        axes[i, 0].imshow(images[i], cmap='gray')
        axes[i, 0].axis('off')
        axes[i, 0].set_title("Original")

        # Reconstructed Image
        axes[i, 1].imshow(outputs[i], cmap='gray')
        axes[i, 1].axis('off')
        axes[i, 1].set_title("Reconstructed")

    plt.tight_layout()
    plt.show()
