In [None]:
from __future__ import division
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os 

from sklearn.model_selection import train_test_split

import torch
import torchvision
import torch.nn.functional as fun
import torch.nn as nn 
import torch.utils.data
from torchvision import transforms
from torchvision.datasets import MNIST

In [None]:
!pip install torchsummary

## Preprocess data

In [None]:
train_df = pd.read_csv('../input/digit-recognizer/train.csv')
#display(train_df.head())

test_df = pd.read_csv('../input/digit-recognizer/test.csv')


y = torch.tensor(train_df.label.values)
x = torch.tensor(train_df.iloc[:,1:].values)


x_tr, x_ts, y_tr, y_ts = train_test_split(x,y, test_size = 0.1, random_state=42)


print(f'Train size: {len(y_tr)} \nTest size: {len(y_ts)}')

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels, transform = None):
        'Initialization'
        self.labels = labels
        self.inp_feats = inputs
        self.transform = transform
    
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.labels)
    
    def __getitem__(self, index):
        'Generates one sample of data'
        # Select/Load sample data and get label
        X = self.inp_feats[index].type('torch.FloatTensor')
        X *= 1/255.0
        y = self.labels[index]
        
        if self.transform is not None:
            X = self.transform(X)

        return X, y

In [None]:
# Parameters
params = {'batch_size': 128,
          'shuffle': True,
          'num_workers': 0}

training_set = Dataset(x_tr, y_tr)
train_gen = torch.utils.data.DataLoader(training_set, **params, drop_last=True)

testing_set = Dataset(x_ts, y_ts)
test_gen = torch.utils.data.DataLoader(testing_set, **params, drop_last=True)

In [None]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [None]:
# define cuda device: 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Vanilla Variational Autoencoder (VAE):

In [None]:
' construct VAE architecture '

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        
        
        self.encoder = nn.Sequential(
            nn.Linear(784, width**2),
            nn.BatchNorm1d(width**2),
            nn.ReLU(),
            nn.Linear(width**2, width**2),
            nn.BatchNorm1d(width**2),
            nn.ReLU(),
            nn.Linear(width**2, width**2),
            nn.BatchNorm1d(width**2),
            nn.ReLU(),
            nn.Linear(width**2, 3*2),
            nn.BatchNorm1d(3*2)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(3, width**2),
            nn.BatchNorm1d(width**2),
            nn.ReLU(),
            nn.Linear(width**2, width**2),
            nn.BatchNorm1d(width**2),
            nn.ReLU(),
            nn.Linear(width**2, 784),
            nn.Sigmoid(),
    
        )
        
        
    def reparam_trick(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def forward(self, x):
        mu_logvar = self.encoder(x.view(-1, 784)).view(-1, 2, 3)
        mu = mu_logvar[:, 0, :]
        logvar = mu_logvar[:, 1, :]
        z = self.reparam_trick(mu, logvar)
        return self.decoder(z), z, mu, logvar
        
width = 25
# initialize the NN     
model = VAE().to(device)
#print(model)
        
from torchsummary import summary
summary(model, (1, 28*28))
        
        

In [None]:
lrate = 0.00100031312
optimizer = torch.optim.Adam(
            model.parameters(),
            lr = lrate)


In [None]:
def VAE_loss(x_tilde, x, mu, logvar, beta):
    BCE = fun.binary_cross_entropy(x_tilde, x.view(-1, 784), reduction = 'sum')
    KL = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
    return BCE + beta * KL 
    

#### Train model

In [None]:
def train_model(beta, epochs, model):
    dic = dict(latent_space = list(), mu_list=list(), logsig2_list=list(), y=list())
    for epoch in range(0, epochs + 1):
        # ========= TRAINING =========
        if epoch > 0: 
            model.train()
            train_loss = 0
            for X, _ in train_gen:
                X = X.to(device)
                # forward pass ...
                x_tilde, z, mu, logvar = model(X)
                loss = VAE_loss(x_tilde, X, mu, logvar, beta)
                train_loss += loss.item()
                # backward pass ...
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            # log ...
            print(f'----> Epoch: {epoch} Average loss: {train_loss / len(train_gen.dataset):.4f}')

        # ========= TESTING ========= 

        z_list, means, logvars , labels = list(), list(), list(), list()
        with torch.no_grad():
            model.eval()
            test_loss = 0
            for X, Y in test_gen:
                X = X.to(device)
                # forward ...
                x_tilde, z, mu, logvar = model(X)
                test_loss += VAE_loss(x_tilde, X, mu, logvar, beta).item()
                # log ...
                z_list.append(z.detach())
                means.append(mu.detach())
                logvars.append(logvar.detach())
                labels.append(Y.detach())
        # log ...
        dic['latent_space'].append(torch.cat(z_list))
        dic['mu_list'].append(torch.cat(means))
        dic['logsig2_list'].append(torch.cat(logvars))
        dic['y'].append(torch.cat(labels))
        test_loss /= len(test_gen.dataset)
        print(f'----> Test set loss: {test_loss:.4f}')
    return dic
    
beta = 1
epochs = 50
dic = train_model(beta, epochs, model)

#### Plot latent space 

In [None]:
z_arr = dic['latent_space'][0].cpu().numpy()
y_arr = dic['y'][0].cpu().numpy()
plt.figure(figsize = (10,5))
plt.subplot(1,2,1)
plt.scatter(z_arr[:,0], z_arr[:,1], c = y_arr)
plt.subplot(1,2,2)
plt.scatter(z_arr[:,1], z_arr[:,2], c = y_arr)
plt.colorbar()
plt.tight_layout()


The above latent space did not do a great job in disentangling the mnist digits; however, this should ideally be fixed by increasing the latent space dimension. For example, we expand the dimension size from 3 to width parameter, albeit we lose the ability to visualize it. If you do plan to expand the latent space dimension from 3 to width parameter, you can then implement t-SNE to map the latent space to a lower dimension so that you potentially visualize distinct MNIST digit clusters.

## Beta Variational Autoencoder

Here we introduce the $\beta$-VAE, where $\beta$>1 restricts the reconstruction accuracy and increases the degree of disentanglement of learn features. 

In [None]:
# all we need to do now is call the VAE model and define beta > 1 
width = 25
# initialize the NN     
model = VAE().to(device)
print(model)
        
from torchsummary import summary
summary(model, (1,28*28), 1)

beta = 3 # beta becomes an additional hyperparameter
epochs = 100

#### train $\beta$-VAE

In [None]:
dic = train_model(beta, epochs, model)

#### plot latent space

In [None]:
z_arr = dic['latent_space'][0].cpu().numpy()
y_arr = dic['y'][0].cpu().numpy()
plt.figure(figsize = (10,5))
plt.subplot(1,2,1)
plt.scatter(z_arr[:,0], z_arr[:,1], c = y_arr)
plt.subplot(1,2,2)
plt.scatter(z_arr[:,1], z_arr[:,2], c = y_arr)
plt.colorbar()
plt.tight_layout()


## Ref:

Overall, the workflow was adapted from Alfredo Canziani work at https://atcold.github.io/pytorch-Deep-Learning/ with changes to the preprocessing step, specific architectures details, and model extention by including $\beta$-VAE. 