## Prelimnaries

In [1]:
# Importing the necessary libraries
import torch  # Import PyTorch library
import torch.nn as nn  # Import neural network module
import torch.optim as optim  # Import optimization module
from torchvision import datasets, transforms  # Import datasets and transforms
from torchvision.utils import save_image, make_grid  # Import utility to save images
import torchvision  # Import torchvision library
import matplotlib.pyplot as plt  # Import plotting library
import os  # Import os module for file operations
import numpy as np  # Import numpy library

In [2]:
# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Set device
print(f"Using device: {device}")  # Print the device being used

Using device: cuda


## Define models

The Variational Autoencoder (VAE) architecture can be visualized as follows:
```
    Input (x)
       |
       v
   Encoder q(z|x)
       |
       v
   Latent Space (z)
       |
       v
   Decoder p(x|z)
       |
       v
   Reconstructed (x̂)
```
We can see the key components of a VAE:

- **Encoder:** We model $q(z|x)$ as a neural network with parameters $\phi$. The network takes in an observation $x$ and outputs the parameters of a Gaussian distribution ie mean $\mu_{\phi}(x)$ and covariance $\Sigma_{\phi}(x)$.

- **Decoder:** We model $p_{\theta}(x|z)$ as a neural network with parameters $\theta$. The network takes in a sampled latent variable $z$ from the distribution with parameters $\mu_{\phi}(x)$ and $\Sigma_{\phi}(x)$ and outputs a data sample $\hat{x}$. Post training, we use the decoder to generate new data samples ie works as generator.


### Encoder

In [3]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim  # Dimension of the latent space as instance variable because it an important hyperparameter.
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # Input: N x 3 x 128 x 128, Output: N x 32 x 64 x 64, 32 filters of size 4x4x3
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # Input: N x 32 x 64 x 64, Output: N x 64 x 32 x 32, 64 filters of size 4x4x32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Input: N x 64 x 32 x 32, Output: N x 128 x 16 x 16, 128 filters of size 4x4x64
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Input: N x 128 x 16 x 16, Output: N x 256 x 8 x 8, 256 filters of size 4x4x128
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # Input: N x 256 x 8 x 8, Output: N x 512 x 4 x 4, 512 filters of size 4x4x256
            nn.ReLU()
        )
        
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)  # Input: N x (512 * 4 * 4), Output: N x latent_dim, Fully connected layer for mean
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)  # Input: N x (512 * 4 * 4), Output: N x latent_dim, Fully connected layer for log variance becaues in vanilla VAE we use a diagonal covariance matrix ie we assume the latent variables are independent
                                                             # It produces diagonal elements of the covariance matrix which are Log variance. During the reparameterization trick, we use these log variances to reconstruct the diagonal covariance matrix by exponentiating them.
    def forward(self, x):
        x = self.conv_layers(x)  # Pass input through convolutional layers
        x = x.view(x.size(0), -1)  # Flatten the output
        mu = self.fc_mu(x)  # Calculate mean
        logvar = self.fc_logvar(x)  # Calculate log variance
        return mu, logvar

### Decoder

In [4]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim  # Dimension of the latent space as instance variable for consistency with Encoder

        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)  # Fully connected layer to transform latent vector to appropriate size for deconvolution

        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # Input: N x 512 x 4 x 4, Output: N x 256 x 8 x 8
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Input: N x 256 x 8 x 8, Output: N x 128 x 16 x 16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Input: N x 128 x 16 x 16, Output: N x 64 x 32 x 32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Input: N x 64 x 32 x 32, Output: N x 32 x 64 x 64
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # Input: N x 32 x 64 x 64, Output: N x 3 x 128 x 128
            nn.Sigmoid()  # Sigmoid activation to ensure output is in range [0, 1]
        )

    def forward(self, z):
        x = self.fc(z)  # Transform latent vector to appropriate size for deconvolution
        x = x.view(-1, 512, 4, 4)  # Reshape to 4D tensor for deconvolution
        x = self.deconv_layers(x)  # Pass through deconvolutional layers
        return x  # Return reconstructed image

### Reparameterization Trick


We have to minimize ELBO:
$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

Focus on first term
$$\mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right]$$

We introduce a function $g_{\phi}(\epsilon)$ that transforms a noise variable $\epsilon$ into $z$

$$ z = g_{\phi}(\epsilon) $$

This allows us to rewrite the expectation in terms of $\epsilon$

$$ \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] = \mathbb{E}_{p(\epsilon)} \left[ \log p_{\theta}(x|g_{\phi}(\epsilon)) \right] $$

Where:
- $\epsilon \sim p(\epsilon)$ (typically a standard normal distribution)
- $g_{\phi}(\epsilon)$ is our reparameterization function typically:
$$ z = \mu + \sigma \odot \epsilon $$

The gradient can then be estimated using Monte Carlo sampling:

$$ \nabla_{\phi} \mathbb{E}_{q(z|x)}[\log p_{\theta}(x|z)] \approx \frac{1}{N} \sum_{i=1}^N \nabla_{\phi} [\log p_{\theta}(x|g_{\phi}(\epsilon_i))]  \quad \epsilon_i \sim p(\epsilon) $$

This approach allows the gradient to flow through the sampling process, enabling effective optimization of the VAE.

In [5]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
    eps = torch.randn_like(std)  # Generate random noise with same shape as std
    z = mu + eps * std  # Reparameterization trick: z = μ + σ * ε
    return z  # Return the sampled latent vector


## Define the loss function

log likelihood is given by:

$$ \ell(\theta) = \log p_{\theta}(x) $$

We assume that each data point $x_i$ is associated with a latent variable $z_i$.
Hence, we will introduce the latent variable $z$ and marginalize over it:

$$ \ell(\theta) = \log \sum_z p_{\theta}(x, z) $$

Let $q(z|x)$ be a conditional distribution over $z$ given $x$.

$$ \ell(\theta) = \log \sum_z q(z|x) \frac{p_{\theta}(x, z)}{q(z|x)} $$

$$ \ell(\theta) = \log \mathbb{E}_{q(z|x)} \left[ \frac{p_{\theta}(x, z)}{q(z|x)} \right] $$

By Jensen's inequality, we have:

$$ \log \mathbb{E}_{q(z|x)} \left[ \frac{p_{\theta}(x, z)}{q(z|x)} \right] \geq \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(x, z)}{q(z|x)} \right] $$

Hence, we have:

$$ \ell(\theta) \geq \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(x, z)}{q(z|x)} \right] = F_{\theta}(q) $$

Where $F_{\theta}(q)$ is the evidence lower bound (ELBO).

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(x|z)p_{\theta}(z)}{q(z|x)} \right] $$

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] + \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(z)}{q(z|x)} \right] $$

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - \mathbb{E}_{q(z|x)} \left[ \log \frac{q(z|x)}{p_{\theta}(z)} \right] $$

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

Here the first term is the conditional log likelihood of the data under the model. We want to maximise this term.

The second term is the KL divergence between the posterior and the prior. We want to minimise this term.

If we assume $p_{\theta}(x|z) = p_{\theta}(x|g_{\phi}(\epsilon)) \sim N(x; x_i, I)$, which is a model assumption. This allows us to calculate the log-likelihood $\log p_{\theta}(x|z)$ using the generated samples $\hat{x}_j^i$ and the original input $x_i$ as follows(derivation skipped):

$$ \mathbb{E}_{q(x|z)} \left[ \log p_{\theta}(x|z) \right] = \mathbb{E}_{p(\epsilon)} \left[ \log p_{\theta}(x|g_{\phi}(\epsilon)) \right] \approx \frac{1}{m} \sum_{j=1}^m \log p_{\theta}(x|g_{\phi}(\epsilon_j)) \propto \frac{1}{m} \sum_{j=1}^m \|x_i - \hat{x}_j^i\|_2^2 $$

### Reconstruction loss

In [6]:
def reconstruction_loss(x_reconstructed, x):
    """
    Calculates the reconstruction loss using Mean Squared Error.
    
    Args:
    x_reconstructed (torch.Tensor): The reconstructed input from the decoder with shape (batch_size, input_dimension).
    x (torch.Tensor): The original input with shape (batch_size, input_dimension).
    
    Returns:
    torch.Tensor: The reconstruction loss, a scalar value representing the total loss across the batch.
    """
    return nn.functional.mse_loss(x_reconstructed, x, reduction='sum')  # Calculate reconstruction loss using Mean Squared Error

### KL Divergence loss

Recall that:

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

the second term is:

$$ D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

We want to minimise this term.

We assume the latent prior $p_\theta(z) \sim N(0, I)$, where $I$ is the identity matrix.

The approximate posterior $q(z|x)$ is modeled as $N(z; \mu_\phi(x), \Sigma_\phi(x))$.

Given these assumptions, we can derive the KL divergence in closed form as:

$$ D_{KL}(N(z; \mu_\phi(x), \Sigma_\phi(x)) \| N(0, I)) = \frac{1}{2} \sum_{j=1}^J \left( \mu_{\phi,j}^2(x) + \Sigma_{\phi,j}(x) - \log \Sigma_{\phi,j}(x) - 1 \right) $$

Where:
- $J$ is the dimensionality of the latent space
- $\mu_{\phi,j}(x)$ is the j-th element of the mean vector
- $\Sigma_{\phi,j}(x)$ is the j-th diagonal element of the covariance matrix

In [7]:
def kl_divergence_loss(mu, logvar):
    """
    Calculates the KL divergence loss.
    
    Args:
    mu (torch.Tensor): The mean of the latent distribution.
    logvar (torch.Tensor): The log variance of the latent distribution.
    
    Returns:
    torch.Tensor: The KL divergence loss.
    """
    kl_divergence_loss_value = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1) 
    return kl_divergence_loss_value  # Return the calculated KL divergence loss

## Data preparation

In [8]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images to 128x128
    transforms.ToTensor(),  # Convert images to PyTorch tensors and scale to [0, 1]
])

In [9]:
# Load the dataset from the specified relative path
animal_dataset = datasets.ImageFolder(root='Animal_data_resized/animals',  # Specify the root directory of the dataset
                                       transform=transform)  # Apply the defined transformations to the dataset

In [10]:
# Create a DataLoader for the dataset to enable batch processing
dataloader = torch.utils.data.DataLoader(animal_dataset, 64, shuffle=True, num_workers=2)

## Training loop


In [11]:
# Instantiate the encoder model, assuming Encoder is a predefined class for the encoder architecture
encoder_model = Encoder()  # Create an instance of the Encoder class

# Instantiate the decoder model, assuming Decoder is a predefined class for the decoder architecture
decoder_model = Decoder()  # Create an instance of the Decoder class

# Move the encoder model to the specified device (e.g., GPU or CPU)
encoder_model.to(device)  # Transfer the encoder model to the appropriate device

# Move the decoder model to the specified device (e.g., GPU or CPU)
decoder_model.to(device)  # Transfer the decoder model to the appropriate device

Decoder(
  (fc): Linear(in_features=128, out_features=8192, bias=True)
  (deconv_layers): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
    (8): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): Sigmoid()
  )
)

In [12]:
encoder_model.train()  # Set the encoder model to training mode to enable dropout and batch normalization layers
decoder_model.train()  # Set the decoder model to training mode to enable dropout and batch normalization layers

Decoder(
  (fc): Linear(in_features=128, out_features=8192, bias=True)
  (deconv_layers): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
    (8): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): Sigmoid()
  )
)

In [13]:
# Define the optimizer for the encoder model using Adam optimizer with a learning rate of 0.001
encoder_optimizer = torch.optim.Adam(encoder_model.parameters(), lr=0.001)  # Create an Adam optimizer for the encoder model

# Define the optimizer for the decoder model using Adam optimizer with a learning rate of 0.001
decoder_optimizer = torch.optim.Adam(decoder_model.parameters(), lr=0.001)  # Create an Adam optimizer for the decoder model

In [14]:
# Hyperparameters for training
num_epochs = 100 
num_samples = 10  # Number of latent samples per input

In [15]:
# Initialize arrays to store losses for each epoch
epoch_total_losses = []  # Array to store total losses for each epoch
epoch_reconstruction_losses = []  # Array to store reconstruction losses for each epoch
epoch_kl_divergence_losses = []  # Array to store KL divergence losses for each epoch

# Training loop
for epoch_index in range(num_epochs):  # Iterate through the specified number of epochs
    print(f"Epoch {epoch_index + 1}/{num_epochs}")  # Print the current epoch number
    epoch_loss = 0  # Initialize the total loss for this epoch
    epoch_reconstruction_loss = 0  # Initialize the reconstruction loss for this epoch
    epoch_kl_divergence_loss = 0  # Initialize the KL divergence loss for this epoch

    for batch_index, (images_batch, _) in enumerate(dataloader):  # Iterate through batches of images
        images_batch = images_batch.to(device)  # Move the batch of images to the specified device (GPU/CPU)
        print(f"Processing batch {batch_index + 1} with {images_batch.size(0)} images")  # Print information about the current batch being processed
        
        # Forward pass through encoder to obtain mean and log variance
        mean_output, log_variance_output = encoder_model(images_batch)  # Encode the input images to get mean and log variance
        
        reconstructed_samples_list = []  # List to store multiple reconstructed samples
        for sample_index in range(num_samples):  # Loop to generate 'num_samples' reconstructions
            latent_vector_z = reparameterize(mean_output, log_variance_output).to(device)  # Sample latent vectors using the reparameterization trick and move to device
            reconstructed_sample = decoder_model(latent_vector_z)  # Decode the sampled latent vectors to reconstruct images
            reconstructed_samples_list.append(reconstructed_sample)  # Add the reconstructed sample to the list
        reconstructed_samples_tensor = torch.stack(reconstructed_samples_list, dim=0).to(device)  # Stack all reconstructed samples into a single tensor and move to device

        # Compute mean reconstruction across all samples
        mean_reconstructed_images = torch.mean(reconstructed_samples_tensor, dim=0).to(device)  # Calculate the mean of all reconstructed samples and move to device

        # Calculate reconstruction loss using the mean reconstructed images and the original images
        reconstruction_loss_value = reconstruction_loss(mean_reconstructed_images, images_batch)  # Compute reconstruction loss (both tensors are already on device)

        # Calculate KL divergence loss using the mean and log variance from the encoder
        kl_divergence_loss_value = kl_divergence_loss(mean_output, log_variance_output)  # Compute KL divergence loss (both tensors are already on device)

        # Compute total loss by summing reconstruction loss and KL divergence loss
        total_loss_value = reconstruction_loss_value + kl_divergence_loss_value  # Total loss for both encoder and decoder

        # Zero gradients for both encoder and decoder optimizers
        encoder_optimizer.zero_grad()  # Reset gradients for the encoder optimizer
        decoder_optimizer.zero_grad()  # Reset gradients for the decoder optimizer

        # Backward pass for both encoder and decoder (using total loss)
        total_loss_value.backward()  # Compute gradients for both encoder and decoder
        encoder_optimizer.step()  # Update the encoder's parameters based on the gradients
        decoder_optimizer.step()  # Update the decoder's parameters based on the gradients

        # Accumulate losses for this epoch
        epoch_loss += total_loss_value.item()  # Add the total loss of this batch to the epoch loss
        epoch_reconstruction_loss += reconstruction_loss_value.item()  # Add the reconstruction loss of this batch to the epoch reconstruction loss
        epoch_kl_divergence_loss += kl_divergence_loss_value.item()  # Add the KL divergence loss of this batch to the epoch KL divergence loss

        # Generate and save an image for this batch
        with torch.no_grad():  # Disable gradient calculation for inference
            decoder_model.eval()  # Set decoder to evaluation mode
            sample_latent_vector = torch.randn(1, 128).to(device)  # Generate a random latent vector on the correct device
            generated_image = decoder_model(sample_latent_vector)  # Generate an image using the decoder
            generated_image = generated_image.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert tensor to numpy array and rearrange dimensions
            generated_image = (generated_image * 255).astype(np.uint8)  # Scale pixel values to 0-255 range
            decoder_model.train()  # Set decoder back to training mode
            
        # Create a directory to save generated images if it doesn't exist
        os.makedirs('generated_images', exist_ok=True)  # Create 'generated_images' directory if it doesn't exist
        
        # Save the generated image
        image_path = f'generated_images/generated_image_epoch_{epoch_index + 1}_batch_{batch_index + 1}.png'  # Define the path for saving the image
        plt.imsave(image_path, generated_image)  # Save the generated image using matplotlib
        print(f"Generated image saved at: {image_path}")  # Print the path where the image was saved

    # Calculate average losses for this epoch and store them
    average_epoch_loss = epoch_loss / len(dataloader)  # Compute average total loss for the epoch
    average_epoch_reconstruction_loss = epoch_reconstruction_loss / len(dataloader)  # Compute average reconstruction loss for the epoch
    average_epoch_kl_divergence_loss = epoch_kl_divergence_loss / len(dataloader)  # Compute average KL divergence loss for the epoch
    
    epoch_total_losses.append(average_epoch_loss)  # Store average total loss for this epoch
    epoch_reconstruction_losses.append(average_epoch_reconstruction_loss)  # Store average reconstruction loss for this epoch
    epoch_kl_divergence_losses.append(average_epoch_kl_divergence_loss)  # Store average KL divergence loss for this epoch
    
    print(f"Average total loss for epoch {epoch_index + 1}: {average_epoch_loss}")  # Print average total loss
    print(f"Average reconstruction loss for epoch {epoch_index + 1}: {average_epoch_reconstruction_loss}")  # Print average reconstruction loss
    print(f"Average KL divergence loss for epoch {epoch_index + 1}: {average_epoch_kl_divergence_loss}")  # Print average KL divergence loss

    # Save models every epoch
    os.makedirs('saved_models', exist_ok=True)  # Create a directory to save the models if it doesn't exist
    torch.save(encoder_model.state_dict(), f'saved_models/encoder_epoch_{epoch_index + 1}.pth')  # Save the encoder model
    torch.save(decoder_model.state_dict(), f'saved_models/decoder_epoch_{epoch_index + 1}.pth')  # Save the decoder model
    print(f"Models saved in 'saved_models' folder at epoch {epoch_index + 1}")  # Print a message indicating the models have been saved

print("Training completed.")  # Print a message indicating that the entire training process has finished

Epoch 1/100
Processing batch 1 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_1.png
Processing batch 2 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_2.png
Processing batch 3 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_3.png
Processing batch 4 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_4.png
Processing batch 5 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_5.png
Processing batch 6 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_6.png
Processing batch 7 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_7.png
Processing batch 8 with 64 images
Generated image saved at: generated_images/generated_image_epoch_1_batch_8.png
Processing batch 9 with 64 images
Generated image saved at: generated_images/generat