<a href="https://colab.research.google.com/github/wslbooth/vae-m1-mnist/blob/main/MNIST_SSL_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Latent Feature Extraction (M1) on MNIST Dataset

The goal of this project is to implement the latent feature extraction with a classifier (M1) from [1], and see if we can acheive similar results.

##Section 1: Setup

In [None]:
import torch
from torch import nn
import torchvision
from torchvision import datasets, transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random

In [None]:
def accuracy_fn(y_true, y_pred):
  y_pred_classes = y_pred.argmax(dim=1)
  correct = torch.eq(y_true, y_pred_classes).sum().item()
  acc = (correct/len(y_pred))*100
  return acc

In [None]:
seed = 22
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor(),
    target_transform=None
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor(),
    target_transform=None
)

In [None]:
from torch.utils.data import random_split, DataLoader

l_size = int(0.01 * len(train_data))
u_size = len(train_data) - l_size
l_data, u_data = random_split(train_data, [l_size,u_size])

batch_size = 32
l_data_loader = DataLoader(dataset=l_data,
                           batch_size=batch_size,
                           shuffle=True)
train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_data,
                         batch_size=100)

In [None]:
from collections import Counter

l_data_labels = [int(train_data.targets[i]) for i in l_data.indices]
l_data_class_counts = Counter(l_data_labels)
l_data_class_counts

##Section 2: Creating Model Classes




In [None]:
class MLP(nn.Module):
  def __init__(self,
               in_dim,
               hid_dim,
               out_dim):
    super().__init__()
    self.layer_stack = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_dim,hid_dim),
        nn.ReLU(),
        nn.Linear(hid_dim,out_dim)
    )

  def forward(self,x):
    return self.layer_stack(x)

In [None]:
class Encoder(nn.Module):
  def __init__(self, in_dim, hid_dim, latent_dim):
    super().__init__()
    self.fc1 = nn.Linear(in_dim, hid_dim)
    self.fc_mu = nn.Linear(hid_dim, latent_dim)
    self.fc_logvar = nn.Linear(hid_dim, latent_dim)

  def forward(self,x):
    h = F.softplus(self.fc1(x))
    return self.fc_mu(h), self.fc_logvar(h)

def reparameterize(mu, logvar):
  std = torch.exp(0.5*logvar)
  eps = torch.randn_like(std)
  return mu + eps * std

class Decoder(nn.Module):
  def __init__(self, latent_dim, hid_dim, out_dim):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(latent_dim, hid_dim),
        nn.Softplus(),
        nn.Linear(hid_dim, out_dim),
        nn.Sigmoid()
    )
  def forward(self, z):
    return self.net(z)

class VAE(nn.Module):
  def __init__(self, in_dim, hid_dim, latent_dim, out_dim):
    super().__init__()
    self.encoder = Encoder(in_dim, hid_dim, latent_dim)
    self.decoder = Decoder(latent_dim, hid_dim, out_dim)

  def forward(self,x):
    mu, logvar = self.encoder(x)
    z = reparameterize(mu, logvar)
    x_hat = self.decoder(z)
    return x_hat, mu, logvar

In [None]:
class EncoderClassifier(nn.Module):
  def __init__(self, encoder, mlp):
    super().__init__()

    self.encoder = encoder
    self.mlp = mlp

    for param in self.encoder.parameters():
      param.requires_grad = False

  def forward(self,x):
    x = torch.flatten(x,start_dim=1)
    mu, logvar = self.encoder(x)
    return self.mlp(mu)

##Section 3: Model Training and Results

###3.1 Baseline Model Training

In [None]:
from tqdm.auto import tqdm

m0 = MLP(
    in_dim=784,
    hid_dim=64,
    out_dim=10
).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=m0.parameters(),lr=0.1)

avg_loss_vals = []
train_acc_vals = []
test_acc_vals = []

epochs = 50

for epoch in tqdm(range(epochs)):
  m0.train()
  train_loss = 0
  train_acc = 0
  for batch, (X,y) in enumerate(l_data_loader):
    X,y = X.to(device),y.to(device)
    y_pred = m0(X)
    loss = loss_fn(y_pred, y)
    acc = accuracy_fn(y, y_pred)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    train_acc += acc
  avg_loss = train_loss / len(l_data_loader)
  avg_loss_vals.append(avg_loss)
  avg_train_acc = train_acc / len(l_data_loader)
  train_acc_vals.append(avg_train_acc)

  with torch.no_grad():
    test_acc = 0
    for X,y in test_loader:
      X,y = X.to(device),y.to(device)
      preds = m0(X)
      test_acc += accuracy_fn(y, preds)
  test_acc_vals.append(test_acc/len(test_loader))

####Baseline Model Results

In [None]:
plt.figure()
plt.plot(avg_loss_vals)
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.figure()
plt.plot(train_acc_vals)
plt.xlabel("Epoch")
plt.ylabel("Training Accuracy (%)")
plt.figure()
plt.plot(test_acc_vals)
plt.xlabel("Epoch")
plt.ylabel("Testing Accuracy (%)")

In [None]:
print(train_acc_vals)
print(test_acc_vals)
print(f"Max Training Accuracy of {max(train_acc_vals)} after Epoch {np.argmax(train_acc_vals)+1}")
print(f"Max Test Accuracy of {max(test_acc_vals)} after Epoch {np.argmax(test_acc_vals)+1}")
num_params = sum(p.numel() for p in m0.parameters() if p.requires_grad)
print(num_params)

###3.2 VAE Training

In [None]:
def vae_loss(x, x_hat, mu, logvar, beta=1.0):
  recon_loss = F.mse_loss(x_hat, x, reduction='sum')
  kld = -0.5 * torch.sum(1+logvar - mu.pow(2) - logvar.exp())
  return recon_loss + beta * kld, recon_loss.item(), kld.item()

model_vae = VAE(
    in_dim=784,
    hid_dim=400,
    latent_dim=100,
    out_dim=784
).to(device)

from torch.optim.lr_scheduler import ExponentialLR
optimizer = torch.optim.Adam(model_vae.parameters(), lr=0.001)
schedular = ExponentialLR(optimizer, gamma=0.99)

In [None]:
epochs = 50

loss_vals = []
kld_vals = []
recon_loss_vals = []

for epoch in tqdm(range(epochs)):
  model_vae.train()
  total_loss = 0
  total_kld = 0
  total_recon_loss = 0
  for batch, _ in train_loader:
    batch = batch.view(-1, 784)
    batch = batch.to(device)

    x_hat, mu, logvar = model_vae(batch)
    loss, recon_loss, kld = vae_loss(batch, x_hat, mu, logvar)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    total_kld += kld
    total_recon_loss += recon_loss
  schedular.step()
  avg_loss = total_loss / len(train_loader.dataset)
  loss_vals.append(avg_loss)
  avg_kld = total_kld / len(train_loader.dataset)
  kld_vals.append(avg_kld)
  avg_recon_loss = total_recon_loss / len(train_loader.dataset)
  recon_loss_vals.append(avg_recon_loss)
  #print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, KLdivergence: {avg_kld}, Recon Loss: {avg_recon_loss}")


In [None]:
plt.figure()
plt.plot(loss_vals)
plt.xlabel("Training Loss")
plt.ylabel("Epoch")
plt.figure()
plt.plot(kld_vals)
plt.xlabel("KL-Divergence")
plt.ylabel("Epoch")
plt.figure()
plt.xlabel("Reconstruction Loss (MSE)")
plt.ylabel("Epoch")
plt.plot(recon_loss_vals)

###3.3 M1 Training

In [None]:
mlp = MLP(in_dim=100, hid_dim=64, out_dim=10)

m1 = EncoderClassifier(encoder=model_vae.encoder,
                       mlp=mlp).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=m1.parameters(), lr=0.1)

avg_loss_vals = []
train_acc_vals = []
test_acc_vals = []

epochs = 50

for epoch in tqdm(range(epochs)):
  m1.train()
  train_loss = 0
  train_acc = 0
  for batch, (X,y) in enumerate(l_data_loader):
    X,y = X.to(device),y.to(device)
    y_pred = m1(X)
    loss = loss_fn(y_pred, y)
    acc = accuracy_fn(y, y_pred)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    train_acc += acc
  avg_loss = train_loss / len(l_data_loader)
  avg_loss_vals.append(avg_loss)
  avg_train_acc = train_acc / len(l_data_loader)
  train_acc_vals.append(avg_train_acc)

  with torch.no_grad():
    test_acc = 0
    for X,y in test_loader:
      X,y = X.to(device),y.to(device)
      preds = m1(X)
      test_acc += accuracy_fn(y, preds)
  test_acc_vals.append(test_acc/len(test_loader))

####M1 Results

In [None]:
plt.figure()
plt.plot(avg_loss_vals)
plt.xlabel("Epoch")
plt.ylabel("Training Loss")
plt.figure()
plt.plot(train_acc_vals)
plt.xlabel("Epoch")
plt.ylabel("Training Accuracy (%)")
plt.figure()
plt.plot(test_acc_vals)
plt.xlabel("Epoch")
plt.ylabel("Test Accuracy (%)")

In [None]:
print(train_acc_vals)
print(test_acc_vals)
print(f"Max Training Accuracy of {max(train_acc_vals)} after Epoch {np.argmax(train_acc_vals)+1}")
print(f"Max Test Accuracy of {max(test_acc_vals)} after Epoch {np.argmax(test_acc_vals)+1}")
num_params = sum(p.numel() for p in m0.parameters() if p.requires_grad)
print(num_params)

###Discussion of Results

As we have seen, our classifier with the latent feature extractor beats our baseline model. After training each model for 50 epochs, the baseline model achieved a max test accuracy of 86.21% after epoch 41, and our M1 achieved a max test accuracy of 88.8%, after epoch 49.

Therefore, we can conclude that the latent feature extraction performed by the encoder, does increase the ability of the classifier to generalize to unseen data. So, our VAE is learning useful discriminitive latent features.

##References:
[1] Kingma, D. P., Rezende, D. J., Mohamed, S., & Welling, M. (2014).  
Semi-Supervised Learning with Deep Generative Models.  
_NeurIPS 27_.  
https://arxiv.org/abs/1406.5298

[2]Doersch, C. (2016).
Tutorial on Variational Autoencoders.
arXiv preprint arXiv:1606.05908.
https://arxiv.org/abs/1606.05908

[3]Kingma, D. P., & Welling, M. (2014).
Auto-Encoding Variational Bayes.
arXiv preprint arXiv:1312.6114.
https://arxiv.org/abs/1312.6114