In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as T

import numpy as np
import random
import matplotlib.pyplot as plt

from typing import Any, Union, Tuple, Callable, Type

In [4]:
# to ensure reproducibility
SEED = 26

# torch related
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# others
np.random.seed(SEED)
random.seed(SEED)

## Sampling module

We implement a sampling layer corresponding to the reparametrization trick.
The input for the forward procedure is a tuple containing the mean and the log of the diagonal of the covariance matrix for the posterior distribution $P(z|x)$.

You need to implement the following procedure:
* sample $\epsilon$ from a standard normal $N(0, I)$
* return $z_{mean} + \exp(0.5 \cdot z_{log-var}) \cdot \epsilon$

In [17]:
class Sampling(nn.Module):
    def __init__(self, **kwargs):
        super(Sampling, self).__init__(**kwargs)
    
    def forward(self, inputs: Tuple[torch.tensor, torch.tensor]) -> torch.tensor:
        """
        Parameters
        ----------
        inputs
            tuple containing z_mean and z_log_var representing
            the mean and the log of the diagonal of the convariance
            matrix corresponding to the posterior distribution of
            latent variable
        
        Returns
        -------
            sample from the posterior distribution
        """
        z_mean, z_log_var = inputs
        
        # TODO 1
        # For this task you have to implement the sampling procedure
        eps =  torch.rand_like(z_mean) # sample from a standard normal 
        z   =  z_mean + torch.exp(0.5 * z_log_var) * eps # use the reparametrization trick to sample from the posterior
        return z 

## Encoder Module

In [22]:
class Encoder(nn.Module):
    def __init__(self, latent_dim: int = 2, **kwargs):
        """
        Parameters
        ----------
        latent_dim
            the dimension of the laten variable. We will use 2
            for visualization purposes.
        """
        
        super(Encoder, self).__init__(**kwargs)
        
        # define network architecture
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, padding=1, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, padding=1, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 16)
        
        # TODO 2
        # Add the two heads of the encoder corresponding
        # to the mean and the log_var of the posterior distribution
        #
        # Both of the should have an input of 16 and the output 
        # equal to latent dim
        self.fc_mean    =  nn.Linear(16, latent_dim) # mean head
        self.fc_log_var =  nn.Linear(16, latent_dim) # log var head
        
        # define sampling layer
        self.sampling = Sampling()
        
    def forward(self, inputs: torch.tensor) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
        """
        Parameters
        ----------
        inputs
            tensor containing the input image (B x C x H x W)
            
        Returns
            z_mean, z_log_var, z representing the mean of the posterior,
            log of diagonal covariance matrix, sample from the posterior
        """
        
        x = F.relu(self.conv1(inputs), inplace=True)
        x = F.relu(self.conv2(x), inplace=True)
        x = self.flatten(x)
        x = F.relu(self.fc1(x), inplace=True) 
        
        
        # TODO 3
        # Write the forward procedure to obtain the
        # mean and the log var of the posterior p(z | x)
        z_mean    = self.fc_mean(x)
        z_log_var = self.fc_log_var(x)
        
        # TODO 4
        # use the mean the log_var to sample from 
        # the posterior distribution p(z | x)
        z = self.sampling((z_mean, z_log_var))
        
        return z_mean, z_log_var, z

## Decoder Module

In [23]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=2, **kwargs):
        """
        Parameters
        ----------
        latent_dim
            dimension of the latent variable
        """
        super(Decoder, self).__init__(**kwargs)
            
        # define network architecture
        self.fc1 = nn.Linear(latent_dim, 7 * 7 * 64)
        self.convtrans1 = nn.ConvTranspose2d(64, 64, kernel_size=4, padding=1, stride=2)
        self.convtrans2 = nn.ConvTranspose2d(64, 32, kernel_size=4, padding=1, stride=2)
        self.convtrans3 = nn.ConvTranspose2d(32, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, inputs: torch.tensor) -> torch.tensor:
        """
        Parameters
        ----------
        inputs
            tensor in the latent space
        
        Returns
        -------
            reconstructed tensor from input latent 
            input representation
        """
        x = F.relu(self.fc1(inputs), inplace=True)
        x = x.view(x.shape[0], 64, 7, 7)
        x = F.relu(self.convtrans1(x), inplace=True)
        x = F.relu(self.convtrans2(x), inplace=True)
        x = self.sigmoid(self.convtrans3(x))
        return x

## VAE Module

The variational loss contains a KL regularization term that has a closed for two normal distribution (the standard normal prior $p(z)$ and the posterior normal $p(z|x)$)

$$
    KL(p(z|x) || p(z)) = \sum_{i} -\frac{1}{2} [z_{log-var, i} - z_{mean, i}^2 - \exp(z_{log-var, i}) + 1]
$$

In [81]:
class VAE(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
        """
        Parameters
        ----------
        encoder
            encoder module
        decoder
            decoder module
        """
        
        super(VAE, self).__init__(**kwargs)         
        # define device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # architecture
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)
        
        # kl divergence regularizer
        self.kl_loss = None
       
        
    def forward(self, inputs: torch.tensor) -> torch.tensor:
        """
        Parameters
        ----------
        inputs
            tensor in the input space B x C x H x W
        
        Returns
        -------
            reconstruction the the input tensor B x C x H x W
        """
        z_mean, z_log_var, z = self.encoder(inputs)
        rec_inputs = self.decoder(z)
        
        # TODO 6
        # Compute kl divergence regularization loss
        # 1. Use the formula and first obtain a vector of size (B, latent_dim)
        # 2. Do the summation over the first dimension
        # 3. Compute the bacth mean
        # 
        # We will use this in the training loop
        self.kl_loss =  - 0.5 * (z_log_var - torch.square(z_mean) - torch.exp(z_log_var) + 1.) # here you should have vector 
#         print(self.kl_loss.size())
        self.kl_loss =  torch.sum(self.kl_loss, dim=1) # sum over the dimension 1 
#         print(self.kl_loss)
        self.kl_loss =  torch.mean(self.kl_loss) # compute the mean of the batch
#         print(self.kl_loss)
        
        return rec_inputs
    
    
    def compile(self, optimizer: Type[torch.optim.Optimizer], 
                loss: Type[torch.nn.modules.loss._Loss], lr: float = 1e-3):
        """
        Compile function
        
        Parameters
        ----------
        optimizer
            class of the optimizer (e.g. torch.optim.Adam)
        loss
            class of the loss (e.g. torch.loss.BCELoss)
        """
        self.optimizer = optimizer(self.parameters(), lr=lr)
        self.loss = loss(reduction="sum")
    
    
    @staticmethod
    def __running_loss(r_loss: float, loss: float) -> float:
        """
        Update running loss
        
        Parameters
        ----------
        r_loss
            current running loss
        loss
            current loss
        
        Returns
        -------
        Updated running loss
        """
        
        r_loss = loss if r_loss is None else (0.99 * r_loss + 0.01 * loss)
        return r_loss
    
    
    def train_step(self, x: torch.tensor) -> Tuple[float, float, float]:
        """
        Performs a training step
        
        Parameters
        ----------
        x
            batch of input data B x C x H x W
        
        Returns
        -------
        Tuple containing total loss, reconstruction loss, and KL loss
        """
        
        self.train()
        x = x.to(self.device)
        
        # compute the reconstruction loss
        rec_x = self.forward(x)
        rec_loss = self.loss(rec_x, x) / x.shape[0]
        
        # TODO 7
        # Compute the final loss as the sum between
        # the reconstruction loss and the kl regularization loss
        loss = rec_loss + self.kl_loss

        # gradient step
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item(), rec_loss.item(), self.kl_loss.item()

    
    def fit(self, dataset: Dataset, batch_size: int, epochs: int, 
            num_workers: int = 4, log_int: int = 100):
        """
        Training loop
        
        Parameters
        ----------
        dataset
            training dataset
        batch_size
            training batch size
        epochs
            number of training epochs
        num_workers
            number of workers for dataloader
        log_int
            logging interval
        """
        
        # define dataloader
        dataloader = DataLoader(dataset, batch_size=batch_size, 
                                shuffle=True, num_workers=num_workers)
        
        # running loss
        r_loss = None
        r_rec_loss = None
        r_kl_loss = None
        
        # training loop
        for epoch in range(epochs):
            for step, data in enumerate(dataloader):
                X, _ = data
                loss, rec_loss, kl_loss = self.train_step(X)
                
                # update running loss
                r_loss = VAE.__running_loss(r_loss, loss)
                r_rec_loss = VAE.__running_loss(r_rec_loss, rec_loss)
                r_kl_loss = VAE.__running_loss(r_kl_loss, kl_loss)
                
                # logging
                if step % log_int == 0:
                    print("Epoch: %d\t Step: %d\t Loss: %.3f\t RecLoss: %.3f\t KLLoss: %.3f"\
                          % (epoch, step, r_loss, r_rec_loss, r_kl_loss))
                

## Train

In [None]:
# define transformation
# feel free to add any other transformations
transform = T.Compose([
    T.ToTensor(),
])

# load the dateset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# define the autoencoder
enc = Encoder()
dec = Decoder()
vae = VAE(enc, dec)

# define optimizer and reconstruction loss
# notice that the kl loss is incorporated into
# the model itself
optimizer = torch.optim.Adam
loss_fn = torch.nn.BCELoss

# compile the model
vae.compile(optimizer=optimizer, loss=loss_fn, lr=1e-3)

# fit the model
vae.fit(trainset, batch_size=64, epochs=30, log_int=500)

Epoch: 0	 Step: 0	 Loss: 514.278	 RecLoss: 514.250	 KLLoss: 0.028
Epoch: 0	 Step: 500	 Loss: 186.611	 RecLoss: 186.333	 KLLoss: 0.278
Epoch: 1	 Step: 0	 Loss: 180.516	 RecLoss: 180.234	 KLLoss: 0.282
Epoch: 1	 Step: 500	 Loss: 176.968	 RecLoss: 176.645	 KLLoss: 0.323
Epoch: 2	 Step: 0	 Loss: 174.088	 RecLoss: 173.740	 KLLoss: 0.348
Epoch: 2	 Step: 500	 Loss: 171.609	 RecLoss: 171.250	 KLLoss: 0.358
Epoch: 3	 Step: 0	 Loss: 170.834	 RecLoss: 170.468	 KLLoss: 0.367
Epoch: 3	 Step: 500	 Loss: 169.098	 RecLoss: 168.734	 KLLoss: 0.364
Epoch: 4	 Step: 0	 Loss: 168.979	 RecLoss: 168.630	 KLLoss: 0.349
Epoch: 4	 Step: 500	 Loss: 167.930	 RecLoss: 167.589	 KLLoss: 0.341


## Visualization

For the visualization procedure, we will sample from an uniform grid $[-1, 1] \times [-1, 1]$ and we will pass this latent representation through the decoder to generate images in of size $28 \times 28$.

In [None]:
def plot_img(x: torch.tensor, nrow):
    """
    Plots a batch of reconstructed data in a 
    grid formation
    
    Parameters
    ----------
    x
        batch of reconstructed data B x C x H x W
    nrow
        number of elements in a row
    """
    img = torchvision.utils.make_grid(x, nrow=nrow)
    npimg = np.transpose(img.numpy(), (1, 2, 0))

    plt.figure(figsize = (15, 15))
    plt.imshow(npimg)
    plt.show()

In [None]:
# TODO 8
# Run the visualization procedure from below

# construct sample grid
scale = 1.
samples = 20 

x = np.linspace(-scale, scale, samples)
y = np.linspace(-scale, scale, samples)
xv, yv = np.meshgrid(x, y)
xv, yv = xv.reshape(-1, 1), yv.reshape(-1, 1)

# construct the laent space
z = np.concatenate([xv, yv], axis=1)
z = torch.tensor(z)

# set model to evauation
vae = vae.eval()
z = z.float().to(vae.device)

with torch.no_grad():
    x = vae.decoder(z).cpu()
    
# plot the image
plot_img(x, samples)