In [1]:
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as d

import math
import random
import numpy as np

In [2]:
BATCH_SIZE = 128

# Load data

In [3]:
path_train = "data/binarized_mnist_train.amat"
path_valid = "data/binarized_mnist_valid.amat"
path_test = "data/binarized_mnist_test.amat"

In [4]:
def data_loader(path):
    with open(path) as file:
        x = [
            np.array(list(map(int, img.split()))).reshape(1, 28, 28)
            for img in file.readlines()
        ]
    return torch.from_numpy(np.asarray(x)).float()

In [5]:
train = torch.utils.data.TensorDataset(data_loader(path_train))
trainloader = torch.utils.data.DataLoader(
    train, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
valid = torch.utils.data.TensorDataset(data_loader(path_valid))
validloader = torch.utils.data.DataLoader(
    valid, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
test = torch.utils.data.TensorDataset(data_loader(path_test))
testloader = torch.utils.data.DataLoader(
    test, batch_size=BATCH_SIZE, shuffle=False)

# Train a VAE (10pts)

In [9]:
class VAE(nn.Module):
    def __init__(self, dim=100):
        super(VAE, self).__init__()

        self.dim = dim

        #ENCODER LAYERS
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64,
                      kernel_size=3), nn.ELU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5),
            nn.ELU())
        self.hidden_1 = nn.Linear(256, 200)

        #DECODER LAYERS
        self.hidden_2 = nn.Linear(100, 256)
        self.decoder = nn.Sequential(
            nn.ELU(),
            nn.Conv2d(
                in_channels=256, out_channels=64, kernel_size=5, padding=4),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=64, out_channels=32, kernel_size=3, padding=2),
            nn.ELU(), nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(
                in_channels=32, out_channels=16, kernel_size=3, padding=2),
            nn.ELU(),
            nn.Conv2d(
                in_channels=16, out_channels=1, kernel_size=3, padding=2),
            nn.Sigmoid())

    def encode(self, x):
        q_params = self.hidden_1(self.encoder(x).view(x.shape[0], 256))
        mu = q_params[:, :self.dim]
        logvar = q_params[:, self.dim:]
        return mu, logvar

    def decode(self, z, x):
        # we need x to abstract from the batch size
        return self.decoder(self.hidden_2(z).view(x.shape[0], 256, 1, 1))

    def sample(self, mu, logvar, x):
        # we need x to abstract from the batch size
        eps = torch.randn(x.shape[0], self.dim).to(device)
        return mu + eps * (0.5 * logvar).exp()

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.sample(mu, logvar, x)
        x_hat = self.decode(z, x)
        return x_hat, mu, logvar

    def criterion(self, x_hat, x, mu, logvar):
        # BCE = -log(p) because gradient descent and not ascent
        BCE = F.binary_cross_entropy(x_hat, x, reduction='sum') 
        KLD = 0.5 * torch.sum(-1 - logvar + mu.pow(2) + logvar.exp())
        # criterion returns -ELBO !
        return (BCE + KLD) / x.shape[0]

    def evaluate(self, loader):
        with torch.no_grad():
            loss = 0
            for i, data in enumerate(loader, 1):
                # get the inputs
                x = data[0].to(device)
                x_hat, mu, logvar = self.forward(x)
                loss += self.criterion(x_hat, x, mu, logvar).item()
        return loss / len(loader)

In [10]:
# create model and move it to device
model = VAE()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Let's use {}".format(device))

Let's use cuda:0


In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for epoch in range(20):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        x = data[0].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        x_hat, mu, logvar = model.forward(x)

        loss = model.criterion(x_hat, x, mu, logvar)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # print statistics
    train_loss = running_loss / i
    val_loss = model.evaluate(validloader)
    print('epoch %2d: loss: %8.3f    val_loss: %8.3f    val_ELBO: %8.3f' %
          (epoch + 1, train_loss, val_loss, -val_loss))



epoch  1: loss:  222.158    val_loss:  162.954    val_ELBO: -162.954
epoch  2: loss:  147.473    val_loss:  134.615    val_ELBO: -134.615
epoch  3: loss:  126.507    val_loss:  120.244    val_ELBO: -120.244
epoch  4: loss:  116.214    val_loss:  113.082    val_ELBO: -113.082
epoch  5: loss:  110.597    val_loss:  108.831    val_ELBO: -108.831
epoch  6: loss:  107.334    val_loss:  106.028    val_ELBO: -106.028
epoch  7: loss:  105.135    val_loss:  104.327    val_ELBO: -104.327
epoch  8: loss:  103.409    val_loss:  103.136    val_ELBO: -103.136
epoch  9: loss:  102.074    val_loss:  101.656    val_ELBO: -101.656
epoch 10: loss:  101.036    val_loss:  100.551    val_ELBO: -100.551
epoch 11: loss:  100.130    val_loss:  100.067    val_ELBO: -100.067
epoch 12: loss:   99.438    val_loss:   99.362    val_ELBO:  -99.362
epoch 13: loss:   98.749    val_loss:   98.887    val_ELBO:  -98.887
epoch 14: loss:   98.186    val_loss:   98.078    val_ELBO:  -98.078
epoch 15: loss:   97.652    val_lo

# Evaluating log-likelihood with Variational Autoencoders (20pts)

In [12]:
def importance_sampling_estimation(model, data, M=10, K=200, D=784, L=100):
    with torch.no_grad():
        estimates = []
        if M < len(data):
            # randomly draw M input x_i
            x = [random.choice(list(data)) for _ in range(M)]
        else:
            x = data
        for it, xi in enumerate(x, 1):
            # send input to device
            xi = xi[0].view(1, 1, 28, 28).to(device)
            # encode the input and get the parameters of the normal distribution
            mu, logvar = model.encode(xi)
            # get variance
            var = (0.5 * logvar).exp()[0]
            # create covariance matrix (diagonal)
            cov_mat = torch.diag(var).to(device)
            # Normal(mu, diag(var)) of dimension L with
            normal = d.MultivariateNormal(mu[0], cov_mat)
            # Normal(0, I) of dimension L
            snd = d.MultivariateNormal(
                torch.zeros(L).to(device),
                torch.eye(L).to(device))
            # Initialize prob_xi
            sum_prob = 0
            # Sum K times
            for _ in range(K):
                # draw a sample from the model distribution q(z_ik|x_i)
                zik = model.sample(mu, logvar, xi)
                # compute p(z_ik)
                log_prob_zik = normal.log_prob(zik)
                # compute q(z_ik|x_i)
                log_prob_zik_xi = snd.log_prob(zik)
                # compute p(x_i|z_ik)
                xi_hat = model.decode(zik, xi)
                log_prob_xi_zik = -F.binary_cross_entropy(
                    xi_hat, xi, reduction="sum")
                sum_prob += (
                    log_prob_xi_zik + log_prob_zik - log_prob_zik_xi).exp()
            prob_xi = (sum_prob / K).log().cpu().numpy()
            estimates.append(prob_xi)
            print("\rEstimate {:6d}: {}".format(it, np.mean(estimates)), end="")
    return estimates

In [13]:
p = importance_sampling_estimation(model, test, M=len(test))

Estimate    115: -48.134845733642584

KeyboardInterrupt: 