<a href="https://colab.research.google.com/github/rdauncey/MNIST-LinearVAE/blob/master/MNIST_LinearVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Import Modules
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.utils import save_image
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Import GoogleDrive
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
# Parameters
batch_size_train = 128
batch_size_test = 128
learning_rate = 0.01
momentum = 0.5
log_interval = 10
n_epochs = 5
CUDA = False
SEED = 1
n_latent = 20 # this is the number of connections through the autoencoder bottleneck

torch.manual_seed(SEED)
if CUDA:
  torch.cuda.manual_seed(SEED) # these lines allow for repetition of experiment (fixes starting seed for random sampling)
  
kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {} # Will load tensors directly into GPU memory if CUDA is True


In [4]:
# Load Data

transform = transforms.ToTensor()

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                            download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size_train,
                                          shuffle=True, **kwargs)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                           download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size_test,
                                        shuffle=False, **kwargs)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [5]:
# Define VAE

nodesLayer1 = 400

class VAE(nn.Module):
  def __init__(self):
    super(VAE, self).__init__()
    
    # ENCODER
    # 28x28 pixels, one channel (grayscale) - 784 input pixels (each with one value)
    self.fc1 = nn.Linear(784, nodesLayer1)
    self.relu = nn.ReLU()
    self.fc2a = nn.Linear(nodesLayer1, n_latent) #mn layer
    self.fc2b = nn.Linear(nodesLayer1, n_latent) #sd layer
    
    # DECODER
    # from bottleneck with n_latent inputs
    self.fc3 = nn.Linear(n_latent, nodesLayer1)
    self.fc4 = nn.Linear(nodesLayer1, 784)
    self.sigmoid = nn.Sigmoid()
    
  def encode(self, x):
    x = x.view(-1,784)
    x = self.relu(self.fc1(x))
    mn = self.fc2a(x)
    sd = self.fc2b(x)
    
    return x, mn, sd
  
  def reparametrise(self, x, mn, sd):
    
    if self.training:
      epsilon = torch.randn((x.shape[0], n_latent))
      z = mn + torch.mul(epsilon, torch.exp(sd))
      
      return z
    
    else:
      return mn # during inference just spit out mean of learned distribution
  
  
  def decode(self, z):
    z = self.relu(self.fc3(z))
    z = self.sigmoid(self.fc4(z))
    
    return z
  
  def forward(self, x):
    X_out, mn, sd = self.encode(x)
    z = self.reparametrise(X_out, mn, sd)
    output = self.decode(z)
    
    return output, mn, sd # note output has not been reshaped so has dims [batch_size, 784]
  






In [6]:
# Define Loss & Optimizer

vae = VAE()
if CUDA:
  vae.cuda()
  
def loss_function(output, original, mn, sd):
  
  BCE = F.binary_cross_entropy(output, original.view(-1,784))
  KLD = -0.5 * torch.sum(1 + sd - mn.pow(2) - sd.exp())
  KLD /= batch_size_train * 784
  
  return BCE + KLD

optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

In [9]:
# Define Training & Testing

def train(epoch):
  vae.train()
  train_loss = 0
  for batch_idx, (data, label) in enumerate(trainloader):
    if CUDA:
      data = data.cuda()  
    optimizer.zero_grad()
    output, mn, sd = vae(data)
    loss = loss_function(output, data, mn, sd)
    loss.backward()
    train_loss += loss.item()
    optimizer.step()
    
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(trainloader.dataset), 100. * batch_idx / len(trainloader), loss.item() / len(data)))
      
  print('===> Epoch: {}, Average Loss {}'.format(epoch, train_loss / len(trainloader.dataset)))
  
def test(epoch):
  vae.eval()
  test_loss = 0
  for batch_idx, (data, label) in enumerate(testloader):
    if CUDA:
      data = data.cuda()
      
    with torch.no_grad():
      output, mn, sd = vae(data)
      test_loss += loss_function(output, data, mn, sd).item()
      if batch_idx == 0:
        n = min(data.size(0), 9) # show the first 9batch_size_test input digits with the reconstructed output digits below
        comparison = torch.cat([data[:n], output.view(-1, 1, 28, 28)[:n]])
        save_image(comparison.data.cpu(), '/content/gdrive/My Drive/Colab Notebooks/results/reconstruction_' + str(epoch) + '.png', nrow=n)
                
  test_loss /= len(testloader.dataset)
  print('===> Test set loss: {:.4f}'.format(test_loss))

In [10]:
##Training & Testing

for epoch in range(1, n_epochs+1):
  train(epoch)
  test(epoch)

===> Epoch: 1, Average Loss 0.0010131366096436977
===> Test set loss: 0.0010
===> Epoch: 2, Average Loss 0.0009862459012617668
===> Test set loss: 0.0009
===> Epoch: 3, Average Loss 0.0009751458714405695
===> Test set loss: 0.0009
===> Epoch: 4, Average Loss 0.0009691946056981882
===> Test set loss: 0.0009
===> Epoch: 5, Average Loss 0.000964678909505407
===> Test set loss: 0.0009
