In [None]:
from typing import Callable, Tuple

import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms

# Assignment 3

This assignment will briefly cover two major types of deep learning models:
discriminative and generative. You will first implement a simple convolutional
neural network (CNN) to classify handwritten digit images from the MNIST
dataset. Then, you will implement a varitional autoencoder (VAE) to generate
images of handwritten digits that mimic the ones in MNIST.

**Models need to be implemented in PyTorch**. If you are new to PyTorch, the
[official tutorial](https://pytorch.org/tutorials/beginner/basics/intro.html)
is a great place to get started.

You need a CUDA-compatible GPU to train the models. If your own computer is
not equipped with one, you can finish this assignment using Google Colab.

## 1. Image Classification with CNN

### 1a. Load and Visualize Data (5 points)

Let's first load the MNIST dataset with the help of TorchVision.

In [None]:
transform = transforms.ToTensor()
dataset = torchvision.datasets.MNIST(
    root='data', train=True, transform=transform, download=True)

In the following cell, pick 10 samples from the dataset, and visualize the
images and the corresponding labels using Matplotlib.

In [None]:
fig, axes = plt.subplots(1, 10, figsize=(10, 1))
for i in range(10):
    # TODO
plt.show()

We now split the original training set into a training set and a validation set.
The model will be trained on the training set, and the validation set will be
used to tune any hyperparameters.

In [None]:
# Split the original training set into a training set and a validation set.
generator = torch.Generator().manual_seed(1231)
num_train_samples = int(0.8 * len(dataset))
num_val_samples = len(dataset) - num_train_samples
train_dataset, val_dataset = torch.utils.data.random_split(dataset,
    [num_train_samples, num_val_samples], generator=generator)

device = torch.device('cuda')
# Feel free to change these parameters.
train_kwargs = {
    'batch_size': 64,
    'shuffle': True,
    'num_workers': 1,
    'pin_memory': True,
}
test_kwargs = {
    'batch_size': 64,
    'shuffle': False,
    'num_workers': 1,
    'pin_memory': True,
}

# Create data loaders.
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

### 1b. Implement CNN (20 points)

In the following cell, implement a neural network that can classify MNIST images
into 10 classes corresponding to digits 0-9. The network should satifies the
following requirements:
- It must contain convolutional layers.
- It takes single-channel images of size 28 x 28 as input, and outputs
  a 10-dimensional score vector for each sample in the batch.
- The output scores should be unnormalized logits (i.e., not the output of a
  softmax layer).

In [None]:
class CNN(nn.Module):
    def __init__(self):
        """"""
        # TODO

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        # TODO

### 1c. Train the Model (20 points)

Complete the training loop in the following cell and train your model.

Notice the model outputs unnormalized logits, so be sure to choose the
approapriate loss function (`criterion` in the code). You are allowed to use
existing loss functions provided by PyTorch (either the Module version in
`torch.nn` or the function version in `torch.nn.functional`).

You are encouraged to experiment with different optimizers and hyperparameters,
and use any optimizer of your choice in your submission.

In [None]:
def train(
    model: nn.Module,
    data_loader: torch.utils.data.DataLoader,
    criterion: Callable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    log_interval: int,
) -> None:
    model.train()
    for batch_idx, (data, target) in enumerate(data_loader, start=1):
        loss =  # TODO

        if batch_idx % log_interval == 0:
            print("Epoch {:>2d} [{:>6,}/{:>6,}] loss={:.3f}".format(
                epoch,
                batch_idx * len(data),
                len(data_loader.dataset),
                loss.item(),
            ))


def test(
    model: nn.Module,
    data_loader: torch.utils.data.DataLoader,
    device: torch.device,
) -> None:
    model.eval()
    num_correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            # TODO

    print("Test accuracy: {:>6,}/{:>6,} ({:.2f}%)".format(
        num_correct,
        len(data_loader.dataset),
        100. * num_correct / len(data_loader.dataset),
    ))

# TODO
num_epochs =
learning_rate =
# Other hyperparameters:

cnn =
criterion =
optimizer =

for epoch in range(1, num_epochs+1):
    train(cnn, train_loader, criterion, optimizer, device, epoch, log_interval=150)
    test(cnn, val_loader, device)

torch.save(cnn.state_dict(), 'output/mnist_cnn.pth')

### 1d. Evaluate the Model on the Test Set (5 points)

In the following cell, load the test split of the MNIST dataset and evaluate
your trained model on it. Report the accuracy of your model.

The test set is supposed to be held out until you have finished model training.
If you have hyperparameters to tune, please do so on the validation set instead
of the test set.

To receive full credit, your model should achieve an accuracy of at least 97.00%
on the test set.

In [None]:
test_dataset =  # TODO
test_loader =  # TODO
test(cnn, test_loader, device)

## 2. Image Generation with VAE

In this part, you will implement a variational autoencoder (VAE) that can
generate images of handwriten digits as if they were drawn from the actual MNIST
dataset.

### 2a. Evaluation Metric: Inception Score (5 points)

Before building the model, we first introduce how the generated samples will be
quantitatively evaluated. Specifically, we use the *Inception Score* introduced
in [[link](https://arxiv.org/abs/1606.03498)]. The name comes from applying an
Inception network (pretrained on ImageNet) on generated images, and making use
of statistics of the predicted class probabilities. In this assignment, you're
not going to actually use the Inception network. Instead, you will use the CNN
model you implemented in Q1 as the scoring model.

The idea behind the Inception Score is simple. Ideally, we want the generated
samples to be (1) realistic and (2) diverse. For one specific generated image to
be realistic (i.e., actually look like a handwritten digit), the CNN model
should predict a high score for one and only one of the 10 classes. For the set
of generated images to be diverse, the probabilities should spread out across
the 10 classes when averaged over all generated images. Quantitatively, these
translate to (1) the probability distribution for each sample having low
entropy, and (2) the average probability distribution over all samples having
high entropy.

To aid your implementation, the following cell already implements the scoring
function for you.

In [None]:
import numpy as np
from scipy.stats import entropy


def compute_inception_score(
    scoring_model: nn.Module,
    data_loader: torch.utils.data.DataLoader,
    device: torch.device,
    num_splits: int = 10,
) -> Tuple[float, float]:

    scoring_model.eval()
    probs = []
    with torch.no_grad():
        for data in data_loader:

            # HACK: Samples can be (image, label) pairs like those from the
            # TorchVision datasets, or simply images from data loaders directly
            # created on generated image samples.
            if isinstance(data, (list, tuple)):
                data = data[0]

            data = data.to(device)
            logits = scoring_model(data)
            batch_probs = torch.nn.functional.softmax(logits, dim=1)
            probs.append(batch_probs)
    probs = torch.cat(probs, dim=0).cpu().numpy()

    split_scores = []
    for i in range(num_splits):
        n = len(probs) // num_splits
        split_probs = probs[i*n:(i+1)*n]
        # Ideally, high entropy with the averaged probabilities (diverse
        # samples), and low average entropy with probabilities of individual
        # samples (high quality samples).
        log_scores = entropy(np.mean(split_probs, axis=0)) \
            - np.mean(entropy(split_probs, axis=1))
        split_scores.append(np.exp(log_scores))
    split_scores = np.array(split_scores)

    mean = split_scores.mean()
    std = split_scores.std()
    return mean, std

In the following cell, evaluate the Inception Score on real MNIST images. You
should use the CNN model you implemented and trained in Q1 as the scoring model,
and use all images in the *test* split of the MNIST dataset.

In [None]:
mean, std =  # TODO
print("Inception score: {:.2f}±{:.2f}".format(mean, std))

### 2b. Implement VAE (20 points)

In the following cell, implement a variational autoencoder (VAE) that can
generate images of handwritten digits. The VAE should satisfy the following
requirements:

- Its encoder should contain convolutional layers.
- Its decoder should contain transposed convolutional layers.
- The output of the decoder should be unnormalized logits for each pixel.

In [None]:
class VAE(nn.Module):
    """A simple variational autoencoder."""
    def __init__(
        self,
        latent_dim,
    ):
        """
        Args:
        - latent_dim: The dimension of the latent space.
        """
        super().__init__()
        self.latent_dim = latent_dim
        # TODO
        self.encoder =
        self.fc_mu =
        self.fc_logvar =
        self.decoder =

    def encode(
        self,
        x: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Input:
        - x: (N, C, W, H), a batch of input images.

        Returns:
        - mu: (N, D), the mean of the latent distribution.
        - logvar: (N, D), the log-variance of the latent distribution.
        """
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(
        self,
        mu: torch.Tensor,
        logvar: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input:
        - mu: (N, D), the mean of the latent distribution.
        - logvar: (N, D), the log-variance of the latent distribution.

        Returns:
        - z: (N, D), a sample from the latent distribution.
        """
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(
        self,
        z: torch.Tensor,
    ) -> torch.Tensor:
        """
        Input:
        - z: (N, D), a sample from the latent distribution.

        Returns:
        - recon: (N, C, W, H), unnormalized logits of the reconstructed images.
        """
        return self.decoder(z)

    def forward(
        self,
        x: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

### 2c. Train the Model (20 points)

Complete the training loop in the following cell and train your VAE. The loss
function has already been implemented for you. If you are interested in how this
loss function is derived, the following tutorial is a good reference:
[[link](https://arxiv.org/abs/1606.05908)].

After each epoch, you should generate some sample images using the current
decoder and visualize them in the output cell.

In [None]:
def vae_loss(
    image: torch.Tensor,
    recon: torch.Tensor,
    mu: torch.Tensor,
    logvar: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    loss_recon = nn.functional.binary_cross_entropy_with_logits(
        recon, image, reduction='sum')
    loss_kldiv = 0.5 * torch.sum(mu**2 + logvar.exp() - 1 - logvar)

    loss_recon /= image.size(0)
    loss_kldiv /= image.size(0)
    loss = loss_recon + loss_kldiv

    return loss, loss_recon.detach(), loss_kldiv.detach()


def train(
    model: nn.Module,
    data_loader: torch.utils.data.DataLoader,
    criterion: Callable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    log_interval: int,
) -> None:
    model.train()
    for batch_idx, (data, _) in enumerate(data_loader, start=1):
        # TODO
        loss, loss_recon, loss_kldiv =

        if batch_idx % log_interval == 0:
            print("Epoch {:>2d} [{:>6,}/{:>6,}] ".format(
                epoch,
                batch_idx * len(data),
                len(data_loader.dataset),
            ) + ' '.join(
                f"{name}={value:.3f}" for name, value in [
                    ('loss', loss.item()),
                    ('loss_recon', loss_recon.item()),
                    ('loss_kldiv', loss_kldiv.item()),
            ]))


def sample_and_visualize(
    model: nn.Module,
    device: torch.device,
    num_rows: int = 4,
    num_cols: int = 8,
) -> None:
    num_samples = num_rows * num_cols
    model.eval()
    with torch.no_grad():
        # TODO
        samples =
        # Normalize the pixels to [0, 1] since the model outputs logits.
        samples = samples.sigmoid()

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols, num_rows))
    for i, ax in enumerate(axes.flat):
        ax.imshow(samples[i, 0].cpu().detach().numpy(), cmap='gray')
        ax.axis('off')
    plt.show()


# TODO
# Feel free to add other hyperparameters.
num_epochs: int =
learning_rate: float =

vae =
criterion = vae_loss
optimizer =

for epoch in range(1, num_epochs+1):
    train(vae, train_loader, criterion, optimizer, device, epoch, 150)
    sample_and_visualize(vae, device)

torch.save(vae.state_dict(), 'output/mnist_vae.pth')

### 2d. Generate Images and Evaluate (5 points)

In the following cell, generate 10,000 sample images using your trained VAE.
Calculate and report the Inception Score on the generated images.

To receive full credit, your VAE should achieve a mean Inception Score of at
least 3.00.

In [None]:
vae.eval()
with torch.no_grad():
    # TODO
    samples =

data_loader = torch.utils.data.DataLoader(samples, batch_size=64)
mean, std = compute_inception_score(cnn, data_loader, device)
print("Inception score: {:.2f}±{:.2f}".format(mean, std))