<a href="https://colab.research.google.com/github/suvarnak/Test/blob/master/VAE_DFL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os 

EPOCHS = 10
CUDA = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
SEED = 1
LOG_INTERVAL = 10


In [0]:
torch.manual_seed(SEED)

device = torch.device("cuda" if CUDA else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:02, 3588110.31it/s]                             


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 55697.40it/s]                           
0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:01, 915968.27it/s]                             
0it [00:00, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 21124.99it/s]            

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!





In [0]:
folder = "results"

if not os.path.exists(folder):
    os.mkdir(folder)
    print("Directory " , folder ,  " Created ")
else:    
    print("Directory " , folder ,  " already exists")

Directory  results  Created 


In [0]:
# import shutil
# shutil.rmtree("results") 

In [0]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        # fc1, fc21 and fc22 are used by the encoder.
        # fc1 takes a vectorized MNIST image as input
        # fc21 and fc22 are both attached to the activation output of fc1 (using ReLU).
        # fc21 outputs the means, and fc22 the log-variances of
        # each component of th 20-dimensional latent Gaussian.
        self.fc1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3,
                  stride=(2,2))
        self.fc1_normalized = nn.BatchNorm2d(32)
        self.fc2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3,
                  stride=(2,2))
        self.fc2_normalized = nn.BatchNorm2d(64)
        self.fc21 = nn.Linear(6*6*64, self.latent_dim)
        self.fc22 = nn.Linear(6*6*64, self.latent_dim)
        # fc3 and fc4 are connected in series as the decoder.
        # fc3 takes a realization from the latent space as input
        # and the decoder generates a vectorized 28x28 image.
        # The output of fc3 passes through a ReLU,
        # while fc4 uses a sigmoid in order to output a probability for each pixel
        self.fc3 = nn.Linear(self.latent_dim, 6*6*64)
        self.fc4 = nn.ConvTranspose2d(in_channels=64, out_channels=32,
                                     kernel_size=3, stride=(2,2))
        self.fc4_normalized = nn.BatchNorm2d(32)
        self.fc5 = nn.ConvTranspose2d(in_channels=32, out_channels=28,
                                     kernel_size=3, stride=(2,2))
        self.fc5_normalized = nn.BatchNorm2d(28)
        self.fc6 = nn.ConvTranspose2d(in_channels=28, out_channels=1,
                                     kernel_size=2, stride=(1,1))

    # TODO: Implement the following four functions.  Note that they should be able to accept arguments containing stacked information for multiple observations
    # e.g. a minibatch rather than a single observation.  Your solution will need to handle this.  If you treat the arguments as
    # representing a single observation in your logic, in most cases broadcasting will do the rest of the job automatically for you.
    def encode(self, x):
        # This should return the outputs of fc21 and fc22 as a tuple
        hidden1 = F.selu(self.fc1(x))
        hidden1 = self.fc1_normalized(hidden1)       
        hidden2 = F.selu(self.fc2(hidden1))
        hidden2 = self.fc2_normalized(hidden2)               
        flatten = hidden2.view(-1, 6*6*64)
        mu = self.fc21(flatten)
        logvar = self.fc22(flatten)
        return (mu, logvar)

    def reparameterize(self, mu, logvar):
        # This should sample vectors from an isotropic Gaussian, and use these to generate
        # and return observations with a mean vectors from mu, and log-variances of log-var
        latent_dim = mu.shape
        loc = torch.zeros(latent_dim)
        covariance_matrix = torch.diag_embed(torch.ones(latent_dim))
        gaussian = torch.distributions.MultivariateNormal(loc, covariance_matrix)
        epsilon = gaussian.sample().to(device)
        latent_vector = mu + logvar * epsilon
        return latent_vector

    def decode(self, z):
        # Pass z through the decoder. For each 20-dimensional latent realization, there should be a 784-dimensional vector of
        #probabilities generated, one per pixel
        hidden1 = F.selu(self.fc3(z))
        hidden1 = hidden1.view(-1,64,6,6)
        hidden2 = F.selu(self.fc4(hidden1))
        hidden2 = self.fc4_normalized(hidden2)
        hidden3 = F.selu(self.fc5(hidden2))
        hidden3 = self.fc5_normalized(hidden3)        
        hidden4 = self.fc6(hidden3)
        realizations = torch.sigmoid(hidden4)
#         import pdb; pdb.set_trace()
        return realizations

    def forward(self, x):
        # For each observation in x:
        # 1. Pass it through the encoder to get predicted variational distribution parameters
        # 2. Reparameterize an isotropic Gaussian with these parameters to get sample latent variable realizations
        # 3. Pass the realization through the decoder to get predicted pixel probabilities
        # Return a tuple with 3 elements: (a) the predicted pixel probabilities, (b) the predicted variational means, and (c) the predicted variational log-variances
        #x = x.view(-1,784) # Reshape x to provide suitable inputs to the encoder
        mu, logvar = self.encode(x)
        latent_vector = self.reparameterize(mu, logvar)
        recon_x = self.decode(latent_vector)
        return (recon_x, mu, logvar)
        
        
latent_dimension=100
model = VAE(latent_dimension).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# TODO: Implement this loss function
def loss_function(recon_x, x, mu, logvar):
    # The loss should be (an estimate of) the negative ELBO - remember we wish to maximise the ELBO - but the ELBO can be written in a number of forms.
    # In this case, the prior for the latent variable and the variational posterior are both Gaussians, and we will exploit this.
    # Specifically, we can analytically calculate a part of the ELBO, and only use Monte Carlo estimation for the rest.
    # 1. We use the form of the ELBO which includes a KL divergence between the latent prior and the variational family
    # - see the form at the bottom of page 6 of Blei et al's "Variational Inference: A Review for Statisticians".
    # 2. In this case, the expression for the relevant KL divergence can be obtained from Exercise (e) in Week 1.
    #
    # The other term is the expected conditional log-likelihood, which is estimated using a single Monte-Carlo sample.
    # For the log-likelihood, one evaluates the probability of observing an input point given the "conditional distribution" for
    # observations output by the network - in this case, each pixel is independently Bernoulli with parameter equal to the output probability.
    # You may find torch.nn.functional's binary_cross_entropy function useful here.
    #
    # Additional: the extraction of the KL divergence as above reduces the variance.  Investigate the effect of directly estimating
    # the full ELBO term for each observation with a single Monte Carlo sample.
    #
    # You may find torch.nn.functional's binary_cross_entropy function useful.
    #
    # Return a single value accumulating the loss over the whole batch.
    #
    # Arguments:
    # x is the batch of observations
    # recon_x, mu, and logvar are the outputs of forward(x) (above) - see the usage below
    #x = x.view(-1,784) # Reshape x to provide suitable inputs to the encoder
   
    # Computing the KL term
    mu_square = mu.pow(2)
    var = logvar.exp()
    KL = .5 * torch.sum((var + mu_square - logvar - 1), dim=-1)
    KL = KL.mean()    
    # Computing the conditional log likelihood expectation
#    covariance_matrix = torch.diag_embed(logvar.exp())
#    variational_posterior = torch.distributions.MultivariateNormal(mu, covariance_matrix)
#    z_sample = variational_posterior.sample()
    log_loss = -F.binary_cross_entropy(recon_x, x, reduction="mean")    
    ELBO = log_loss + KL   
    return ELBO


In [0]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(BATCH_SIZE, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


In [0]:
if __name__ == "__main__":
    for epoch in range(1, EPOCHS + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, latent_dimension).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: -0.1601
====> Test set loss: -0.1973
====> Epoch: 2 Average loss: -0.1938
====> Test set loss: -0.1971
====> Epoch: 3 Average loss: -0.2010
====> Test set loss: -0.2087
====> Epoch: 4 Average loss: -0.2108
====> Test set loss: -0.2140
====> Epoch: 5 Average loss: -0.2124
====> Test set loss: -0.2173
====> Epoch: 6 Average loss: -0.2119
====> Test set loss: -0.2156
====> Epoch: 7 Average loss: -0.2112
====> Test set loss: -0.2152
====> Epoch: 8 Average loss: -0.2105
====> Test set loss: -0.2119
====> Epoch: 9 Average loss: -0.2099
====> Test set loss: -0.2122
====> Epoch: 10 Average loss: -0.2095
====> Test set loss: -0.2115
