[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mariyamuneeb/ssl_wordspotting/blob/main/VariationalAE.ipynb)

## Installations, Imports, Plotting Utils

In [None]:
!git clone https://github.com/mariyamuneeb/ssl_wordspotting
!pip -qqq install wandb

In [None]:
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd 
import random 
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import wandb
import numpy as np
wandb.login(key="76fdad476f01ca03a4b43a03616920f905a25488")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
from google.colab import drive
drive.mount('/content/drive')
GDRIVE_ROOT = '/content/drive/MyDrive/Datasets'

## Plotting Functions

In [None]:
## Plotting Few Samples
def plot_samples(dataset,num_samples):
    random_imgs = dataset.get_random_samples(num_samples)
    _, axs = plt.subplots(3, 3, figsize=(12, 12))
    axs = axs.flatten()
    for img, ax in zip(random_imgs, axs):
        ax.imshow(img)
        ax.title.set_text(f'Image Shape {img.size},{img.mode}')
    plt.show()

    
## Plotting Samples During Training
def plot_ae_custom_ds_outputs(encoder,decoder,test_dataset,n=10):
    wandb_imgs = list()
    wandb_rec_imgs = list()
    my_table = wandb.Table(columns=["Original", "Reconstruction"])
    plt.figure(figsize=(16,4.5))   
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[i][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy().T, cmap='gist_gray') # for MNIST remove the transpose
    #   plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray') # for MNIST remove the transpose
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy().T, cmap='gist_gray')  #for MNIST remove the transpose
    #   plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  #for MNIST remove the transpose
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
      my_table.add_data( wandb.Image(img.cpu()), wandb.Image(rec_img.cpu()))
      wandb_imgs.append(img.cpu())
      wandb_rec_imgs.append(rec_img.cpu())
    plt.show()   

# Model Design

## Encoder Class

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, num_input_channels,
                 base_channel_size,
                 latent_dims,
                 ):
        super(VariationalEncoder, self).__init__()
        c_hid = base_channel_size
        self.conv1 = nn.Conv2d(num_input_channels, c_hid, kernel_size = 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(c_hid, 2*c_hid,  kernel_size =3, stride=2, padding=1)
        # self.batch2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(2*c_hid, 2*2*c_hid, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(2*2*c_hid, 2*2*2*c_hid, kernel_size=3, stride=2,padding=1)
        self.conv5 = nn.Conv2d(2*2*2*c_hid, 2*2*2*2*c_hid, kernel_size=3, padding=1, stride=2)
        self.linear1 = nn.Linear(4*4*16*c_hid, 4*2*c_hid)
        self.linear2 = nn.Linear(4*2*c_hid, latent_dims)
        self.linear3 = nn.Linear(4*2*c_hid, latent_dims)
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc # hack to get sampling on the GPU
        self.N.scale = self.N.scale
        self.kl = 0

    def forward(self, x):        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = torch.flatten(x,start_dim=1)
        x = self.linear1(x)
        x = F.relu(x)
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        return z
        # return x

In [None]:
x = torch.rand(8,1,128,128)
print(x.shape)
enc = VariationalEncoder(1, 32,512)
z = enc(x)
print(z.shape)
# z2 = z.reshape(z.shape[0], -1, 4, 4)
# print(z2.shape)

## Decoder Class

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, num_input_channels,
                 base_channel_size,
                 latent_dims):
        super().__init__()
        c_hid = base_channel_size
        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, 4*2*c_hid),
            nn.ReLU(True),
            nn.Linear(4*2*c_hid,4*4*16*c_hid),
            nn.ReLU(True)
        )

        # self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(2*2*2*2*c_hid, 2*2*2*c_hid,  kernel_size=3, output_padding=1, padding=1, stride=2),
            # nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(2*2*2*c_hid, 2*2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.ReLU(True),
            nn.ConvTranspose2d(2*2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            # nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.ReLU(True),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        # x = self.unflatten(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.decoder_conv(x)
        # x = torch.tanh(x)
        return x

In [None]:
dec = Decoder(3,32,512)
z = dec(z)
print(z.shape)

## Variational AE

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, num_channels,base_channel_size,latent_dim):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(num_channels,base_channel_size,latent_dim)
        self.decoder = Decoder(num_channels,base_channel_size,latent_dim)

    def forward(self, x):
        x = x.to(device)
        z = self.encoder(x)
        return self.decoder(z)

class that merges the encoder and decoder

Initialize the VariationalAutoencoder class, the optimizer, and the device to use the GPU in the code.

## Training Function

Functions to train and evaluate the Variational Autoencoder

In [None]:
### Training function
def train_epoch(vae, device, dataloader, optimizer):
    # Set train mode for both the encoder and the decoder
    vae.train()
    train_loss = 0.0
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for x, _ in dataloader: 
        # Move tensor to the proper device
        x = x.to(device)
        x_hat = vae(x)
        # Evaluate loss
        loss = ((x - x_hat)**2).sum() + vae.encoder.kl

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print batch loss
        # print('\t partial train loss (single batch): %f' % (loss.item()))
        train_loss+=loss.item()
    train_loss_ave = train_loss / len(dataloader.dataset)
    wandb.log({"train_loss": train_loss_ave})
    return train_loss_ave

## Testing Function

The loss is composed of two terms. The reconstruction term is the sum of the squared differences between the input and its reconstruction.

In [None]:
### Testing function
def test_epoch(vae, device, dataloader):
    # Set evaluation mode for encoder and decoder
    vae.eval()
    val_loss = 0.0
    with torch.no_grad(): # No need to track the gradients
        for x, _ in dataloader:
            # Move tensor to the proper device
            x = x.to(device)
            # Encode data
            encoded_data = vae.encoder(x)
            # Decode data
            x_hat = vae(x)
            loss = ((x - x_hat)**2).sum() + vae.encoder.kl
            val_loss += loss.item()
    val_loss_ave = val_loss / len(dataloader.dataset)
    wandb.log({"val_loss": val_loss_ave})    
    return val_loss_ave

## Plotting Function


The input and its corresponding reconstruction in each epoch during the training of the VAE model.

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):
    wandb_imgs = list()
    wandb_rec_imgs = list()
    my_table = wandb.Table(columns=["Original", "Reconstruction"])
    plt.figure(figsize=(16,4.5))
    targets = np.array(test_dataset.targets) # for MNIST change this to test_dataset.targets.numpy()
    # targets =  test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}
    for i in range(n):
      ax = plt.subplot(2,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)
      encoder.eval()
      decoder.eval()
      with torch.no_grad():
         rec_img  = decoder(encoder(img))
      plt.imshow(img.cpu().squeeze().numpy().T, cmap='gist_gray') # for MNIST remove the transpose
    #   plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray') # for MNIST remove the transpose
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(2, n, i + 1 + n)
      plt.imshow(rec_img.cpu().squeeze().numpy().T, cmap='gist_gray')  #for MNIST remove the transpose
    #   plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  #for MNIST remove the transpose
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
      my_table.add_data( wandb.Image(img.cpu()), wandb.Image(rec_img.cpu()))
      wandb_imgs.append(img.cpu())
      wandb_rec_imgs.append(rec_img.cpu())
    plt.show()
    # my_table = wandb.Table()
    
    

    # my_table.add_column("Original", wandb_imgs)
    # my_table.add_column("Reconstruction", wandb_rec_imgs)

# Log your Table to W&B
    wandb.log({"vae_reconstrunction_cifar10": my_table})

# Standard Datasets

## Initialization

In [None]:
num_channels = 3
base_channel_size=32
lr = 10e-2
latent_dim = 384
epochs = 300
plot_freq = 10

wandb.init(
      # Set the project where this run will be logged
      project="SSL", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"VAE", 
      # Track hyperparameters and run metadata
      config={
      "architecture": "CNN",
      "dataset": "CIFAR-10",
      "epochs": epochs,
      "latent_dim":latent_dim
      })

In [None]:
torch.manual_seed(0)

# d = 4

vae = VariationalAutoencoder(latent_dims=4)

optim = torch.optim.Adam(vae.parameters(), lr=lr)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

vae.to(device)

## Import/Load Datasets

In [None]:
data_dir = 'dataset'
# train_dataset = torchvision.datasets.MNIST(data_dir, train = True, download = True)
# test_dataset  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

train_dataset = torchvision.datasets.CIFAR10(data_dir,train=True,download=True)
test_dataset  = torchvision.datasets.CIFAR10(data_dir,train=False, download=True)

In [None]:
img = train_dataset[1][0]
label = train_dataset[1][1]
print(img.mode)
print(label)
plt.imshow(img)
plt.show()

In [None]:
img.size

In [None]:
train_transform = transforms.Compose([transforms.ToTensor(), ])

test_transform = transforms.Compose([transforms.ToTensor(), ])

In [None]:
train_dataset.transform = train_transform
test_dataset.transform = test_transform

m=len(train_dataset)

train_data, val_data = random_split(train_dataset, [int(m-m*0.2), int(m*0.2)])
batch_size=256

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

VariationalAutoencoder class, which combines the Encoder and Decoder classes 
The encoder and decoder networks contain **three convolutional layers** and **two fully connected layers**. 
Some batch normal layers are added to have more robust features in the latent space. 
Differently from the standard autoencoder, the **encoder returns mean and variance matrices** and we use them to obtain the sampled latent vector. 

## Training Loop

In [None]:
for epoch in range(epochs):
   train_loss = train_epoch(vae,device,train_loader,optim)
   val_loss = test_epoch(vae,device,valid_loader)
   print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, epochs,train_loss,val_loss))
   if epoch%plot_freq==0:
       plot_ae_outputs(vae.encoder,vae.decoder,n=10)

In [None]:
break

In [None]:
train_loader

# VAE on Hand-written Dataset 

## Dataset

### Connect to GDrive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
PATH = '/content/drive/MyDrive/Datasets/tif'

### Imports

In [None]:
import os
import cv2
import math
from PIL import Image
import matplotlib.pyplot as plt
from statistics import mean
from math import floor

### Custom Dataset Definition

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, img_paths,transform=None):
        super(MyDataset, self).__init__()
        self.img_paths = img_paths
        self.transform = transform       

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_name = self.img_paths[idx]
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image,'1'

    @property
    def targets(self):
        dummy_targets = ['null']*len(self.img_paths)
        return dummy_targets

### Dataset Loading

In [None]:
images_paths = [os.path.join(PATH,i) for i in os.listdir(PATH)]
split = 0.85
train_idx = math.floor(split*len(images_paths))
train_images = images_paths[:train_idx]
test_images = images_paths[train_idx:]

### Plotting Few Images

In [None]:
random_imgs = random.sample(train_images, 9)
random_imgs = [Image.open(i) for i in random_imgs]
_, axs = plt.subplots(3, 3, figsize=(12, 12))
axs = axs.flatten()
for img, ax in zip(random_imgs, axs):
    ax.imshow(img)
    ax.title.set_text(f'Image Shape {img.size},{img.mode}')
plt.show()

### Finding Ave Image Dimensions

In [None]:
h_list = list()
w_list = list()

for p in train_images:
    h_list.append(Image.open(p).size[1])
    w_list.append(Image.open(p).size[0])
num_channels = int(Image.open(p).mode)

h_ave = floor(mean(h_list))
w_ave = floor(mean(w_list))
# resize_size = (h_ave,w_ave)
resize_size = (128,128)

### Dataset Definitions

In [None]:
train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Resize(resize_size),])
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Resize(resize_size),])

hw_train_dataset = MyDataset(img_paths=train_images,transform=train_transform)
hw_test_dataset = MyDataset(img_paths=test_images,transform=test_transform)

### Dataloader and Batching Definitions

In [None]:
m=len(hw_train_dataset)


batch_size=8

train_loader = torch.utils.data.DataLoader(hw_train_dataset, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(hw_test_dataset, batch_size=batch_size)

## Plotting

## Initialization

### Hyperparameters

In [None]:
base_channel_size = 32
lr = 1e-3
latent_dim = 512
epochs = 300
plot_freq = 10

In [None]:
hw_train_dataset[0][0].shape

### W&B Init

In [None]:
wandb.init(
      # Set the project where this run will be logged
      project="SSL", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"VAE", 
      # Track hyperparameters and run metadata
      config={
      "architecture": "CNN",
      "dataset": "Hand Written Dataset",
      "lr":lr,
      "epochs": epochs,
      "latent_dim":latent_dim
      })

### Initialize VAE

Initialize the VariationalAutoencoder class, the optimizer, and the device to use the GPU in the code.

In [None]:
torch.manual_seed(0)

vae = VariationalAutoencoder(num_channels,base_channel_size,latent_dim)
print(vae)

optim = torch.optim.Adam(vae.parameters(), lr=lr)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')
vae.to(device)

## Training

### Train Loop

In [None]:
for epoch in range(epochs):
   train_loss = train_epoch(vae,device,train_loader,optim)
   val_loss = test_epoch(vae,device,valid_loader)
   print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, epochs,train_loss,val_loss))
   if epoch%plot_freq==0:
       plot_ae_custom_ds_outputs(vae.encoder,vae.decoder,hw_test_dataset,n=10)

# IAM Dataset


## Dataset

In [None]:
from models.dataset_utils import IAMDataset
from PIL import Image


In [None]:
dataset_root_dir = f'{GDRIVE_ROOT}/IAM_HW/words_full_dataset'
train_dir = f'{dataset_root_dir}/words_training'
test_dir = f'{dataset_root_dir}/words_test'
iam_train_dataset = IAMDataset(train_dir)
iam_test_dataset = IAMDataset(test_dir)

In [None]:
batch_size=8

train_loader = torch.utils.data.DataLoader(iam_train_dataset, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(iam_test_dataset, batch_size=batch_size)

In [None]:
h_list = list()
w_list = list()

for p,_ in iam_train_dataset:
    h_list.append(p.size[1])
    w_list.append(p.size[0])
num_channels = 3

h_ave = floor(mean(h_list))
w_ave = floor(mean(w_list))
# resize_size = (h_ave,w_ave)
resize_size = (128,128)

## Initialization

### HyperParameters

In [None]:
base_channel_size = 32
lr = 1e-3
latent_dim = 512
epochs = 300
plot_freq = 10

### W&B Init

In [None]:
wandb.init(
      # Set the project where this run will be logged
      project="SSL", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"VAE", 
      # Track hyperparameters and run metadata
      config={
      "architecture": "CNN",
      "dataset": "Hand Written Dataset",
      "lr":lr,
      "epochs": epochs,
      "latent_dim":latent_dim
      })

### Initialize VAE

In [None]:
torch.manual_seed(0)

vae = VariationalAutoencoder(num_channels,base_channel_size,latent_dim)
print(vae)

optim = torch.optim.Adam(vae.parameters(), lr=lr)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')
vae.to(device)

## Training

### Train Loop

In [None]:
for epoch in range(epochs):
   train_loss = train_epoch(vae,device,train_loader,optim)
   val_loss = test_epoch(vae,device,valid_loader)
   print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, epochs,train_loss,val_loss))
   if epoch%plot_freq==0:
       plot_ae_custom_ds_outputs(vae.encoder,vae.decoder,hw_test_dataset,n=10)