# VQ-VAE

In [1]:
## Standard libraries
import os
from copy import deepcopy
import math
import numpy as np

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
# %matplotlib inline
# from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.set()

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import STL10
from torchvision import transforms

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Import tensorboard
# %load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial13"
# In this notebook, we use data loaders with heavier computational processing. It is recommended to use as many
# workers as possible in a data loader, which corresponds to the number of CPU cores
NUM_WORKERS = os.cpu_count()

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
print("Number of workers:", NUM_WORKERS)

Global seed set to 42


Device: cuda:0
Number of workers: 16


In [2]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial13/"
# Files to download
pretrained_files = []
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)


In [3]:
data_transforms = transforms.Compose([
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (0.5,))
                                ])

train_data = STL10(root=DATASET_PATH, split='unlabeled', download=True,
                   transform=data_transforms)
test_data = STL10(root=DATASET_PATH, split='train', download=True,
                  transform=data_transforms)

def get_train_images(num):
    return torch.stack([train_data[i][0] for i in range(num)], dim=0)

Files already downloaded and verified
Files already downloaded and verified


In [24]:
class GenerateCallback(pl.Callback):

    def __init__(self, input_imgs, every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs # Images to reconstruct during training
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Reconstruct images
            pl_module.eval()
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()
            # Plot and add to tensorboard
            imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
            grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1,1))
            trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)
            pl_module.train()

In [25]:
class EncoderResNetBlock(nn.Module):

    def __init__(self, c_in, act_fn, subsample=False, c_out=-1):
        """
        Inputs:
            c_in - Number of input features
            act_fn - Activation class constructor (e.g. nn.ReLU)
            subsample - If True, we want to apply a stride inside the block and reduce the output shape by 2 in height and width
            c_out - Number of output features. Note that this is only relevant if subsample is True, as otherwise, c_out = c_in
        """
        super().__init__()
        if not subsample:
            c_out = c_in

        # Network representing F
        self.net = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=1 if not subsample else 2, bias=False),
            nn.BatchNorm2d(c_out),
            act_fn(),
            nn.Conv2d(c_out, c_out, kernel_size=3, padding=1, bias=False)
        )

        # 1x1 convolution needs to apply non-linearity as well as not done on skip connection
        self.downsample = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=1, stride=2, bias=False)
        ) if subsample else None

    def forward(self, x):
        z = self.net(x)
        if self.downsample is not None:
            x = self.downsample(x)
        out = z + x
        return out

In [76]:
class DecoderResNetBlock(nn.Module):

    def __init__(self, c_in, act_fn, c_out=-1):
        """
        Inputs:
            c_in - Number of input features
            act_fn - Activation class constructor (e.g. nn.ReLU)
            c_out - Number of output features. If c_out != c_in, a 1x1 convolution is used on the identity path.
        """
        super().__init__()
        if c_out <= 0:
            c_out = c_in

        # Network representing F
        # Note that we keep c_in as the output of the first conv, to mirror the encoder block
        self.net = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_in, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
        )

        # 1x1 convolution needs to apply non-linearity as well as not done on skip connection
        self.identity = nn.Sequential(
            nn.BatchNorm2d(c_in),
            act_fn(),
            nn.Conv2d(c_in, c_out, kernel_size=1, bias=False)
        ) if (c_in != c_out) else nn.Identity()

    def forward(self, x):
        out = self.net(x) + self.identity(x)
        return out

In [77]:
class Encoder(nn.Module):

    def __init__(self,  
                 num_blocks,
                 c_hidden,
                 vocab_dim,
                 c_in=3,
                 act_fn=nn.SiLU):
        super().__init__()
        
        layers = [nn.Conv2d(c_in, c_hidden[0], kernel_size=3, stride=2, padding=1, bias=False)]
        for block_idx, block_count in enumerate(num_blocks):
            for bc in range(block_count):
                subsample = (bc == 0 and block_idx > 0)  # Subsample the first block of each group, except the very first one.
                layers.append(
                    EncoderResNetBlock(c_in=c_hidden[block_idx if not subsample else (block_idx-1)],
                                       act_fn=act_fn,
                                       subsample=subsample,
                                       c_out=c_hidden[block_idx])
                )
        
        self.net = nn.Sequential(
            *layers,
            nn.BatchNorm2d(c_hidden[-1]),
            act_fn(),
            nn.Conv2d(c_hidden[-1], vocab_dim, 1, bias=False)
        )

    def forward(self, img):
        return self.net(img)

In [78]:
class Decoder(nn.Module):

    def __init__(self,  
                 num_blocks,
                 c_hidden,
                 vocab_dim,
                 c_out=3,
                 act_fn=nn.SiLU):
        super().__init__()
        
        layers = [nn.Conv2d(vocab_dim, c_hidden[0], 1, bias=False)]
        for block_idx, block_count in enumerate(num_blocks):
            if block_idx > 0:
                layers.append(
                    nn.Upsample(scale_factor=2.0, mode='nearest')
                )
            for bc in range(block_count):
                channel_change = (bc == 0 and block_idx > 0)
                layers.append(
                    DecoderResNetBlock(c_in=c_hidden[block_idx if not channel_change else (block_idx-1)],
                                       act_fn=act_fn,
                                       c_out=c_hidden[block_idx])
                )
        
        self.net = nn.Sequential(
            *layers,
            nn.BatchNorm2d(c_hidden[-1]),
            act_fn(),
            nn.ConvTranspose2d(c_hidden[-1], c_out, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.Tanh()
        )

    def forward(self, img):
        return self.net(img)

In [84]:
class VQEmbedding(nn.Module):
    
    def __init__(self, vocab_size, vocab_dim):
        super().__init__()
        self.codebook = nn.Embedding(vocab_size, vocab_dim)
        self.codebook.weight.data.uniform_(-1./vocab_size, 1./vocab_size)
        
    def forward(self, z):
        # To simplify the processing, we reshape z to [B, vocab_dim] where B summarizes batch, width, and height
        z_flatten = z.permute(0, 2, 3, 1).flatten(0, 2)
        # Find closest embed for each encoding
        embed_idxs = self._quantize_encodings(z_flatten)
        z_embeddings = self.codebook(embed_idxs)
        
        if z.requires_grad:
            # During training, we need to apply the straight-through estimator trick
            z_quantized = z_flatten + (z_embeddings - z_flatten).detach()
        else:
            z_quantized = z_embeddings.detach()
            
        z_quantized = z_quantized.reshape(z.shape[0], z.shape[2], z.shape[3], -1).permute(0, 3, 1, 2)
        z_embeddings = z_embeddings.reshape(z.shape[0], z.shape[2], z.shape[3], -1).permute(0, 3, 1, 2)
        
        return z_quantized, z_embeddings
            
    @torch.no_grad()
    def _quantize_encodings(self, z):
        embeds = self.codebook.weight
        
        # Distance between two vectors: (z1 - z2) ** 2 = z1**2 - 2*z1*z2 + z2**2
        embeds_sqr = (embeds ** 2).sum(dim=1)
        z_sqr = (z ** 2).sum(dim=1)
        distances = torch.addmm(embeds_sqr[None,:] + z_sqr[:,None],
                                z,
                                embeds.T,
                                alpha=-2.0,
                                beta=1.0)
        embed_idxs = distances.argmin(dim=1)
        return embed_idxs

In [97]:
class VQVAE(pl.LightningModule):

    def __init__(self, 
                 vocab_size,   
                 vocab_dim,
                 beta,
                 resnet_blocks=[3, 3, 2],
                 resnet_hidden=[32, 64, 128],
                 lr=5e-4,
                 warmup=500, 
                 max_iters=100000,
                 **kwargs):
        super().__init__()
        self.save_hyperparameters()
        # Encoder network (ResNet)
        self.encoder = Encoder(num_blocks=self.hparams.resnet_blocks,
                               c_hidden=self.hparams.resnet_hidden,
                               vocab_dim=self.hparams.vocab_dim,
                               c_in=3,
                               act_fn=nn.SiLU)
        # Vector quantized bottleneck layer
        self.vector_quantization = VQEmbedding(vocab_size=self.hparams.vocab_size,
                                               vocab_dim=self.hparams.vocab_dim)
        # Decoder network (ResNet, mirrored version of Encoder)
        self.decoder = Decoder(num_blocks=self.hparams.resnet_blocks[::-1],
                               c_hidden=self.hparams.resnet_hidden[::-1],
                               vocab_dim=self.hparams.vocab_dim,
                               c_out=3,
                               act_fn=nn.SiLU)
        
        # self.example_input_array = torch.zeros(2, 3, 96, 96)

    def forward(self, x, return_latents=False):
        z = self.encoder(x)
        z_quantized, _ = self.vector_quantization(z)
        x_rec = self.decoder(z_quantized)
        return x_rec

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = CosineWarmupScheduler(optimizer,
                                             warmup=self.hparams.warmup,
                                             max_iters=self.hparams.max_iters)
        return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]

    def _get_loss(self, batch, mode='train'):
        if isinstance(batch, (tuple, list)):  # Examples from the STL10 training set have labels
            imgs = batch[0]
        else:
            imgs = batch
        z = self.encoder(imgs)
        z_quantized, z_embeddings = self.vector_quantization(z)
        x_rec = self.decoder(z_quantized)
        
        rec_loss = F.mse_loss(x_rec, imgs)
        embed_loss = F.mse_loss(z.detach(), z_embeddings)
        quant_loss = self.hparams.beta * F.mse_loss(z, z_embeddings.detach())
        
        loss = rec_loss + embed_loss + quant_loss
        
        self.log(f'{mode}_loss', loss)
        self.log(f'{mode}_loss_rec', rec_loss)
        self.log(f'{mode}_loss_embed', embed_loss)
        self.log(f'{mode}_loss_quant', quant_loss)
        
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._get_loss(batch, mode='train')
        return loss

    def validation_step(self, batch, batch_idx):
        self._get_loss(batch, mode='val')

    def test_step(self, batch, batch_idx):
        self._get_loss(batch, mode='test')

In [98]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_iters, min_factor=0.05):
        self.warmup = warmup
        self.max_num_iters = max_iters
        self.min_factor = min_factor
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        lr_factor = lr_factor * (1 - self.min_factor) + self.min_factor
        if epoch <= self.warmup and self.warmup > 0:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [99]:
def train_vqvae(batch_size, max_epochs=500, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, 'VQVAE'),
                         gpus=1 if str(device)=='cuda:0' else 0,
                         max_epochs=max_epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode='min', monitor='val_loss_rec'),  # We are interested in the reconstruction quality
                                    LearningRateMonitor('step'),
                                    GenerateCallback(get_train_images(8), every_n_epochs=1)],
                         progress_bar_refresh_rate=1)
    trainer.logger._log_graph = True          # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, 'AE.ckpt')
    if os.path.isfile(pretrained_filename):
        print(f'Found pretrained model at {pretrained_filename}, loading...')
        model = VQVAE.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(42) # To be reproducable
        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                       drop_last=True, pin_memory=True, num_workers=NUM_WORKERS//2)
        val_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=False,
                                     drop_last=False, pin_memory=True, num_workers=NUM_WORKERS//4)
        model = VQVAE(max_iters=(max_epochs * len(train_loader)), **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = VQVAE.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    return model

In [100]:
model = train_vqvae(batch_size=64,
                    vocab_size=256,
                    vocab_dim=128,
                    beta=0.25)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name                | Type        | Params
----------------------------------------------------
0 | encoder             | Encoder     | 1.1 M 
1 | vector_quantization | VQEmbedding | 32.8 K
2 | decoder             | Decoder     | 1.4 M 
----------------------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params
10.035    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]



IsADirectoryError: [Errno 21] Is a directory: '/home/phillip/Documents/DL_course/notebook_test/docs/tutorial_notebooks/tutorial13'