In [1]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as d

import math
import random
import numpy as np

In [2]:
BATCH_SIZE = 128

# Load data

In [3]:
path_train = "data/binarized_mnist_train.amat"
path_valid = "data/binarized_mnist_valid.amat"
path_test = "data/binarized_mnist_test.amat"

In [4]:
def data_loader(path):
    with open(path_train) as file:
        x = [
            np.array(list(map(int, img.split()))).reshape(1, 28, 28)
            for img in file.readlines()
        ]
    return torch.from_numpy(np.asarray(x)).float()

In [5]:
train = torch.utils.data.TensorDataset(data_loader(path_train))
trainloader = torch.utils.data.DataLoader(
    train, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
valid = torch.utils.data.TensorDataset(data_loader(path_train))
validloader = torch.utils.data.DataLoader(
    train, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
test = torch.utils.data.TensorDataset(data_loader(path_train))
testloader = torch.utils.data.DataLoader(
    train, batch_size=BATCH_SIZE, shuffle=False)

# Train a VAE (10pts)

In [8]:
class VAE(nn.Module):
    def __init__(self, dim=100):
        super(VAE, self).__init__()

        self.dim = dim

        #ENCODER LAYERS
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64,
                      kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5),
            nn.ELU())
        self.hidden_1 = nn.Linear(256, 200)

        #DECODER LAYERS
        self.hidden_2 = nn.Linear(100, 256)
        self.decoder = nn.Sequential(
            nn.ELU(),
            nn.Conv2d(
                in_channels=256, out_channels=64, kernel_size=5, padding=4),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=64, out_channels=32, kernel_size=3, padding=2),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=32, out_channels=16, kernel_size=3, padding=2),
            nn.ELU(),
            nn.Conv2d(
                in_channels=16, out_channels=1, kernel_size=3, padding=2),
            nn.Sigmoid())

    def encode(self, x):
        q_params = self.hidden_1(self.encoder(x).view(x.shape[0], 256))
        mu = q_params[:, :self.dim]
        logvar = q_params[:, self.dim:]
        return mu, logvar

    def decode(self, z, x):
        return self.decoder(self.hidden_2(z).view(x.shape[0], 256, 1, 1))

    def sample(self, mu, logvar, x):
        eps = torch.randn(x.shape[0], self.dim).to(device)
        return mu + eps * (0.5 * logvar).exp()

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.sample(mu, logvar, x)
        x_hat = self.decode(z, x)
        return x_hat, mu, logvar

    def criterion(self, x_hat, x, mu, logvar):
        # BCE = -log(p) because gradient descent and not ascent
        BCE = F.binary_cross_entropy(x_hat, x, reduction='sum') 
        KLD = 0.5 * torch.sum(-1 - logvar + mu.pow(2) + logvar.exp())
        # criterion returns -ELBO !
        return (BCE + KLD) / x.shape[0]

    def evaluate(self, loader):
        with torch.no_grad():
            loss = 0
            for i, data in enumerate(loader, 1):
                # get the inputs
                x = data[0].to(device)
                x_hat, mu, logvar = self.forward(x)
                loss += self.criterion(x_hat, x, mu, logvar).item()
        return loss / len(loader)

In [35]:
# create model and move it to device
model = VAE()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Let's use {}".format(device))

Let's use cuda:0


In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for epoch in range(1):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        x = data[0].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        x_hat, mu, logvar = model.forward(x)

        loss = model.criterion(x_hat, x, mu, logvar)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # print statistics
    train_loss = running_loss / i
    val_loss = model.evaluate(validloader)
    print('epoch %2d: loss: %8.3f    val_loss: %8.3f    val_ELBO: %8.3f' %
          (epoch + 1, train_loss, val_loss, -val_loss))



epoch  1: loss:  214.330    val_loss:  155.447    val_ELBO: -155.447


# Evaluating log-likelihood with Variational Autoencoders (20pts)

In [73]:
def importance_sampling_estimation(model, data, M, K=200, D=784, L=100):
    estimates = []
    for _ in range(M):
        x = random.choice(list(data))[0].view(1, 1, 28, 28).to(device)
        mu, logvar = model.encode(x)
        cov_mat = torch.diag((0.5*logvar).exp()[0]).to(device)
        normal = d.MultivariateNormal(mu[0], cov_mat)
        snd = d.MultivariateNormal(torch.zeros(L).to(device), torch.eye(L).to(device))
        p = torch.zeros_like(x).to(device)
        for _ in range(K):
            z = model.sample(mu, logvar, x)
            x_hat = model.decode(z, x)
            std = (z-mu)/((0.5*logvar).exp()).to(device)
            p += (snd.log_prob(std)-normal.log_prob(z)).exp()*x_hat
        estimates.append((p / K).log())
    return estimates

In [72]:
importance_sampling_estimation(model, test, M = 4)

[tensor([[[[0.4955, 0.4929, 0.4937, 0.4945, 0.4948, 0.4949, 0.4949, 0.4949,
            0.4949, 0.4949, 0.4948, 0.4948, 0.4949, 0.4949, 0.4949, 0.4950,
            0.4950, 0.4950, 0.4951, 0.4950, 0.4950, 0.4950, 0.4950, 0.4949,
            0.4950, 0.4953, 0.4953, 0.4967],
           [0.4930, 0.4903, 0.4909, 0.4920, 0.4920, 0.4921, 0.4922, 0.4922,
            0.4922, 0.4921, 0.4921, 0.4921, 0.4921, 0.4921, 0.4922, 0.4922,
            0.4923, 0.4923, 0.4923, 0.4923, 0.4923, 0.4923, 0.4924, 0.4923,
            0.4922, 0.4925, 0.4955, 0.4967],
           [0.4946, 0.4921, 0.4934, 0.4944, 0.4937, 0.4938, 0.4939, 0.4939,
            0.4939, 0.4939, 0.4939, 0.4939, 0.4939, 0.4939, 0.4939, 0.4939,
            0.4939, 0.4940, 0.4940, 0.4941, 0.4942, 0.4942, 0.4943, 0.4943,
            0.4937, 0.4935, 0.4952, 0.4967],
           [0.4951, 0.4931, 0.4947, 0.4956, 0.4948, 0.4949, 0.4950, 0.4950,
            0.4950, 0.4950, 0.4950, 0.4951, 0.4951, 0.4951, 0.4951, 0.4950,
            0.4950, 0.4950, 0