In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/My \Drive/ML/

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
from tqdm.notebook import tqdm
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
lr = 0.001
num_epochs = 100 
latent_dim = 100

# MNIST dataset
batch_size = 128
image_size = 28
channels = 1
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

# VAE Encoder
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc_mu = nn.Linear(4 * 4 * 64 * 4, latent_dim)
        self.fc_logvar = nn.Linear(4 * 4 * 64 * 4, latent_dim)

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

# VAE Decoder
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 4 * 4 * 64 * 4)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(64 * 4, 64 * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64 * 4, 4, 4)
        x = self.main(x)
        return x

# VAE
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_reconst = self.decoder(z)
        return x_reconst, mu, logvar

vae = VAE().to(device)
optimizer = optim.Adam(vae.parameters(), lr=lr)

# Training Loop
LossArr = []
for epoch in tqdm(range(num_epochs)):
  loss_epoch = 0
  num_batches = 0

  for i, (images, _) in enumerate(data_loader):
    images = images.to(device)
    x_reconst, mu, logvar = vae(images)
    # Reconstruction loss and KL-divergence
    reconst_loss = nn.functional.mse_loss(x_reconst, images)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_div /= batch_size * channels * image_size * image_size

    # Total loss
    loss = reconst_loss + kl_div
    loss_epoch += loss.item()
    num_batches += 1

    # Backpropagation and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  loss_epoch /= num_batches
  LossArr.append(loss_epoch)

  # Generate and save sample images
  if (epoch + 1) % 5 == 0:
      with torch.no_grad():
          z = torch.randn(batch_size, latent_dim).to(device)
          out = vae.decoder(z)
          out = out.view(out.size(0), channels, image_size, image_size)
          vutils.save_image(out.data, f"./Samples/VAE/MNIST_samples_VAE_epoch_{epoch+1}.png", nrow=16, normalize=True)

# Save Loss
np.save('./Plots/VAE_Loss.npy', LossArr)

# Save Model
torch.save(vae.state_dict(), "./Weights/MNIST_VAE.pth")