# Contrastive Variational Autoencoder for the MNIST-Data-Set

Tobias Haase

I am slightly orienting myself on a paper from [Abid & Zou (2019)](https://arxiv.org/abs/1902.04601)
## Set Up
Firstly I am loading the required modules.

In [4]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import imageio 

import os
os.chdir("/home/tchaase/Documents/Universitaet/Master-Arbeit/pytorch_mnist_VAE/")

As with the [vae](vae.ipynb) I will firstly also define the utility functions. 

In [None]:
def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD



Here the training function

In [None]:
def train(model, dataloader, dataset, device, optimizer, criterion):
    model.train()
    running_loss = 0.0
    counter = 0
    for i, data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
        counter += 1  #There is a counter that is initialized at 0, which goes up for every batch?
        data = data[0]
        data = data.to(device) # All computations are supposed to be computed on my GPU. 
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        bce_loss = criterion(reconstruction, data)
        loss = final_loss(bce_loss, mu, logvar)
        loss.backward()  # Using the loss, backpropagation occurs. Thus, all the tensors that will be connected to this, will be involved in this computation. 
        running_loss += loss.item() #This here defined for every step along the way, how high is the loss. 
        optimizer.step()  #That with a gradient will be updated in one step according to the documentation. 
    train_loss = running_loss / counter 
    return train_loss #The function returns only the training loss. 

Here the validation functions:

In [None]:

def validate(model, dataloader, dataset, device, criterion):
    model.eval()
    running_loss = 0.0
    counter = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
            counter += 1
            data= data[0]
            data = data.to(device)
            reconstruction, mu, logvar = model(data)
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(dataset)/dataloader.batch_size) - 1:
                recon_images = reconstruction
    val_loss = running_loss / counter
    return val_loss, recon_images

Again I also want to save images. The goal would be to have disentangled pictures. 

In [None]:
to_pil_image = transforms.ToPILImage() #conversion of a tensor to an image, this will later be used to generate .gif images!

#The following function converts the PILutilityges to .gif files. The accepted data are numpy arrays!
def image_to_vid(images):     
    imgs = [np.array(to_pil_image(img)) for img in images]
    imageio.mimsave('../outputs/generated_images.gif', imgs)

#This function is equal to the function above, but for outputs of the VAE
def save_reconstructed_images(recon_images, epoch):
    save_image(recon_images.cpu(), f"../outputs/output{epoch}.jpg")  #The save image function comes from torchvision.utilis!

#Finally, here the training and validation losses are saved into a plot! This is done via matplotlib. 
def save_loss_plot(train_loss, valid_loss):
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss, color='orange', label='train loss')
    plt.plot(valid_loss, color='red', label='validataion loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('../outputs/loss.jpg')
    plt.show()

Now, lastly I will define some parameters that I will use before I define the full model.

In [None]:
kernel_size = 3 # (4, 4) kernel
init_channels = 8 # initial number of filters, first layers output. 
image_channels = 1 # MNIST images are grayscale
latent_dim = 16 # latent dimension for sampling
stride = 2
channels = 1  #working with grayscale. 

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class EncoderZ(nn.Module):
    def __init__(self, latent_dim, channels):
        super(EncoderZ, self).__init__()
        self.conv1 = nn.Conv2d(channels, 32, kernel_size= kernel_size, stride= stride, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size= kernel_size, stride= stride, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc_mean = nn.Linear(256, latent_dim)
        self.fc_log_var = nn.Linear(256, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        z_mean = self.fc_mean(x)
        z_log_var = self.fc_log_var(x)
        return z_mean, z_log_var

class EncoderS(nn.Module):
    def __init__(self, latent_dim, channels):
        super(EncoderS, self).__init__()
        self.conv1 = nn.Conv2d(channels, 32, kernel_size= kernel_size, stride= stride, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size= kernel_size , stride= stride, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc_mean = nn.Linear(256, latent_dim)
        self.fc_log_var = nn.Linear(256, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        s_mean = self.fc_mean(x)
        s_log_var = self.fc_log_var(x)
        return s_mean, s_log_var

class Decoder(nn.Module):
    def __init__(self, latent_dim, channels):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 7 * 7 * 64)
        self.conv1 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(32, channels, kernel_size=4, stride=2, padding=1)

    def forward(self, z):
        x = F.relu(self.fc1(z))
        x = F.relu(self.fc2(x))
        x = x.view(-1, 64, 7, 7)
        x = F.relu(self.conv1(x))
        x = torch.sigmoid(self.conv2(x))
        return x

class cVAE(nn.Module):
    def __init__(self, latent_dim, channels):
        super(cVAE, self).__init__()
        self.encoder_z = EncoderZ(latent_dim, channels)
        self.encoder_s = EncoderS(latent_dim, channels)
        self.decoder = Decoder(latent_dim, channels)

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mean + epsilon * std

Right now, I can't directly access the images to superimpose them. Therefore, I will have to get the images with the following function. Additionally, I have incorporated a greyscaler into this. 

In [5]:
from PIL import Image
import pickle

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo)
    return dict

def load_databatch(data_folder, idx, img_size=32):
    data_file = os.path.join(data_folder, 'train_data_batch_' + str(idx))
    d = unpickle(data_file)
    x = d['data']
    y = d['labels']
    mean_image = d['mean']

    x = x / np.float32(255)
    mean_image = mean_image / np.float32(255)

    # Labels are indexed from 1, shift it so that indexes start at 0
    y = [i - 1 for i in y]
    data_size = x.shape[0]

    x -= mean_image

    img_size2 = img_size * img_size

    x = np.dstack((x[:, :img_size2], x[:, img_size2:2 * img_size2], x[:, 2 * img_size2:]))
    x = x.reshape((x.shape[0], img_size, img_size, 3)).transpose(0, 3, 1, 2)

    # create mirrored images
    X_train = x[0:data_size, :, :, :]
    Y_train = y[0:data_size]
    X_train_flip = X_train[:, :, :, ::-1]
    Y_train_flip = Y_train
    X_train = np.concatenate((X_train, X_train_flip), axis=0)
    Y_train = np.concatenate((Y_train, Y_train_flip), axis=0)

    return dict(
        X_train=X_train.astype(np.float32),
        Y_train=np.array(Y_train, dtype=np.int32),
        mean=mean_image.astype(np.float32)
    )


Let's superimpose them with each other!

Firstly here is the way I loaded the mnist pictures as it was done before [here](./vae.ipynb). The data is loaded and resized to 32x32 to fit the Image-Net Data. 
I have once again split the data into a training and validation set. 


In [6]:
# Data transformations
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
    torchvision.transforms.ToTensor(),
])

# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(
    root='./input', train=True, download=True, transform=transform
)
testset = torchvision.datasets.MNIST(
    root='./input', train=False, download=True, transform=transform
)

Firstly I am unpickleing one batch and doing the calculations that were previously defined. 

In [7]:
batch_1 = load_databatch("./input/Imagenet32_train", 1)


Now, lets split the channels into matrices to then greyscale them. 

In [19]:
image_data = batch_1["X_train"]

# Get the number of images
num_images = len(image_data)

# Initialize an empty array to store the reshaped grayscale images
grayscale_images = np.empty((num_images, 32, 32), dtype=np.float32)

print("The shape is now: ",image_data.shape)

# Reshape the images
for i, image in enumerate(image_data):
    # Reshape the image to match the color channel dimensions (3x32x32)
    image = np.reshape(image, (3, 32, 32))

    # Split the image into separate color channels
    red_channel = image[0]
    green_channel = image[1]
    blue_channel = image[2]

    # Combine the color channels weighted by their respective coefficients to form the grayscale image - these values are taken from recommedations from ChatGPT
    grayscale_image = 0.2989 * red_channel + 0.5870 * green_channel + 0.1140 * blue_channel

    # Normalize the grayscale image to the range [0, 1]
    grayscale_image /= 255.0

    # Store the grayscale image in the array
    grayscale_images[i] = grayscale_image

# Store the reshaped grayscale images in a new entry in the dictionary
batch_1["X_train_gray"] = grayscale_images
print("The shape is now: ",greenscale_images.shape)

This seemed to have worked. Now, lets move on to overlaying them. The following still needs proper testing. 

In [None]:
# Convert MNIST images to numpy arrays
x_train_mnist = np.array(mnist_dataset.data)
y_train_mnist = np.array(mnist_dataset.targets)

# Resize MNIST images to match the size of the pictures in the dictionary
num_images = len(batch_1["X_train"])
x_train_resized = np.array([np.resize(img, (32, 32)) for img in x_train_mnist[:num_images]])

# Overlay MNIST images on the pictures in the dictionary
for i in range(num_images):
    x, y = 10, 10  # Location to overlay the MNIST image (adjust as needed)
    alpha = 0.5  # Opacity of the overlay

    # Overlay the MNIST image onto the picture
    batch_1["X_train"][i][x:x+28, y:y+28] = alpha * x_train_resized[i] + (1 - alpha) * batch_1["X_train"][i][x:x+28, y:y+28]

    # Update the corresponding indices for the overlayed images
    batch_1["y_train"][i] = y_train_mnist[i]

# Print the dimensions of one picture in the dictionary
print("Dimensions of a picture after overlaying MNIST dataset:", batch_1["X_train"][0].shape)