In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
from torch.distributions import Categorical

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm

AttributeError: partially initialized module 'torchvision' has no attribute 'extension' (most likely due to a circular import)

In [None]:
batch_size = 64
lr = 1e-4

root = "../datasets"

In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

In [None]:
# Define our transform
# We'll upsample the images to 32x32 as it's easier to contruct our network
transform = transforms.Compose([transforms.Resize(32),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5], [0.5])])

test_transform = transforms.Compose([transforms.Resize(32),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5], [0.5])])

train_set = Datasets.MNIST(root=root, train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size,shuffle=True, num_workers=4)

test_set = Datasets.MNIST(root=root, train=False, transform=test_transform, download=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, code_book_size, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        self.code_book_size = code_book_size
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        self.embedding = nn.Embedding(code_book_size, embedding_dim)
        self.embedding.weight.data.uniform_(-1/code_book_size, 1/code_book_size)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 3, 1).contiguous()  # BSxCxHxW --> BSxHxWxC
        input_shape = inputs.shape
        
        flat_input = inputs.view(-1, 1, self.embedding_dim)  # BSxHxWxC --> BS*H*Wx1xC
        
        # Calculate the distance between each embedding and each codebook vector
        distances = (flat_input - self.embedding.weight.unsqueeze(0)).pow(2).mean(2)  # BS*H*WxN
        
        # Find the closest codebook vector
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)  # BS*H*Wx1
        
        # Select that codebook vector
        quantized = self.embedding(encoding_indices).view(input_shape)
        
        # Create loss that pulls encoder embeddings and codebook vector selected
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        
        # Reconstruct quantized representation using the encoder embeddings to allow for 
        # backpropagation of gradients into encoder
        if self.training:
            quantized = inputs + (quantized - inputs).detach()
        
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), encoding_indices.reshape(input_shape[0], -1)

    
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        self.norm1 = nn.GroupNorm(8, channels)
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.norm2 = nn.GroupNorm(8, channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        
    def forward(self, x):
        skip = x
        
        x = F.elu(self.norm1(x))
        x = F.elu(self.norm2(self.conv1(x)))
        x = self.conv2(x) + skip
        return x


# We split up our network into two parts, the Encoder and the Decoder
class DownBlock(nn.Module):
    def __init__(self, channels_in, channels_out):
        super(DownBlock, self).__init__()
        self.bn1 = nn.GroupNorm(8, channels_in)
        self.conv1 = nn.Conv2d(channels_in, channels_out, 3, 2, 1)
        self.bn2 = nn.GroupNorm(8, channels_out)
        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, 1, 1)
        
        self.conv3 = nn.Conv2d(channels_in, channels_out, 3, 2, 1)

    def forward(self, x):
        x = F.elu(self.bn1(x))
                  
        x_skip = self.conv3(x)
        
        x = F.elu(self.bn2(self.conv1(x)))        
        return self.conv2(x) + x_skip
    
    
# We split up our network into two parts, the Encoder and the Decoder
class UpBlock(nn.Module):
    def __init__(self, channels_in, channels_out):
        super(UpBlock, self).__init__()
        self.bn1 = nn.GroupNorm(8, channels_in)

        self.conv1 = nn.Conv2d(channels_in, channels_in, 3, 1, 1)
        self.bn2 = nn.GroupNorm(8, channels_in)

        self.conv2 = nn.Conv2d(channels_in, channels_out, 3, 1, 1)
        
        self.conv3 = nn.Conv2d(channels_in, channels_out, 3, 1, 1)
        self.up_nn = nn.Upsample(scale_factor=2, mode="nearest")

    def forward(self, x_in):
        x = self.up_nn(F.elu(self.bn1(x_in)))
        
        x_skip = self.conv3(x)
        
        x = F.elu(self.bn2(self.conv1(x)))
        return self.conv2(x) + x_skip

    
# We split up our network into two parts, the Encoder and the Decoder
class Encoder(nn.Module):
    def __init__(self, channels, ch=32, latent_channels=32):
        super(Encoder, self).__init__()
        self.conv_1 = nn.Conv2d(channels, ch, 3, 1, 1)
        
        self.conv_block1 = DownBlock(ch, ch * 2)
        self.conv_block2 = DownBlock(ch * 2, ch * 4)

        # Instead of flattening (and then having to unflatten) out our feature map and 
        # putting it through a linear layer we can just use a conv layer
        # where the kernal is the same size as the feature map 
        # (in practice it's the same thing)
        self.res_block_1 = ResBlock(ch * 4)
        self.res_block_2 = ResBlock(ch * 4)
        self.res_block_3 = ResBlock(ch * 4)

        self.conv_out = nn.Conv2d(4 * ch, latent_channels, 3, 1, 1)
    
    def forward(self, x):
        x = self.conv_1(x)
        
        x = self.conv_block1(x)
        x = self.conv_block2(x)

        x = self.res_block_1(x)
        x = self.res_block_2(x)
        x = F.elu(self.res_block_3(x))

        return self.conv_out(x)
    
    
class Decoder(nn.Module):
    def __init__(self, channels, ch = 32, latent_channels = 32):
        super(Decoder, self).__init__()
        
        self.conv1 = nn.Conv2d(latent_channels, 4 * ch, 3, 1, 1)
        self.res_block_1 = ResBlock(ch * 4)
        self.res_block_2 = ResBlock(ch * 4)
        self.res_block_2 = ResBlock(ch * 4)

        self.conv_block1 = UpBlock(4 * ch, 2 * ch)
        self.conv_block2 = UpBlock(2 * ch, ch)
        self.conv_out = nn.Conv2d(ch, channels, 3, 1, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.res_block_1(x)
        x = self.res_block_2(x)
        x = self.res_block_2(x)

        x = self.conv_block1(x)
        x = self.conv_block2(x)
        
        return torch.tanh(self.conv_out(x))

In [None]:
class VQVAE(nn.Module):
    def __init__(self, channel_in, ch=16, latent_channels=32, code_book_size=64, commitment_cost=0.25):
        super(VQVAE, self).__init__()
        self.encoder = Encoder(channels=channel_in, ch=ch, latent_channels=latent_channels)
        
        self.vq = VectorQuantizer(code_book_size=code_book_size, 
                                  embedding_dim=latent_channels, 
                                  commitment_cost=commitment_cost)
        
        self.decoder = Decoder(channels=channel_in, ch=ch, latent_channels=latent_channels)

    def encode(self, x):
        encoding = self.encoder(x)
        vq_loss, quantized, encoding_indices = self.vq(encoding)
        return vq_loss, quantized, encoding_indices
        
    def decode(self, x):
        return self.decoder(x)
        
    def forward(self, x):
        vq_loss, quantized, encoding_indices = self.encode(x)
        recon = self.decode(quantized)
        
        return recon, vq_loss, quantized

In [None]:
# Get a test image
dataiter = iter(test_loader)
test_images = next(dataiter)[0]

# View the shape
test_images.shape
# Visualize the data!!!
plt.figure(figsize = (5,5))
out = vutils.make_grid(test_images, normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
# The number of code book embeddings
code_book_size = 32

# The number of latent embedding channels
latent_channels = 10

# Number of Training epochs
vq_nepoch = 50

# Create our network
vae_net = VQVAE(channel_in=test_images.shape[1], latent_channels=latent_channels, ch=16, 
                code_book_size=code_book_size, commitment_cost=0.25).to(device)

# Setup optimizer
optimizer = optim.Adam(vae_net.parameters(), lr=lr)
scaler = torch.amp.GradScaler('cuda')

lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=vq_nepoch, eta_min=0)

# Create loss logger
recon_loss_log = []
qv_loss_log = []
test_recon_loss_log = []
train_loss = 0

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in vae_net.parameters():
    num_model_params += param.flatten().shape[0]

print("-The VQVAE Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, 
                                                                          num_model_params//1e6))

In [None]:
# Pass through a test image to make sure everything is working
recon_data, vq_loss, quantized = vae_net(test_images.to(device))

# View the Latent vector shape
quantized.shape

In [None]:
pbar = trange(0, vq_nepoch, leave=False, desc="Epoch")   
for epoch in pbar:
    pbar.set_postfix_str('Loss: %.4f' % (train_loss/len(train_loader)))
    train_loss = 0
    vae_net.train()
    for i, data in enumerate(tqdm(train_loader, leave=False, desc="Training")):

        image = data[0].to(device)
        with torch.amp.autocast('cuda'):
            # Forward pass the image in the data tuple
            recon_data, vq_loss, quantized = vae_net(image)

            # Calculate the loss
            recon_loss = (recon_data - image).pow(2).mean()
            loss = vq_loss + recon_loss

        # Take a training step
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Log the loss
        recon_loss_log.append(recon_loss.item())
        qv_loss_log.append(vq_loss.item())
        train_loss += recon_loss.item()
        
    lr_scheduler.step()

    vae_net.eval()
    for i, data in enumerate(tqdm(test_loader, leave=False, desc="Testing")):
        image = data[0].to(device)
        with torch.amp.autocast('cuda'):
            with torch.no_grad():
                # Forward pass the image in the data tuple
                recon_data, vq_loss, quantized = vae_net(image)

                # Calculate the loss
                recon_loss = (recon_data - image).pow(2).mean()
                loss = vq_loss + recon_loss
                test_recon_loss_log.append(recon_loss.item())

In [None]:
x_train = np.linspace(0, vq_nepoch, len(recon_loss_log[200:]))
_ = plt.plot(x_train, recon_loss_log[200:])

x_test = np.linspace(0, vq_nepoch, len(test_recon_loss_log[200:]))
_ = plt.plot(x_test, test_recon_loss_log[200:])
_ = plt.title("Reconstruction Loss")

In [None]:
_ = plt.plot(qv_loss_log[100:])
_ = plt.title("VQ Loss")

In [None]:
vae_net.eval()
recon_data, vq_loss, quantized = vae_net(test_images.to(device))

In [None]:
vq_loss, quantized, encoding_indices = vae_net.encode(test_images.to(device))

In [None]:
encoding_indices[0]

In [None]:
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_images[0:8], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
plt.figure(figsize = (20,10))
out = vutils.make_grid(recon_data.detach().cpu()[0:8], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

# Define an Encoder module for the Transformer architecture
class Encoder(nn.Module):
    def __init__(self, hidden_size=128, num_layers=3, num_heads=4):
        super(Encoder, self).__init__()

        # Create multiple transformer blocks as layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads,
                                                   dim_feedforward=hidden_size * 4, dropout=0.0,
                                                   batch_first=True)
        # TransformerEncoder will clone the encoder_layer "num_layers" times
        self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, input_seq, padding_mask=None):
        
        bs, l, h = input_seq.shape
        # Create the causal mask
        causal_mask = torch.triu(torch.ones(l, l, device=input_seq.device), 1).bool()

        # Pass the embeddings through each transformer block
        output = self.encoder_layers(src=input_seq, mask=causal_mask)

        return output

In [None]:
class Transformer(nn.Module):
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Transformer, self).__init__()

        # Create an embedding layer for tokens
        self.embedding = nn.Embedding(num_emb, hidden_size)

        # Initialize sinusoidal positional embeddings
        self.pos_emb = SinusoidalPosEmb(hidden_size)

        # Create an encoder and decoder with specified parameters
        self.encoder = Encoder(hidden_size=hidden_size, num_layers=num_layers,
                               num_heads=num_heads)

        # Define a linear layer for output prediction
        self.fc_out = nn.Linear(hidden_size, num_emb)

    def embed(self, input_seq):
        # Embed the input sequence
        input_embs = self.embedding(input_seq)
        bs, l, h = input_embs.shape

        # Add positional embeddings to the input embeddings
        seq_indx = torch.arange(l, device=input_seq.device)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(bs, l, h)
        embs = input_embs + pos_emb
        return embs

    def encode(self, input_seq):
        # Embed the input sequence
        embs = self.embed(input_seq)

        # Encode the sequence
        embs_out = self.encoder(embs)
        return embs_out


    def forward(self, input_seq):
        # Encode the input sequence
        encoded_seq = self.encode(input_seq=input_seq)

        return self.fc_out(encoded_seq)

In [None]:
# Number of transformer blocks
num_layers = 4

# MultiheadAttention Heads
num_heads = 8

hidden_size = 256

# Number of Training epochs
tf_nepoch = 100

# Create model
# We'll include a "start-sequence" token so there are num_embeddings + 1 embeddings
tf_generator = Transformer(num_emb=code_book_size + 1, num_layers=num_layers, 
                           hidden_size=hidden_size, num_heads=num_heads).to(device)

# Initialize the optimizer with above parameters
tf_optimizer = optim.Adam(tf_generator.parameters(), lr=lr)

tf_lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(tf_optimizer, T_max=tf_nepoch, eta_min=0)

# Scaler for mixed precision training
tf_scaler = torch.amp.GradScaler('cuda')

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Initialize training loss logger
training_loss_logger = []

In [None]:
# Let's see how many Parameters our Model has!
num_model_params = 0
for param in tf_generator.parameters():
    num_model_params += param.flatten().shape[0]

print("-The TF Model Has %d (Approximately %d Million) Parameters!" % (num_model_params, num_model_params//1e6))

In [None]:
pbar = trange(0, tf_nepoch, leave=False, desc="Epoch")   
vae_net.eval()
for epoch in pbar:
    pbar.set_postfix_str('Loss: %.4f' % (train_loss/len(train_loader)))
    train_loss = 0
    
    tf_generator.train()
    for i, data in enumerate(tqdm(train_loader, leave=False, desc="Training")):
        image = data[0].to(device)
        
        with torch.no_grad():
            _, _, encoding_indices = vae_net.encode(image)
        
        encoding_indices = encoding_indices + 1
        tf_inputs = torch.cat((torch.zeros_like(encoding_indices[:, 0:1]), encoding_indices[:, :-1]), 1)
        tf_outputs = encoding_indices

        # Generate predictions
        with torch.amp.autocast('cuda'):
            pred = tf_generator(tf_inputs)

        loss = loss_fn(pred.transpose(1, 2), tf_outputs)
        
        # Backpropagation
        tf_optimizer.zero_grad()
        tf_scaler.scale(loss).backward()
        tf_scaler.step(tf_optimizer)
        tf_scaler.update()

        # Log training loss and entropy
        training_loss_logger.append(loss.item())
        train_loss += loss.item()
        
    tf_lr_scheduler.step()

In [None]:
_ = plt.plot(training_loss_logger[200:])
_ = plt.title("Loss per iteration")

In [None]:
# Set temperature for sampling
temp = 0.99

In [None]:
# List to log generated tokens
log_tokens = [torch.zeros(64, 1).long()]

# Set the generator model to evaluation mode
tf_generator.eval()

# Generate tokens
with torch.no_grad():    
    for i in range(64):
        # Concatenate tokens from previous iterations
        input_tokens = torch.cat(log_tokens, 1)
        
        # Get model predictions for the next token
        data_pred = tf_generator(input_tokens.to(device))
        
        # Sample the next token from the distribution of probabilities
        dist = Categorical(logits=data_pred[:, -1] / temp)
        next_tokens = dist.sample().reshape(-1, 1)
        
        # Append the sampled token to the list of generated tokens
        log_tokens.append(next_tokens.cpu())
        # break

In [None]:
_ = plt.plot(F.softmax(data_pred[0, -1], -1).flatten().cpu())

In [None]:
embs_indx = torch.cat(log_tokens, 1)[:, 1:].to(device) - 1
embeds = vae_net.vq.embedding(embs_indx).reshape(-1, 8, 8, latent_channels).permute(0, 3, 1, 2).contiguous()

In [None]:
recon_data = vae_net.decode(embeds)

plt.figure(figsize = (5,5))
out = vutils.make_grid(recon_data.detach().cpu(), normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
rand_sample = torch.randint(code_book_size, (64, 64), device=device)
rand_sample_embeds = vae_net.vq.embedding(rand_sample).reshape(-1, 8, 8, latent_channels).permute(0, 3, 1, 2).contiguous()

recon_data = vae_net.decode(rand_sample_embeds)

plt.figure(figsize = (5,5))
out = vutils.make_grid(recon_data.detach().cpu(), normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))