# Training Notebook for VQ-VAE

In [None]:
# imports
import os
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt

### Load CelebA dataset

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# access the kaggle.json API key from the main folder of your google drive
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# download the dataset from kaggle
!kaggle datasets download -d zuozhaorui/celeba
!mkdir ./data
!unzip -q celeba.zip -d ./data/celeba

In [None]:
# load dataset
class CelebATransform:
    '''
    Crops around the face and resizes to 64x64. Output is a tensor of shape (3, 64, 64) scaled to [0, 1]
    '''
    def __call__(self, img):
        img = torchvision.transforms.functional.crop(img, top=60, left=25, height=128, width=128)
        img = torchvision.transforms.functional.resize(img, (64, 64))
        img = torchvision.transforms.functional.to_tensor(img)
        return img
celeba = torchvision.datasets.ImageFolder(root='./data/celeba', transform=CelebATransform())

# visualize
grid_x = 5
grid_y = 4

samples = torch.stack([celeba[i][0] for i in range(grid_x*grid_y)])

img = torchvision.utils.make_grid(samples, grid_x, normalize=True)
plt.title(f'Sample Images')
plt.axis('off')
plt.imshow(img.permute(1,2,0).cpu())

### Import model

In [None]:
# clone the github repository containing the VQ-VAE model and import models
!git clone https://github.com/patrickmastorga/VQ-VAE-Tranformer-Image-Gen.git

os.chdir('VQ-VAE-Tranformer-Image-Gen/VQ')
from model import Encoder, Decoder, Quantizer, VQ_VAE, EMBEDDING_DIM
os.chdir('../../')

### Train model

In [None]:
BATCH_SIZE = 256
BETA = 0.25

CHECKPOINT_DIR = '/content/drive/MyDrive/vq_models'
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint.pt')
LOAD_FROM_CHECKPOINT = False
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# initialize dataloader, models, and optimizer for training
dataloader = torch.utils.data.DataLoader(celeba, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = Encoder().to(device)
decoder = Decoder().to(device)
quantizer = Quantizer().to(device)
model = VQ_VAE(encoder, decoder, quantizer).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

training_losses = []
training_steps = 0
running_loss = 0.0

# load from checkpoint
if LOAD_FROM_CHECKPOINT:
    if not os.path.exists(CHECKPOINT_PATH):
        print(f'WARNING: Checkpoint not found at {CHECKPOINT_PATH}!')
    else:
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        training_steps = checkpoint['training_steps']
        training_losses = checkpoint['training_losses']
        running_loss = checkpoint['running_loss']
        torch.set_rng_state(checkpoint['cpu_rng_state'])
        if torch.cuda.is_available() and 'cuda_rng_state' in checkpoint:
            torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])

        print(f'Checkpoint loaded. Resuming from training step {training_steps}.')

In [None]:
EPOCHS = 1
LOG_INTERVAL = 128
SAVE_INTERVAL = 10000

total_steps = training_steps + len(dataloader) * EPOCHS

model.train()
for epoch in range(EPOCHS):
    for batch in dataloader:
        # training step
        optimizer.zero_grad()
        images, _ = batch
        images = images.to(device)

        loss = model(images)
        loss.backward()
        optimizer.step()
        training_steps += 1

        running_loss += loss.item()

        # keep track of loss and epoch progress
        if training_steps % LOG_INTERVAL == 0:
            avg_loss = running_loss / LOG_INTERVAL
            running_loss = 0.0
            training_losses.append((training_steps, avg_loss))
            print(f'TRAINING Step [{training_steps}/{total_steps}], Loss: {avg_loss:.1f}')
        
        if training_steps % SAVE_INTERVAL == 0:
            checkpoint = {
                'training_steps': training_steps,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'training_losses': training_losses,
                'running_loss': running_loss,
                'cpu_rng_state': torch.get_rng_state(),
            }
            if torch.cuda.is_available():
                checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state()

            torch.save(checkpoint, CHECKPOINT_PATH)
            print(f'Checkpoint saved at step {training_steps} to {CHECKPOINT_PATH}')

    # visualize reconstructions
    samples, _ = next(iter(dataloader))
    samples = samples[:5].to(device)

    model.eval()
    reconstructed = model.reconstruct(samples)
    model.train()

    img = torchvision.utils.make_grid(torch.cat((samples, reconstructed), dim=0), 5, normalize=True)
    plt.title(f'Reconstructions')
    plt.axis('off')
    plt.imshow(img.permute(1,2,0).cpu())

print(f'Training complete.')