<a href="https://colab.research.google.com/github/ram130849/VQ_VAE/blob/main/VQ_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install umap-learn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting umap-learn
  Downloading umap-learn-0.5.3.tar.gz (88 kB)
[K     |████████████████████████████████| 88 kB 3.3 MB/s 
Collecting pynndescent>=0.5
  Downloading pynndescent-0.5.7.tar.gz (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 10.8 MB/s 
Building wheels for collected packages: umap-learn, pynndescent
  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Created wheel for umap-learn: filename=umap_learn-0.5.3-py3-none-any.whl size=82829 sha256=574cb6f34cfec3acc4b63355dbbde8e435f2314470f6112b57db8b2fcb4a5517
  Stored in directory: /root/.cache/pip/wheels/b3/52/a5/1fd9e3e76a7ab34f134c07469cd6f16e27ef3a37aeff1fe821
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Created wheel for pynndescent: filename=pynndescent-0.5.7-py3-none-any.whl size=54286 sha256=1bf83acb9640c127405f310afe31b42917426ab9703b15740ba7b6327958a9b9
  Stored in directo

In [3]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
import umap


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
training_data = datasets.CIFAR10(root="data", train=True, download=True,
                                 transform = transforms.Compose([
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5,0.5,0.5),(1,1,1))
                                 ]))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data


In [7]:
validation_data = datasets.CIFAR10(root="data",train=False,download=True,
                                   transform = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5,0.5,0.5),(1.0,1.0,1.0))
                                   ]))

Files already downloaded and verified


In [8]:
data_var = np.var(training_data.data/255.0)

In [9]:
class VectorQuantizer(nn.Module):
    def __init__(self,no_embeddings:int,embedding_dim:int,commit_loss:float):
        super(VectorQuantizer,self).__init__()
        self.K = no_embeddings
        self.D = embedding_dim
        self.commit_loss = commit_loss

        self.embedding = nn.Embedding(self.K,self.D)
        self.embedding.weight.data.uniform(-1/self.K,1/self.K)

    def forward(self,latents):
        # convert the input from BCHW --> BHWC
        latents = latents.permute(0,2,3,1).contiguous()
        latent_shape = latents.shape

        # Flatten_Input.
        flat_input = latents.view(-1,self.D)

        # Calculate L2 distances between the latent space and embedding weights
        distances = (torch.sum(flat_input**2,dim=1,keepdim=True)+torch.sm(self.embedding.weight**2,dim=1)- 
                     2*torch.matmul(flat_input,self.embedding.weight.t()))

        # Encoding - get the encoding that has the min distance and convert it to one-hot encodings.
        encoding_idx = torch.argmin(distances,dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_idx.shape[0],self.K,device=latents.device)
        encodings.scatter_(1,encoding_idx,1) # [BHW * K]

        # Quantize and unflatten the latents.
        quantized = torch.matmul(encodings,self.embedding.weight).view(latent_shape) # [BHW,D] --> [B*H*W*D]

        # Loss
        commitment_loss = F.mse_loss(quantized.detach(),latents)
        embedding_loss  = F.mse_loss(quantized,latents.detach())

        vq_loss = commitment_loss * self.commit_loss + embedding_loss

        quantized = latents + (quantized - latents).detach() 
        # convert quantized from BHWC --> BCHW
        return quantized.permute(0,3,1,2).contiguous(), vq_loss # [B*C*H*W]

In [None]:
### We will also implement a slightly modified version which will use exponential moving averages to update the embedding vectors instead of an 
### auxillary loss. This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other 
### parts of the architecture.For most experiments the EMA version trains faster than the non-EMA version.

In [10]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self,no_embeddings,embedding_dim,commit_loss,decay,epsilon=1e-5):
        super(VectorQuantizerEMA,self).__init__()
        self.K = no_embeddings
        self.D = embedding_dim

        self.embedding = nn.Embedding(self.K,self.D)
        self.embedding.weiht.normal_()
        self.commit_loss = commit_loss

        self.register_buffer('_ema_cluster_size',torch.zeros(self.K))
        self.ema_w = nn.Parameter(torch.Tensor(self.K,self.D))
        self.ema_w.data.normal_()

        self.decay = decay
        self.epsilon = epsilon

    def forward(self,latents):
        # convert inputs from BCHW --> BHWC
        latents = latents.permute(0,2,3,1).contiguous()
        latent_shape = latents.shape

        # Flatten_Input
        flat_input = latents.view(-1,self.D)

        # Calculate Distances.
        distances = (torch.sum(flat_input**2,dim=1,keepdim=True)+
                     torch.sum(self.embedding.weight**2,dim=1)
                     - 2*torch.matmul(flat_input,self.embedding.weight.t()))
        # Encoding
        encoding_idx = torch.argmin(distances,dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_idx.shape[0],self.K,device=latents.device)
        encodings.scatter_(1,encoding_idx,1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings,self.embedding.weight).view(latent_shape)

        # Use EMA to update the embedding vectors
        if(self.training):
            self._ema_cluster_size = self._ema_cluster_size*self.decay + (1-self.decay) * torch.sum(encodings,0)

            # Laplace Smoothing of the cluster Size.
            ema_cluster_sum = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = ((self._ema_cluster_size + self.epsilon)/(ema_cluster_sum + self.K*self.epsilon)*ema_cluster_sum)
            dw = torch.matmul(encodings.t(),flat_input)
            self.ema_w = nn.Parameter(self.ema_w*self.decay + (1-self.decay)*dw)
            self.embedding.weight = nn.Parameter(self.ema_w/ self._ema_cluster_size.unsqueeze(1))

        # Loss
        commitment_loss = F.mse_loss(quantized.detach(),latents)
        loss = self.commit_loss*commitment_loss

        # Straight through Estimator
        quantized = latents + (quantized - latents).detach()
        avg_probs = torch.mean(encodings,dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs*torch.log(avg_probs+1e-10)))

        #convert quantized from BHWC --> BCHW
        return loss, quantized.permute(0,3,1,2).contiguous(), perplexity, encodings

In [None]:
class Residual(nn.Module):
    def __init__(self,input_chnl,no_hidden,no_residual_hidden):