In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
batch_size = 128
latent_dim = 20
epochs = 50
# learning_rate = 1e-2

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# VAE Model
class VAE(nn.Module):
    def __init__(self,latent_dim):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 400),
            nn.ReLU(),
            nn.Linear(400, 2 * latent_dim)  # Two times latent_dim for mean and variance
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 28 * 28),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        enc_output = self.encoder(x)
        mu, logvar = enc_output[:, :latent_dim], enc_output[:, latent_dim:]
        z = self.reparameterize(mu, logvar)
        decoded = self.decoder(z)
        return decoded, mu, logvar

# Loss function
def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 28 * 28), reduction='sum')

    # See Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD



def visualize_manifold(model, lr, save_path='vae_lr_manifold.png', n=256):
    model.eval()
    with torch.no_grad():
        # Sample points from the latent space
        torch.manual_seed(0)
        latent_points = torch.randn(n, latent_dim).to(device)

        # Decode the latent points to generate images
        generated_images = model.decoder(latent_points)

        # Reshape the generated images
        generated_images = generated_images.view(-1, 1, 28, 28)

        # Save the generated images as a grid
        fname = save_path.split('_lr_')
        save_path = fname[0] + '_' + str(lr)+'_'+fname[1]
        save_image(generated_images, save_path, nrow=int(np.sqrt(n)))


# Specify the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set Hyperparameter Tuning
# lr_tune = [10**i for i in range(-5,-1)]
lr = 0.01
latent_dims = [2,5,10,20,50]

# Training loop
#for lr in lr_tune:
for latent_dim in latent_dims:
  # Initialize VAE model, optimizer, and data loader
  model = VAE(latent_dim)
  # Move the model to the specified device
  model.to(device)
  optimizer = optim.Adam(model.parameters(), lr=lr)
  tbar = tqdm(range(epochs))
  for epoch in tbar:
      model.train()
      total_loss = 0
      for batch_idx, (data, _) in enumerate(train_loader):
          data = data.to(device)
          optimizer.zero_grad()
          recon_batch, mu, logvar = model(data)
          loss = vae_loss(recon_batch, data, mu, logvar)
          loss.backward()
          total_loss += loss.item()
          optimizer.step()

      tbar.set_postfix({'latent_dimension':latent_dim,'epoch':epoch,'loss':total_loss / len(train_loader)})
      #print('Epoch {}, Loss: {:.4f}'.format(epoch + 1, total_loss / len(train_loader)))

  # Save the trained model
  torch.save(model.state_dict(), 'vae_mnist.pth')
  # ...

  # Visualize the learned manifold
  visualize_manifold(model,lr=latent_dim)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 124311064.60it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 121500194.41it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 14734130.41it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 18658696.15it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



100%|██████████| 50/50 [06:21<00:00,  7.62s/it, latent_dimension=2, epoch=49, loss=1.91e+4]
100%|██████████| 50/50 [06:24<00:00,  7.69s/it, latent_dimension=5, epoch=49, loss=1.6e+4]
100%|██████████| 50/50 [06:17<00:00,  7.54s/it, latent_dimension=10, epoch=49, loss=1.42e+4]
100%|██████████| 50/50 [06:13<00:00,  7.48s/it, latent_dimension=20, epoch=49, loss=1.43e+4]
100%|██████████| 50/50 [06:19<00:00,  7.60s/it, latent_dimension=50, epoch=49, loss=1.48e+4]
