# Tutorial 8.1: Deep Autoencoders 

In this tutorial, we will take a closer look at autoencoders. In contrast to variational autoencoders (VAE), autoencoders are not considered as a generative model because they do not model a distribution from which we can easily sample. The latent space does not have any constraint/incentive to follow a specific distribution. However, autoencoders are still useful, in particular to represent data in lower dimensional space and compressing data. 

First of all, we again import most of our standard libraries. We will use [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) to reduce the training code overhead.

In [1]:
USE_NOTEBOOK = False
TRAIN_CIFAR = True
TRAIN_STL = False

## Standard libraries
import os
import json
import math
import numpy as np 
import scipy.linalg

## Imports for plotting
import matplotlib.pyplot as plt
if USE_NOTEBOOK:
    %matplotlib inline 
    from IPython.display import set_matplotlib_formats
    set_matplotlib_formats('svg', 'pdf') # For export
    from matplotlib.colors import to_rgb
    import matplotlib
    matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## Progress bar
if USE_NOTEBOOK:
    from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateLogger

import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial8"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

In this tutorial, we work with the CIFAR10 dataset. In CIFAR10, each image has 3 color channels and is 32x32 pixels large. As autoencoders do not have the constrain of modeling images probabilistic, we can work on more complex image data (i.e. 3 color channels instead of black-and-white) much easier than for VAEs. 

In case you have downloaded CIFAR10 already in a different directory, make sure to set DATASET_PATH accordingly to prevent another download.

In [2]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)
pl.seed_everything(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


## Building the autoencoder

In general, an autoencoder consists of an **encoder** that maps the input $x$ to a lower dimensional feature vector $z$, and a **decoder** that reconstructs the input $\hat{x}$ from $z$. We train the model by comparing $x$ to $\hat{x}$, and optimizing the parameters to increase similarity between $x$ and $\hat{x}$. See below for a small illustration of the autoencoder framework.

<span style="color:red;"> **TODO**: Add illustration of autoencoder </span>

<span style="color:red;"> **TODO**: Check whether weight norm is needed at all, and if not, remove it (less to explain) </span>

We first start by implementing the encoder. The encoder effectively consists of a deep convolutional network, where we scale down the image layer-by-layer using strided convolutions. After downscaling the image three times, we flatten the features and apply linear layers. The latent representation $z$ is therefore a vector of size *d* which can be flexible selected. 

In [3]:
def wn_conv(*args, **kwargs):
    # Convolution with weight norm applied
    return nn.utils.weight_norm(nn.Conv2d(*args, **kwargs))

class Encoder(nn.Module):
    
    def __init__(self, 
                 num_input_channels : int, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 act_fn : object = nn.GELU):
        """
        Inputs: 
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            wn_conv(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            wn_conv(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(),
            nn.Linear(2*16*c_hid, latent_dim)
        )
    
    def forward(self, x):
        return self.net(x)

The decoder is a mirrored, flipped version of the encoder. The only difference is that we replace strided convolutions by transposed convolutions (i.e. deconvolutions) to upscale the features.

In [4]:
def wn_conv_trans(*args, **kwargs):
    return nn.utils.weight_norm(nn.ConvTranspose2d(*args, **kwargs))

class Decoder(nn.Module):
    
    def __init__(self, 
                 num_input_channels : int, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 act_fn : object = nn.GELU):
        """
        Inputs: 
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            wn_conv_trans(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv_trans(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            wn_conv(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv_trans(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [5]:
class Autoencoder(pl.LightningModule):
    
    def __init__(self, 
                 base_channel_size: int, 
                 latent_dim: int, 
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 num_input_channels: int = 3, 
                 width: int = 32, 
                 height: int = 32):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)
        
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat
    
    def _get_reconstruction_loss(self, batch):
        x, _ = batch # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                         mode='min', 
                                                         factor=0.2, 
                                                         patience=50, 
                                                         min_lr=5e-5)
        return [optimizer], [scheduler]
    
    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)                             
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, prog_bar=True)
        return result
    
    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        return {'val_loss': loss, 'checkpoint_on': loss, 'log': {'val_loss': loss}}
    
    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        result = pl.EvalResult()
        result.log("test_loss", loss)
        return result

### Training the model

In [6]:
class GenerateCallback(pl.Callback):
    
    def __init__(self, input_imgs, every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs
        self.every_n_epochs = every_n_epochs
        
    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()

            imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
            grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1,1))
            trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)

In [7]:
def train_cifar(latent_dim):
    exmp_imgs, _ = next(iter(train_loader))
    exmp_imgs = exmp_imgs[:8]
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "cifar10_%i" % latent_dim), 
                         gpus=1, 
                         max_epochs=1, 
                         callbacks=[GenerateCallback(exmp_imgs, every_n_epochs=10),
                                    LearningRateLogger("epoch")],
                         benchmark=True,
                         progress_bar_refresh_rate=1 if USE_NOTEBOOK else 0)
    
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "cifar10_%i.ckpt" % latent_dim)
    if os.path.isfile(pretrained_filename):
        model = Autoencoder.load_from_checkpoint(pretrained_filename)
    else:
        model = Autoencoder(base_channel_size=32, latent_dim=latent_dim)
        trainer.fit(model, train_loader, val_loader)
    trainer.test(model, test_dataloaders=test_loader)
    return model

if TRAIN_CIFAR:
    model_384 = train_cifar(384)
    model_256 = train_cifar(256)
    model_128 = train_cifar(128)
    model_64 = train_cifar(64)
    
if not USE_NOTEBOOK:
    import sys
    sys.exit(1)

GPU available: True, used: True
I0904 15:35:31.035763 139824593631040 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0904 15:35:31.038624 139824593631040 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I0904 15:35:31.040117 139824593631040 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 496 K  | [2, 3, 32, 32] | [2, 384]      
1 | decoder | Decoder | 496 K  | [2, 384]       | [2, 3, 32, 32]
I0904 15:35:31.214221 139824593631040 lightning.py:1449] 
  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 496 K  | [2, 3, 32, 32] | [2, 384]      
1 | decoder | Decoder | 496 K  | [2, 384]       | [2, 3, 32, 32]
Saving latest checkpoint..
I0904 15:35:4

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(240.6534, device='cuda:0')}
--------------------------------------------------------------------------------


GPU available: True, used: True
I0904 15:35:42.489176 139824593631040 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0904 15:35:42.491806 139824593631040 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I0904 15:35:42.493494 139824593631040 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 365 K  | [2, 3, 32, 32] | [2, 256]      
1 | decoder | Decoder | 365 K  | [2, 256]       | [2, 3, 32, 32]
I0904 15:35:42.643676 139824593631040 lightning.py:1449] 
  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 365 K  | [2, 3, 32, 32] | [2, 256]      
1 | decoder | Decoder | 365 K  | [2, 256]       | [2, 3, 32, 32]
Saving latest checkpoint..
I0904 15:35:5

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(256.4281, device='cuda:0')}
--------------------------------------------------------------------------------


GPU available: True, used: True
I0904 15:35:53.252152 139824593631040 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0904 15:35:53.254111 139824593631040 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I0904 15:35:53.255248 139824593631040 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 233 K  | [2, 3, 32, 32] | [2, 128]      
1 | decoder | Decoder | 234 K  | [2, 128]       | [2, 3, 32, 32]
I0904 15:35:53.401491 139824593631040 lightning.py:1449] 
  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 233 K  | [2, 3, 32, 32] | [2, 128]      
1 | decoder | Decoder | 234 K  | [2, 128]       | [2, 3, 32, 32]
Saving latest checkpoint..
I0904 15:36:0

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(240.8853, device='cuda:0')}
--------------------------------------------------------------------------------


GPU available: True, used: True
I0904 15:36:04.797479 139824593631040 distributed.py:41] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I0904 15:36:04.798998 139824593631040 distributed.py:41] TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
I0904 15:36:04.799760 139824593631040 distributed.py:41] CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 168 K  | [2, 3, 32, 32] | [2, 64]       
1 | decoder | Decoder | 169 K  | [2, 64]        | [2, 3, 32, 32]
I0904 15:36:04.952379 139824593631040 lightning.py:1449] 
  | Name    | Type    | Params | In sizes       | Out sizes     
----------------------------------------------------------------------
0 | encoder | Encoder | 168 K  | [2, 3, 32, 32] | [2, 64]       
1 | decoder | Decoder | 169 K  | [2, 64]        | [2, 3, 32, 32]
Saving latest checkpoint..
I0904 15:36:1

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(246.1945, device='cuda:0')}
--------------------------------------------------------------------------------


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


### Visualizing the reconstruction

## Finding similarities

In [None]:
def embed_imgs(data_loader):
    img_embeddings = ([],[])
    model.eval()
    for imgs, _ in tqdm(data_loader):
        with torch.no_grad():
            z = model.encoder(imgs.to(model.device))
        img_embeddings[0].append(imgs)
        img_embeddings[1].append(z)
    return (torch.cat(img_embeddings[0], dim=0), torch.cat(img_embeddings[1], dim=0))

train_img_embeds = embed_imgs(train_loader)
test_img_embeds = embed_imgs(test_loader)


In [None]:
def find_similar_images(query_img, query_z, key_embeds, num_imgs=8):
    dist = torch.cdist(query_z[None,:], key_embeds[1], p=2)
    dist = dist.squeeze(dim=0)
    dist, indices = torch.sort(dist)
    
    imgs_to_display = torch.cat([query_img[None], key_embeds[0][indices[:num_imgs]]], dim=0)
    grid = torchvision.utils.make_grid(imgs_to_display, nrow=num_imgs+1, normalize=True, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(10,2))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

In [None]:
for i in range(10):
    find_similar_images(test_img_embeds[0][i], test_img_embeds[1][i], key_embeds=train_img_embeds)