# Data Generation - Variational Auto Encoder (VAE)

In [1]:
import os
import math
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.data import Data

In [2]:
base_dir: str = os.path.join(os.getcwd(), 'data')

tuple_data: tuple[list[pd.DataFrame], np.ndarray, np.ndarray] = Data.load_data(base_dir)

# Unpack data
pd_data: list[pd.DataFrame] = tuple_data[0]
labels: np.ndarray = tuple_data[1]
classes: np.ndarray = tuple_data[2]

0 OUI
1 NON
2 VRAI
3 FAUX


In [3]:
data: np.ndarray = Data.convert_to_numpy(pd_data)

## Model

### Encoder

In [4]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim: int = 2):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, hidden_dim, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(hidden_dim, hidden_dim * 2, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(hidden_dim * 2, hidden_dim * 2, kernel_size=3, stride=2, padding=1)

        # Activation functions
        self.relu = nn.ReLU()

        # Dropout layer(s)
        self.dropout = nn.Dropout(p=0.3)
        
        # Pooling layer
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        encoded = self.relu(self.conv4(x))

        encoded = self.dropout(encoded)
        return encoded

### Decoder

In [5]:
class Trim(nn.Module):
    def __init__(self, size1, size2):
        super(Trim, self).__init__()
        self.size1 = size1
        self.size2 = size2
    
    def forward(self, x):
        return x[:, :, :self.size1, :self.size2]

In [6]:
class Decoder(nn.Module):
    def __init__(self, hidden_dim: int = 2, img_shape: tuple = (72, 114)):
        super(Decoder, self).__init__()

        self.deconv1 = nn.ConvTranspose2d(hidden_dim * 2, hidden_dim * 2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, kernel_size=3, stride=2, padding=1, output_padding=0)
        self.deconv3 = nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1, output_padding=0)
        self.deconv4 = nn.ConvTranspose2d(hidden_dim, 1, kernel_size=3, stride=2, padding=1, output_padding=0)

        # Activation functions
        self.relu = nn.ReLU()

        # Trimming
        self.trim = Trim(img_shape[0], img_shape[1])

    def forward(self, x):
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        x = self.relu(self.deconv3(x))
        decoded = self.relu(self.deconv4(x))

        decoded = self.trim(decoded)
        
        return decoded

### Variational Autoencoder (VAE)

<u>**Log-variance:**</u><br><br>
Instead of learning the variance $\sigma^2$, we here learn the log-variance $\log(\sigma^2)$ because it can take any real number, while the variance can only take positive values. We can then exponentiate the log-variance to get the variance. Here is how we can do this:<br>

$$\log(\sigma^2) = 2 * \log(\sigma)$$
$$\log(\sigma^2) / 2 = \log(\sigma)$$
$$e^{\log(\sigma^2) / 2} = \sigma$$

In [8]:
DEVICE = 'mps'

class VariationalAutoEncoder(nn.Module):
    def __init__(self, hidden_dim: int = 128, latent_dim: int = 2, input_shape: tuple = (72, 114)):
        super(VariationalAutoEncoder, self).__init__()
        self.encoder = Encoder(hidden_dim=hidden_dim)
        self.decoder = Decoder(hidden_dim=hidden_dim)

        self.nb_features = hidden_dim * 2
        self.intermediate_shape = (math.ceil(input_shape[0] / (2 ** 4)), math.ceil(input_shape[1] / (2 ** 4)))

        intermediary_dim = self.nb_features * self.intermediate_shape[0] * self.intermediate_shape[1]

        # Fully connected layers
        self.fc1 = nn.Linear(intermediary_dim, hidden_dim)
        self.mu_layer = nn.Linear(hidden_dim, latent_dim)
        self.logvar_layer = nn.Linear(hidden_dim, latent_dim)
        self.fc2 = nn.Linear(latent_dim, intermediary_dim)

        self.hidden_dim = hidden_dim

    def reparametrize(self, mu, logvar):
        sigma = torch.exp(logvar * 0.5).to(DEVICE)
        epsilon = torch.randn_like(sigma).to(DEVICE)

        return mu + epsilon * sigma
    
    def forward(self, x):
        encoded = self.encoder(x)
        flattened_encoded = encoded.view(x.size(0), -1)

        hidden = self.fc1(flattened_encoded)

        mu, logvar = self.mu_layer(hidden), self.logvar_layer(hidden)

        z = self.reparametrize(mu, logvar)
        z = self.fc2(z)
        reshaped_z = z.view(x.size(0),
                            self.nb_features,
                            self.intermediate_shape[0],
                            self.intermediate_shape[1])

        decoded = self.decoder(reshaped_z)

        return decoded, mu, logvar

    def loss_function(self, x, x_hat, mu, logvar):
        reconstruction_loss = nn.MSELoss()(x_hat, x)
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return reconstruction_loss + kld_loss
    
    def fit(self, X_train, X_val, epochs: int = 10, learning_rate: float = 1e-3, batch_size: int = 32):
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)

        
        n_train_batches = X_train.shape[0] // batch_size
        n_val_batches = X_val.shape[0] // batch_size

        print('Training on {} batches, validating on {} batches'.format(n_train_batches, n_val_batches))

        for epoch in tqdm(range(epochs)):
            mean_loss = 0
            self.train()
            for i, x in enumerate(Data.unlabeled_data_generator(X_train, batch_size=batch_size)):
                x = x.to(DEVICE).unsqueeze(1)
                x_hat, mu, logvar = self(x)
                loss = self.loss_function(x, x_hat, mu, logvar)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                mean_loss += loss.item()
            
            mean_loss /= n_train_batches

            self.eval()
            with torch.no_grad():
                mean_val_loss = 0
                for i, x in enumerate(Data.unlabeled_data_generator(X_val, batch_size=batch_size)):
                    x = x.to(DEVICE).unsqueeze(1)
                    x_hat, mu, logvar = self(x)
                    val_loss = self.loss_function(x, x_hat, mu, logvar)

                    mean_val_loss += val_loss.item()

                mean_val_loss /= n_val_batches
                print('epoch [{}/{}], loss: {:.2f}, val_loss: {:.2f}                         '.format(epoch+1, epochs, mean_loss, mean_val_loss), end='\r')
seed = 42
torch.manual_seed(seed)
torch.mps.manual_seed(seed)

X_train, X_test = train_test_split(data, test_size=0.2, random_state=seed)
X_train, X_val = train_test_split(X_train, test_size=0.15, random_state=seed)

vae = VariationalAutoEncoder(hidden_dim=32, latent_dim=16).to(DEVICE)
vae.fit(X_train, X_val, epochs=1000, learning_rate=1e-3, batch_size=4)

Training on 17 batches, validating on 3 batches


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch [1000/1000], loss: 345.71, val_loss: 1432.10                          

In [45]:
vae.eval()
with torch.no_grad():
    test_mean_loss = 0
    n_test_batches = X_test.shape[0] // 4
    for i, x in enumerate(Data.unlabeled_data_generator(X_test, batch_size=4)):
        images = x.to(DEVICE).unsqueeze(1)
        x_hat, z_mu, z_logvar = vae(images)
        test_mean_loss = vae.loss_function(x_hat, images, z_mu, z_logvar)
        test_mean_loss += test_mean_loss.item()
        
    test_mean_loss = test_mean_loss / n_test_batches
    print(f"Test loss: {test_mean_loss}")

Test loss: 344.11517333984375
