In [43]:
import os
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image


In [44]:
#Initialize the variables
SEED = 1
BATCH_SIZE = 128
LOG_INTERVAL = 100
EPOCHS = 10

#Autoencoder bottleneck of VAE, this is 20
ZDIMS = 20

# MNIST dataset is downloaded on to required location
os.chdir("/home/CS/mnist")

#utilize cuda if available, if not use cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [45]:
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f1dcc0f2ed0>

In [46]:
# Download or load downloaded MNIST dataset shuffle data at every epoch for training dataset
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, num_workers = 4)

# Download or load downloaded MNIST dataset shuffle data at every epoch for test dataset
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, num_workers = 4)


In [47]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # ENCODER
        # 28 x 28 pixels = 784 input pixels, 400 outputs
        self.fc1 = nn.Linear(784, 400)
        # ReLU from 400 to 400: max(0, x)
        self.relu = nn.ReLU()
        self.fc21 = nn.Linear(400, ZDIMS)  # mu 
        self.fc22 = nn.Linear(400, ZDIMS)  # logvariance
        # bottlenecks to ZDIMS: 20

        # DECODER
        # from bottleneck to hidden 400: (20,400)
        self.fc3 = nn.Linear(ZDIMS, 400)
        # from hidden 400 to 784 outputs
        self.fc4 = nn.Linear(400, 784)
        self.sigmoid = nn.Sigmoid()

    def encode(self, x: Variable) -> (Variable, Variable):
        """
        Input vector x -> fully connected 1 -> ReLU -> (fully connected 21, fully connected 22)

        Input Parameter:
        ----------
        x : [128, 784] matrix; 128 digits of 28x28 pixels each

        Returns:
        --------

        (mu, logvar) : ZDIMS mean and variance units for each one for each latent dimension

        """

        # h1 is [128, 400]
        h1 = self.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu: Variable, logvar: Variable) -> Variable:
        """
        REPARAMETERIZATION TECHNIQUE:
        Parameters
        ----------
        mu : [128, ZDIMS] mean matrix
        logvar : [128, ZDIMS] variance matrix

        Returns
        -------
        During training random sample from the learned ZDIMS-dimensional normal distribution; during inference its mean.

        """

        if self.training:
            # multiply log variance with 0.5, then in-place exponent yielding the standard deviation
            std = logvar.mul(0.5).exp_()
            # eps is  [128, ZDIMS] with all elements  picked from mean 0 and std dev 1: normal distribution
            eps = Variable(std.data.new(std.size()).normal_())
            # return the sample from normal distribution with standard deviation (std) and mean (mu)
            # by multiplying mean 0, and stddev 1 
            return eps.mul(std).add_(mu)

        else:
            # During inference, the mean of the learned distribution for the current input(mu has highest probability from random sample)
            return mu

    def decode(self, z: Variable) -> Variable:
        h3 = self.relu(self.fc3(z))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x: Variable) -> (Variable, Variable, Variable):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar



In [48]:
model = VAE()
model = model.to(device)
print(model)

VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (relu): ReLU()
  (fc21): Linear(in_features=400, out_features=20, bias=True)
  (fc22): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=784, bias=True)
  (sigmoid): Sigmoid()
)


In [49]:
def loss_function(recon_x, x, mu, logvar) -> Variable:
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
    # KLD is Kullback–Leibler divergence
    # D_{KL} = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Normalise by same number of elements as in reconstruction
    KLD /= BATCH_SIZE * 784

    # BCE loss: make our reconstruction as accurate as possible
    # KLD loss: push the distributions as close as possible to unit Gaussian
    return BCE + KLD

In [50]:
# Optimize the parameteres of the model using Adam optimizer with learning rate set to 1e-3
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [51]:
def train(epoch):
    model.train()
    train_loss = 0

    # each `data` is of BATCH_SIZE samples and has shape [128, 1, 28, 28]
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data).to(device)
        optimizer.zero_grad()

        # push whole batch of data through VAE.forward() to get recon_loss
        recon_batch, mu, logvar = model(data)
        # calculate loss
        loss = loss_function(recon_batch, data, mu, logvar)
        # calculate the gradient of the loss
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data[0] / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0

    # each data is of BATCH_SIZE (default 128) samples
    for i, (data, _) in enumerate(test_loader):
        data = data.to(device)

        # Here, volatile=True, so no autograd at all required
        data = Variable(data, volatile=True)
        recon_batch, mu, logvar = model(data)
        test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
        if i == 0:
            n = min(data.size(0), 8)
            # for the first 128 batch of the epoch, show the first 8 input digits with reconstructed image
            comparison = torch.cat([data[:n],
                                  recon_batch.view(BATCH_SIZE, 1, 28, 28)[:n]])
            save_image(comparison.data.cpu(),
                     '/home/CS/results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [52]:
for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test(epoch)

    # 64 sets of random ZDIMS-float vectors
    # digits in latent space
    sample = Variable(torch.randn(64, ZDIMS)).to(device)
    sample = model.decode(sample).cpu()

    # save out as an 8x8 matrix of MNIST digits to visualize the latent space
    save_image(sample.data.view(64, 1, 28, 28),
               '/home/CS/results/sample_' + str(epoch) + '.png')



====> Epoch: 1 Average loss: 0.0016




====> Test set loss: 0.0012
====> Epoch: 2 Average loss: 0.0012
====> Test set loss: 0.0011
====> Epoch: 3 Average loss: 0.0011
====> Test set loss: 0.0010
====> Epoch: 4 Average loss: 0.0011
====> Test set loss: 0.0010
====> Epoch: 5 Average loss: 0.0011
====> Test set loss: 0.0010
====> Epoch: 6 Average loss: 0.0011
====> Test set loss: 0.0010
====> Epoch: 7 Average loss: 0.0011
====> Test set loss: 0.0010
====> Epoch: 8 Average loss: 0.0011
====> Test set loss: 0.0010
====> Epoch: 9 Average loss: 0.0011
====> Test set loss: 0.0010
====> Epoch: 10 Average loss: 0.0011
====> Test set loss: 0.0010
