# Tutorial 8.1: Deep Autoencoders 

**Goal**: Show the working of a deep convolutional autoencoder, and visualize embeddings in tensorboard

In [1]:
USE_NOTEBOOK = True
TRAIN_CIFAR = False
TRAIN_STL = False

## Standard libraries
import os
import json
import math
import numpy as np 
import scipy.linalg

## Imports for plotting
import matplotlib.pyplot as plt
if USE_NOTEBOOK:
    %matplotlib inline 
    from IPython.display import set_matplotlib_formats
    set_matplotlib_formats('svg', 'pdf') # For export
    from matplotlib.colors import to_rgb
    import matplotlib
    matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## Progress bar
if USE_NOTEBOOK:
    from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import pytorch_lightning as pl

import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial8"

# Function for setting the seed
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
set_seed(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

# Fetching the device that will be used throughout this notebook
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)

Using device cuda:0


In [2]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(0.5,0.5)])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)
set_seed(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
"""
class CIFAR10DataModule(pl.LightningDataModule):
    
    def __init__(self, batch_size=256):
        super().__init__()
        self.batch_size = batch_size
    
    def setup(self, stage):
        transform = transforms.Compose([transforms.ToTensor()])
        if stage == 'fit':
            # Loading the training set
            train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)
            set_seed(42)
            self.train_set, self.val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])
        if stage == 'test':
            # Loading the test set
            self.test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)
    
    def train_dataloader(self):
        train_loader = data.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True)
        return train_loader
    
    def val_dataloader(self):
        val_loader = data.DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, drop_last=False)
        return train_loader
    
    def test_dataloader(self):
        test_loader = data.DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, drop_last=False)
        return test_loader
"""

"\nclass CIFAR10DataModule(pl.LightningDataModule):\n    \n    def __init__(self, batch_size=256):\n        super().__init__()\n        self.batch_size = batch_size\n    \n    def setup(self, stage):\n        transform = transforms.Compose([transforms.ToTensor()])\n        if stage == 'fit':\n            # Loading the training set\n            train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)\n            set_seed(42)\n            self.train_set, self.val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])\n        if stage == 'test':\n            # Loading the test set\n            self.test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)\n    \n    def train_dataloader(self):\n        train_loader = data.DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True)\n        return train_loader\n    \n    def val_dataloader(self):\n        val_loader = data.

In [4]:
def wn_conv(*args, **kwargs):
    return nn.utils.weight_norm(nn.Conv2d(*args, **kwargs))

class Encoder(nn.Module):
    
    def __init__(self, num_input_channels, base_channel_size, latent_dim, act_fn=nn.ReLU):
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            wn_conv(num_input_channels, c_hid, kernel_size=5, padding=2, stride=2), # 32x32 => 16x16
            act_fn(),
            wn_conv(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(),
            nn.Linear(2*16*c_hid, latent_dim)
        )
    
    def forward(self, x):
        # This is the CIFAR10 encoder
        return self.net(x)

In [5]:
def wn_conv_trans(*args, **kwargs):
    return nn.utils.weight_norm(nn.ConvTranspose2d(*args, **kwargs))

class Decoder(nn.Module):
    
    def __init__(self, num_input_channels, base_channel_size, latent_dim, act_fn=nn.ReLU):
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            nn.ReLU()
        )
        self.net = nn.Sequential(
            wn_conv_trans(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv_trans(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            wn_conv(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv_trans(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [6]:
class Autoencoder(pl.LightningModule):
    
    def __init__(self, 
                 base_channel_size: int, 
                 latent_dim: int, 
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 num_input_channels: int = 3, 
                 width: int = 32, 
                 height: int = 32):
        super().__init__()
        print("Encoder class", encoder_class)
        self.save_hyperparameters()
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)
        
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat
    
    def _get_reconstruction_loss(self, x):
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        imgs, _ = batch # We do not need the labels
        loss = self._get_reconstruction_loss(imgs)
        
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, prog_bar=USE_NOTEBOOK)
        return result
    
    def validation_step(self, batch, batch_idx):
        imgs, _ = batch
        loss = self._get_reconstruction_loss(imgs)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss)
        return result

In [7]:
class GenerateCallback(pl.Callback):
    
    def __init__(self, input_imgs):
        super().__init__()
        self.input_imgs = input_imgs
        
    def on_epoch_end(self, trainer, pl_module):
        input_imgs = self.input_imgs.to(pl_module.device)
        with torch.no_grad():
            pl_module.eval()
            reconst_imgs = pl_module(input_imgs)
            pl_module.train()
            
        imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
        grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1,1))
        trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)

In [8]:
model = Autoencoder(base_channel_size=32, latent_dim=256)

Encoder class <class '__main__.Encoder'>


In [9]:
if TRAIN_CIFAR:
    exmp_imgs, _ = next(iter(train_loader))
    exmp_imgs = exmp_imgs[:8]
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "autoencoder"), gpus=1, max_epochs=1000, callbacks=[GenerateCallback(exmp_imgs)])
    trainer.fit(model, train_loader, val_loader)

In [10]:
# trainer.test(test_loader)

## STL10 Dataset
In contrast to CIFAR10, this dataset has 96x96 images (and a lot of unlabeled ones)

In [11]:
# Transformations applied on each image => only make them a tensor
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(0.5,0.5)])

# Loading the training dataset. We need to split it into a training and validation part
train_dataset = torchvision.datasets.STL10(root=DATASET_PATH, split='train+unlabeled', transform=transform, download=True)
set_seed(42)
train_set, val_set = torch.utils.data.random_split(train_dataset, [100000, 5000])

# Loading the test set
test_set = torchvision.datasets.STL10(root=DATASET_PATH, split='test', transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, num_workers=8)
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=8)
test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
class ResNetEncoder(nn.Module):
    
    def __init__(self, c_hidden, act_fn):
        super().__init__()
        self.net = nn.Sequential(
            act_fn(),
            wn_conv(c_hidden, c_hidden, kernel_size=3, padding=1),
            act_fn(),
            wn_conv(c_hidden, 2*c_hidden, kernel_size=1, padding=0)
        )
    
    def forward(self, x):
        val, gate = self.net(x).chunk(2, dim=1)
        return x + val * torch.sigmoid(gate)

class ResNetDecoder(nn.Module):
    
    def __init__(self, c_hidden, act_fn):
        super().__init__()
        self.net = nn.Sequential(
            wn_conv(c_hidden, 2*c_hidden, kernel_size=1, padding=0),
            act_fn(),
            wn_conv(2*c_hidden, 2*c_hidden, kernel_size=5, padding=2, groups=2*c_hidden),
            act_fn(),
            wn_conv(2*c_hidden, 2*c_hidden, kernel_size=1, padding=0)
        )
    
    def forward(self, x):
        val, gate = self.net(x).chunk(2, dim=1)
        return x + val * torch.sigmoid(gate)
    
class DeepEncoder(nn.Module):
    
    def __init__(self, num_input_channels, base_channel_size, latent_dim, act_fn=nn.GELU):
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            wn_conv(num_input_channels, c_hid, kernel_size=5, padding=2, stride=2), # 96x96 => 48x48
            act_fn(),
            wn_conv(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 48x48 => 24x24
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 24x24 => 12x12
            ResNetEncoder(c_hidden=2*c_hid, act_fn=act_fn),
            ResNetEncoder(c_hidden=2*c_hid, act_fn=act_fn),
            wn_conv(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 12x12 => 6x6
            ResNetEncoder(c_hidden=2*c_hid, act_fn=act_fn),
            ResNetEncoder(c_hidden=2*c_hid, act_fn=act_fn),
            wn_conv(2*c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Flatten(),
            nn.Linear(36*c_hid, latent_dim)
        )
    
    def forward(self, x):
        x = self.net(x)
        return x
    
class DeepDecoder(nn.Module):
    
    def __init__(self, num_input_channels, base_channel_size, latent_dim, act_fn=nn.GELU):
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 36*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            wn_conv(c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            wn_conv_trans(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 6x6 => 12x12
            ResNetDecoder(c_hidden=2*c_hid, act_fn=act_fn),
            ResNetDecoder(c_hidden=2*c_hid, act_fn=act_fn),
            ResNetDecoder(c_hidden=2*c_hid, act_fn=act_fn),
            wn_conv_trans(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 12x12 => 24x24
            ResNetDecoder(c_hidden=2*c_hid, act_fn=act_fn),
            ResNetDecoder(c_hidden=2*c_hid, act_fn=act_fn),
            ResNetDecoder(c_hidden=2*c_hid, act_fn=act_fn),
            wn_conv_trans(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 24x24 => 48x48
            ResNetDecoder(c_hidden=c_hid, act_fn=act_fn),
            ResNetDecoder(c_hidden=c_hid, act_fn=act_fn),
            ResNetDecoder(c_hidden=c_hid, act_fn=act_fn),
            wn_conv_trans(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 48x48 => 96x96
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 6, 6)
        x = self.net(x)
        return x

In [14]:
deep_model = Autoencoder(base_channel_size=32, latent_dim=512, encoder_class=DeepEncoder, decoder_class=DeepDecoder, width=96, height=96)

Encoder class <class '__main__.DeepEncoder'>


In [15]:
if TRAIN_STL:
    exmp_imgs, _ = next(iter(train_loader))
    exmp_imgs = exmp_imgs[:8]
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "deep_autoencoder"), gpus=1, max_epochs=1000, callbacks=[GenerateCallback(exmp_imgs)])
    trainer.fit(deep_model, train_loader, val_loader)