# VQ-VAE

In [1]:
## Standard libraries
import os
from copy import deepcopy
import math
import numpy as np

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.set()

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import STL10
from torchvision import transforms

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Import tensorboard
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial13"
# In this notebook, we use data loaders with heavier computational processing. It is recommended to use as many
# workers as possible in a data loader, which corresponds to the number of CPU cores
NUM_WORKERS = os.cpu_count()

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
print("Number of workers:", NUM_WORKERS)

Global seed set to 42


Device: cuda:0
Number of workers: 16


In [2]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial13/"
# Files to download
pretrained_files = []
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)


In [15]:
data_transforms = transforms.Compose([
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,), (0.5,))
                                ])

train_data = STL10(root=DATASET_PATH, split='unlabeled', download=True,
                   transform=data_transforms)
test_data = STL10(root=DATASET_PATH, split='train', download=True,
                  transform=data_transforms)

def get_train_images(num):
    return torch.stack([train_data[i][0] for i in range(num)], dim=0)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
class GenerateCallback(pl.Callback):

    def __init__(self, input_imgs, every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs # Images to reconstruct during training
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Reconstruct images
            pl_module.eval()
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()
            # Plot and add to tensorboard
            imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
            grid = torchvision.utils.make_grid(imgs, nrow=2, normalize=True, range=(-1,1))
            trainer.logger.experiment.add_image("Reconstructions", grid, global_step=trainer.global_step)
            pl_module.train()

In [4]:
class BatchNorm2d(nn.Module):
    
    def __init__(self, c_in):
        super().__init__()
        self.norm = nn.BatchNorm2d(c_in, momentum=0.02)
        
    def forward(self, x):
        return self.norm(x)

    
class ResidualBlock(nn.Module):

    def __init__(self, net, skip_connect=None):
        super().__init__()
        self.net = net
        self.skip_connect = skip_connect if skip_connect is not None else nn.Identity()

    def forward(self, x):
        return self.skip_connect(x) + self.net(x)
    

class PreActResNetBlock(ResidualBlock):
    
    def __init__(self, c_in, c_out, stride=1, act_fn=nn.SiLU):
        net = nn.Sequential(
                BatchNorm2d(c_in),
                act_fn(),
                nn.Conv2d(c_in, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
                BatchNorm2d(c_out),
                act_fn(),
                nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False)
          )
        if stride == 1 and c_in == c_out:
            skip_connect = None
        else:
            skip_connect = nn.Sequential(
                                BatchNorm2d(c_in),
                                act_fn(),
                                nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, bias=False),
                               )
        super().__init__(net=net, skip_connect=skip_connect)

In [5]:
class Encoder(nn.Module):

    def __init__(self, 
                 c_hid=32, 
                 num_latents=64,
                 c_in=3,
                 num_blocks=2,
                 width=32,
                 act_fn=nn.SiLU):
        super().__init__()
        num_layers = int(math.ceil(np.log2(width) - 2))
        end_width = width // 2 ** num_layers
        print('Width', width, 'Num layers', num_layers, 'End width', end_width)
        
        net = [nn.Conv2d(c_in, c_hid, kernel_size=3, padding=1)]
        for l_idx in range(num_layers):
            for b_idx in range(num_blocks):
                use_stride = (b_idx == num_blocks-1)
                stride = 2 if use_stride else 1
                res_c_in = c_hid * 2**l_idx
                res_c_out = c_hid * 2**(l_idx + (1 if use_stride else 0))
                net.append(PreActResNetBlock(c_in=res_c_in, c_out=res_c_out, stride=stride, act_fn=act_fn))
        
        self.net = nn.Sequential(
            *net,
            BatchNorm2d(res_c_out),
            act_fn(),
            nn.Conv2d(res_c_out, c_hid * 2, 1, bias=False),
            nn.Flatten(),
            act_fn(),
            nn.Linear((end_width**2) * c_hid * 2, c_hid * 2**num_layers),
            act_fn(),
            nn.Linear(c_hid * 2**num_layers, num_latents)
        )

    def forward(self, img):
        return self.net(img)

In [6]:
class Decoder(nn.Module):

    def __init__(self, 
                 c_hid=32, 
                 num_latents=64,
                 width=96,
                 act_fn=nn.SiLU,
                 num_blocks=2,
                 c_out=3):
        super().__init__()
        self.width = width
        num_layers = int(math.ceil(np.log2(width) - 2))
        end_width = width // 2 ** num_layers
        self.end_width = end_width
        
        self.linear = nn.Sequential(
            nn.Linear(num_latents, c_hid * 2**num_layers),
            act_fn(),
            nn.Linear(c_hid * 2**num_layers, (end_width**2) * c_hid * 2)
        )
        net = [
            act_fn(),
            nn.Conv2d(c_hid * 2, c_hid * 2**num_layers, 1, bias=False)
        ]
        for l_idx in reversed(range(num_layers)):
            net.append(nn.Upsample(scale_factor=2.0, mode='nearest'))
            for b_idx in range(num_blocks):
                res_c_in = c_hid * 2**(l_idx + (1 if b_idx == 0 else 0))
                res_c_out = c_hid * 2**l_idx
                net.append(PreActResNetBlock(c_in=res_c_in, c_out=res_c_out, stride=1, act_fn=act_fn))
        self.net = nn.Sequential(
            *net,
            BatchNorm2d(c_hid),
            act_fn(),
            nn.Conv2d(c_hid, c_hid, 1),
            BatchNorm2d(c_hid),
            act_fn(),
            nn.Conv2d(c_hid, c_out, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, self.end_width, self.end_width)
        x = self.net(x)
        return x

In [11]:
class SimpleAE(pl.LightningModule):

    def __init__(self, c_hid=32, 
                       num_latents=64, 
                       lr=5e-4, 
                       c_in=3,
                       warmup=500, max_iters=100000,
                       img_width=64,
                       **kwargs):
        super().__init__()
        self.save_hyperparameters()
        
        self.encoder = Encoder(num_latents=self.hparams.num_latents,
                                  c_hid=self.hparams.c_hid,
                                  c_in=self.hparams.c_in,
                                  width=self.hparams.img_width,
                                  num_blocks=2,
                                  act_fn=nn.SiLU)
        self.decoder = Decoder(num_latents=self.hparams.num_latents,
                                  c_hid=self.hparams.c_hid,
                                  c_out=self.hparams.c_in,
                                  width=self.hparams.img_width,
                                  num_blocks=2,
                                  act_fn=nn.SiLU)
        self.example_array = torch.zeros(2, c_in, img_width, img_width)

    def forward(self, x):
        z = self.encoder(x)
        x_rec = self.decoder(z)
        return x_rec

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = CosineWarmupScheduler(optimizer,
                                             warmup=self.hparams.warmup,
                                             max_iters=self.hparams.max_iters)
        return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]

    def _get_loss(self, batch, mode='train'):
        if isinstance(batch, (tuple, list)):
            imgs = batch[0]
        else:
            imgs = batch
        x_rec = self.forward(imgs)
        loss = F.mse_loss(x_rec, imgs)
        self.log(f'{mode}_loss', loss)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._get_loss(batch, mode='train')
        return loss

    def validation_step(self, batch, batch_idx):
        self._get_loss(batch, mode='val')

    def test_step(self, batch, batch_idx):
        self._get_loss(batch, mode='test')

In [12]:
class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_iters, min_factor=0.05):
        self.warmup = warmup
        self.max_num_iters = max_iters
        self.min_factor = min_factor
        super().__init__(optimizer)

    def get_lr(self):
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]

    def get_lr_factor(self, epoch):
        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
        lr_factor = lr_factor * (1 - self.min_factor) + self.min_factor
        if epoch <= self.warmup and self.warmup > 0:
            lr_factor *= epoch * 1.0 / self.warmup
        return lr_factor

In [17]:
def train_ae(batch_size, max_epochs=500, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, 'AE'),
                         gpus=1 if str(device)=='cuda:0' else 0,
                         max_epochs=max_epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode='min', monitor='val_loss'),
                                    LearningRateMonitor('step'),
                                    GenerateCallback(get_train_images(8), every_n_epochs=1)],
                         progress_bar_refresh_rate=1)
    trainer.logger._log_graph = True          # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None  # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, 'AE.ckpt')
    if os.path.isfile(pretrained_filename):
        print(f'Found pretrained model at {pretrained_filename}, loading...')
        model = SimCLR.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(42) # To be reproducable
        train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                       drop_last=True, pin_memory=True, num_workers=NUM_WORKERS//2)
        val_loader = data.DataLoader(test_data, batch_size=batch_size, shuffle=False,
                                     drop_last=False, pin_memory=True, num_workers=NUM_WORKERS//4)
        model = SimpleAE(max_epochs=max_epochs, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = SimpleAE.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    return model

In [18]:
model = train_ae(batch_size=64,
                 num_latents=64,
                 c_hid=32,
                 img_width=96)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 42


Width 96 Num layers 5 End width 3


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 26.6 M
1 | decoder | Decoder | 17.1 M
------------------------------------
43.7 M    Trainable params
0         Non-trainable params
43.7 M    Total params
174.903   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

IsADirectoryError: [Errno 21] Is a directory: '/home/phillip/Documents/DL_course/notebook_test/docs/tutorial_notebooks/tutorial13'