# Training Notebook for VQ-VAE

In [None]:
# imports
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

### Load CelebA dataset

In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# access the kaggle.json API key from the main folder of your google drive
!mkdir -p ~/.kaggle
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# download the dataset from kaggle
!kaggle datasets download -d zuozhaorui/celeba
!mkdir ./data
!unzip -q celeba.zip -d ./data/celeba

In [None]:
# load dataset
class CelebATransform:
    def __call__(self, img):
        img = torchvision.transforms.functional.crop(img, top=60, left=25, height=128, width=128)
        img = torchvision.transforms.functional.resize(img, (64, 64))
        img = torchvision.transforms.functional.to_tensor(img)
        return img
celeba = torchvision.datasets.ImageFolder(root='./data/celeba', transform=CelebATransform())

# visualize
grid_x = 5
grid_y = 4

samples = torch.stack([celeba[i][0] for i in range(grid_x*grid_y)])

img = torchvision.utils.make_grid(samples, grid_x, normalize=True)
plt.title(f'Sample Images')
plt.axis('off')
plt.imshow(img.permute(1,2,0).cpu())

### Import model

In [None]:
# clone the github repository containing the VQ-VAE model
!git clone https://github.com/patrickmastorga/VQ-VAE-Tranformer-Image-Gen.git

In [None]:
# import VQ-VAE model from model.py
import os
os.chdir('VQ-VAE-Tranformer-Image-Gen/VQ')
from model import Encoder, Decoder, Quantizer, EMBEDDING_DIM
os.chdir('../../')

### Train model

In [None]:
# initialize dataloader, models, and optimizer for training
BATCH_SIZE = 256
EPOCHS = 1
BETA = 0.25

dataloader = torch.utils.data.DataLoader(celeba, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = Encoder().to(device)
decoder = Decoder().to(device)
quantizer = Quantizer().to(device)
vq_vae = VQ_VAE(encoder, decoder, quantizer).to(device)

optimizer = torch.optim.Adam(vq_vae.parameters(), lr=1e-3)

In [None]:
training_losses = []
log_interval = 128

vq_vae.train()
for epoch in range(EPOCHS):

    running_loss = 0.0
    for batch_idx, batch in enumerate(train_dataloader):
        # training step
        optimizer.zero_grad()
        images, _ = batch
        images = images.to(device)
        reconstructed, codebook_loss, commitment_loss = vq_vae(images)

        recon_loss = nn.functional.binary_cross_entropy(reconstructed, images)
        loss = recon_loss + codebook_loss + BETA * commitment_loss

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # keep track of loss and epoch progress
        if batch_idx % log_interval == log_interval - 1:
            training_losses(running_loss / log_interval)
            running_loss = 0.0
            print(f'TRAINING Epoch [{epoch+1}/{EPOCHS}], Batch [{batch_idx+1}/{len(train_dataloader)}]], Loss: {training_losses[-1]: .1f}')

    # visualize reconstructions
    samples, _ = next(iter(train_dataloader))
    samples = samples[:5].to(device)

    vq_vae.eval()
    with torch.no_grad():
        reconstructed, _, _ = vq_vae(samples)

    img = torchvision.utils.make_grid(torch.cat((samples, reconstructed), dim=0), 5, normalize=True)
    plt.title(f'Reconstructions')
    plt.axis('off')
    plt.imshow(img.permute(1,2,0).cpu())

print(f'Training complete.')