In [None]:
# Importing the necessary libraries
import torch  # Import PyTorch library
import torch.nn as nn  # Import neural network module
import torch.optim as optim  # Import optimization module
from torchvision import datasets, transforms  # Import datasets and transforms
from torchvision.utils import save_image, make_grid  # Import utility to save images
from torch.utils.data import Dataset, DataLoader
import torchvision  # Import torchvision library
import matplotlib.pyplot as plt  # Import plotting library
import os  # Import os module for file operations
import numpy as np  # Import numpy library        nn.InstanceNorm2d(out_channels),

from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
import matplotlib.pyplot as plt
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import inception_v3
import numpy as np
import shutil
from PIL import Image  # Import PIL for image processing
import scipy


In [None]:
# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Set device
print(f"Using device: {device}")  # Print the device being used

## Hyperparameters

In [None]:
IMG_SIZE = 128
LATENT_DIM = 128
# Hyperparameters for training
NUM_EPOCHS = 250
tag = os.getenv("TAG", None)
tag = f"_{tag}" if tag else ""

TRAIN_VAE = True
TRAIN_BETA_VAE = False
LOAD_BETA_VAE = False
TRAIN_CNN = False
TRAIN_MLP = False

BUTTERFLY = "butterfly"
ANIMAL = "animal"
dataset = os.getenv('DATASET', BUTTERFLY)  # butterfly or animal
NUM_SAMPLES = int(os.getenv('NS', 10))  # Number of latent samples per input

BATCH_SIZE = 128 if NUM_SAMPLES <= 10 else (96 if NUM_SAMPLES <= 15 else 64)

lr = 0.005

In [None]:
LOG_DIR = log_dir = f'VAE/tensorboard/{dataset}_ns_{NUM_SAMPLES}{tag}'
IMG_DIR = f'VAE/generated_images/{dataset}_ns_{NUM_SAMPLES}{tag}'


def recreate_directory(dir_path):
    # Check if the directory exists
    if os.path.exists(dir_path):
        # Delete the directory if it exists
        shutil.rmtree(dir_path)
    # Create the directory
    os.makedirs(dir_path)

os.makedirs('VAE/models', exist_ok=True)

if TRAIN_VAE:
    recreate_directory(LOG_DIR)
    recreate_directory(IMG_DIR)

In [None]:

# Tensorboard stuff
writer = SummaryWriter(LOG_DIR)
event_acc = EventAccumulator(LOG_DIR)


def log_losses_to_tensorboard(epoch, e_loss, d_loss, total_loss):
    writer.add_scalar('Loss/Encoder', e_loss, epoch)
    writer.add_scalar('Loss/Decoder', d_loss, epoch)
    writer.add_scalar('Loss/Total', total_loss, epoch)


def log_gradients_to_tensorboard(model, epoch, model_name):
    total_norm = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            norm = param.grad.norm(2).item()
            total_norm += norm ** 2
            writer.add_scalar(f'Gradients/{model_name}/{name}', norm, epoch)
    total_norm = total_norm ** 0.5
    writer.add_scalar(f'Gradients/{model_name}/total_norm', total_norm, epoch)

## Define models

The Variational Autoencoder (VAE) architecture can be visualized as follows:
```
    Input (x)
       |
       v
   Encoder q(z|x)
       |
       v
   Latent Space (z)
       |
       v
   Decoder p(x|z)
       |
       v
   Reconstructed (x̂)
```
We can see the key components of a VAE:

- **Encoder:** We model $q(z|x)$ as a neural network with parameters $\phi$. The network takes in an observation $x$ and outputs the parameters of a Gaussian distribution ie mean $\mu_{\phi}(x)$ and covariance $\Sigma_{\phi}(x)$.

- **Decoder:** We model $p_{\theta}(x|z)$ as a neural network with parameters $\theta$. The network takes in a sampled latent variable $z$ from the distribution with parameters $\mu_{\phi}(x)$ and $\Sigma_{\phi}(x)$ and outputs a data sample $\hat{x}$. Post training, we use the decoder to generate new data samples ie works as generator.

### Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim  # Dimension of the latent space as instance variable because it an important hyperparameter.
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # Input: N x 3 x 128 x 128, Output: N x 32 x 64 x 64, 32 filters of size 4x4x3
            nn.BatchNorm2d(32),  # Add BatchNorm
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # Input: N x 32 x 64 x 64, Output: N x 64 x 32 x 32, 64 filters of size 4x4x32
            nn.BatchNorm2d(64),  # Add BatchNorm
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Input: N x 64 x 32 x 32, Output: N x 128 x 16 x 16, 128 filters of size 4x4x64
            nn.BatchNorm2d(128),  # Add BatchNorm
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Input: N x 128 x 16 x 16, Output: N x 256 x 8 x 8, 256 filters of size 4x4x128
            nn.BatchNorm2d(256),  # Add BatchNorm
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # Input: N x 256 x 8 x 8, Output: N x 512 x 4 x 4, 512 filters of size 4x4x256
            nn.BatchNorm2d(512),  # Add BatchNorm
            nn.LeakyReLU()
        )
        
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)  # Input: N x (512 * 4 * 4), Output: N x latent_dim, Fully connected layer for mean
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)  # Input: N x (512 * 4 * 4), Output: N x latent_dim, Fully connected layer for log variance becaues in vanilla VAE we use a diagonal covariance matrix ie we assume the latent variables are independent
                                                             # It produces diagonal elements of the covariance matrix which are Log variance. During the reparameterization trick, we use these log variances to reconstruct the diagonal covariance matrix by exponentiating them.
    def forward(self, x):
        x = self.conv_layers(x)  # Pass input through convolutional layers
        x = x.view(x.size(0), -1)  # Flatten the output
        mu = self.fc_mu(x)  # Calculate mean
        logvar = self.fc_logvar(x)  # Calculate log variance
        return mu, logvar

In [None]:
summary(Encoder(), input_size=(1, 3, IMG_SIZE, IMG_SIZE))

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim  # Dimension of the latent space

        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)  # Fully connected layer to transform latent vector to appropriate size for convolution

        self.deconv_layers = nn.Sequential(
            nn.Upsample(scale_factor=2),  # Upsample to N x 512 x 8 x 8
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),  # Output: N x 256 x 8 x 8
            nn.InstanceNorm2d(256),  # Add normalization layer
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),  # Upsample to N x 256 x 16 x 16
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),  # Output: N x 128 x 16 x 16
            nn.InstanceNorm2d(128),  # Add normalization layer
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),  # Upsample to N x 128 x 32 x 32
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),  # Output: N x 64 x 32 x 32
            nn.InstanceNorm2d(64),  # Add normalization layer
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),  # Upsample to N x 64 x 64 x 64
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),  # Output: N x 32 x 64 x 64
            nn.InstanceNorm2d(32),  # Add normalization layer
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2),  # Upsample to N x 32 x 128 x 128
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),  # Output: N x 3 x 128 x 128
            nn.Tanh()
        )

    def forward(self, z):
        x = self.fc(z)  # Transform latent vector to appropriate size for convolution
        x = x.view(-1, 512, 4, 4)  # Reshape to 4D tensor for convolution
        x = self.deconv_layers(x)  # Pass through convolutional layers
        return x  # Return reconstructed image

In [None]:
summary(Decoder(), input_size=(1, LATENT_DIM))

### Reparameterization Trick

We have to minimize ELBO:
$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

Focus on first term
$$\mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right]$$

We introduce a function $g_{\phi}(\epsilon)$ that transforms a noise variable $\epsilon$ into $z$

$$ z = g_{\phi}(\epsilon) $$

This allows us to rewrite the expectation in terms of $\epsilon$

$$ \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] = \mathbb{E}_{p(\epsilon)} \left[ \log p_{\theta}(x|g_{\phi}(\epsilon)) \right] $$

Where:
- $\epsilon \sim p(\epsilon)$ (typically a standard normal distribution)
- $g_{\phi}(\epsilon)$ is our reparameterization function typically:
$$ z = \mu + \sigma \odot \epsilon $$

The gradient can then be estimated using Monte Carlo sampling:

$$ \nabla_{\phi} \mathbb{E}_{q(z|x)}[\log p_{\theta}(x|z)] \approx \frac{1}{N} \sum_{i=1}^N \nabla_{\phi} [\log p_{\theta}(x|g_{\phi}(\epsilon_i))]  \quad \epsilon_i \sim p(\epsilon) $$

This approach allows the gradient to flow through the sampling process, enabling effective optimization of the VAE.

In [None]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
    eps = torch.randn_like(std).to(device)  # Generate random noise with same shape as std
    z = mu + eps * std  # Reparameterization trick: z = μ + σ * ε
    return z  # Return the sampled latent vector

## Define the loss function

log likelihood is given by:

$$ \ell(\theta) = \log p_{\theta}(x) $$

We assume that each data point $x_i$ is associated with a latent variable $z_i$.
Hence, we will introduce the latent variable $z$ and marginalize over it:

$$ \ell(\theta) = \log \sum_z p_{\theta}(x, z) $$

Let $q(z|x)$ be a conditional distribution over $z$ given $x$.

$$ \ell(\theta) = \log \sum_z q(z|x) \frac{p_{\theta}(x, z)}{q(z|x)} $$

$$ \ell(\theta) = \log \mathbb{E}_{q(z|x)} \left[ \frac{p_{\theta}(x, z)}{q(z|x)} \right] $$

By Jensen's inequality, we have:

$$ \log \mathbb{E}_{q(z|x)} \left[ \frac{p_{\theta}(x, z)}{q(z|x)} \right] \geq \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(x, z)}{q(z|x)} \right] $$

Hence, we have:

$$ \ell(\theta) \geq \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(x, z)}{q(z|x)} \right] = F_{\theta}(q) $$

Where $F_{\theta}(q)$ is the evidence lower bound (ELBO).

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(x|z)p_{\theta}(z)}{q(z|x)} \right] $$

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] + \mathbb{E}_{q(z|x)} \left[ \log \frac{p_{\theta}(z)}{q(z|x)} \right] $$

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - \mathbb{E}_{q(z|x)} \left[ \log \frac{q(z|x)}{p_{\theta}(z)} \right] $$

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

Here the first term is the conditional log likelihood of the data under the model. We want to maximise this term.

The second term is the KL divergence between the posterior and the prior. We want to minimise this term.

If we assume $p_{\theta}(x|z) = p_{\theta}(x|g_{\phi}(\epsilon)) \sim N(x; x_i, I)$, which is a model assumption. This allows us to calculate the log-likelihood $\log p_{\theta}(x|z)$ using the generated samples $\hat{x}_j^i$ and the original input $x_i$ as follows(derivation skipped):

$$ \mathbb{E}_{q(x|z)} \left[ \log p_{\theta}(x|z) \right] = \mathbb{E}_{p(\epsilon)} \left[ \log p_{\theta}(x|g_{\phi}(\epsilon)) \right] \approx \frac{1}{m} \sum_{j=1}^m \log p_{\theta}(x|g_{\phi}(\epsilon_j)) \propto \frac{1}{m} \sum_{j=1}^m \|x_i - \hat{x}_j^i\|_2^2 $$

#### Reconstruction loss

In [None]:
def reconstruction_loss(x_reconstructed, x):
    """
    Calculates the reconstruction loss using Mean Squared Error.
    
    Args:
    x_reconstructed (torch.Tensor): The reconstructed input from the decoder with shape (batch_size, input_dimension).
    x (torch.Tensor): The original input with shape (batch_size, input_dimension).
    
    Returns:
    torch.Tensor: The reconstruction loss, a scalar value representing the total loss across the batch.
    """
    return nn.functional.mse_loss(x_reconstructed, x, reduction='sum')  # Calculate reconstruction loss using Mean Squared Error

#### KL Divergence loss

Recall that:

$$ F_{\theta}(q) = \mathbb{E}_{q(z|x)} \left[ \log p_{\theta}(x|z) \right] - D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

the second term is:

$$ D_{KL} \left( q(z|x) \mid p_{\theta}(z) \right) $$

We want to minimise this term.

We assume the latent prior $p_\theta(z) \sim N(0, I)$, where $I$ is the identity matrix.

The approximate posterior $q(z|x)$ is modeled as $N(z; \mu_\phi(x), \Sigma_\phi(x))$.

Given these assumptions, we can derive the KL divergence in closed form as:

$$ D_{KL}(N(z; \mu_\phi(x), \Sigma_\phi(x)) \| N(0, I)) = \frac{1}{2} \sum_{j=1}^J \left( \mu_{\phi,j}^2(x) + \Sigma_{\phi,j}(x) - \log \Sigma_{\phi,j}(x) - 1 \right) $$

Where:
- $J$ is the dimensionality of the latent space
- $\mu_{\phi,j}(x)$ is the j-th element of the mean vector
- $\Sigma_{\phi,j}(x)$ is the j-th diagonal element of the covariance matrix

In [None]:
def kl_divergence_loss(mu, logvar):
    """
    Calculates the KL divergence loss.
    
    Args:
    mu (torch.Tensor): The mean of the latent distribution.
    logvar (torch.Tensor): The log variance of the latent distribution.
    
    Returns:
    torch.Tensor: The KL divergence loss.
    """
    kl_divergence_loss_value = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1) 
    return kl_divergence_loss_value  # Return the calculated KL divergence loss

## Data preparation

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        # Directory containing all images (including subfolders)
        self.root_dir = root_dir
        self.transform = transform  # Transformations to apply to images
        self.image_files = []  # List to store all image file paths

        # Traverse through all subfolders
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_files.append(os.path.join(root, file))

    def __len__(self):
        return len(self.image_files)  # Return the total number of images

    def __getitem__(self, idx):
        img_path = self.image_files[idx]  # Get image path
        image = Image.open(img_path).convert(
            'RGB')  # Open image and convert to RGB

        if self.transform:
            image = self.transform(image)  # Apply transformations if any

        return image, 0

In [None]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),          # Randomly flip images horizontally
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),  # Convert images to PyTorch tensors and scale to [0, 1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
# Load the dataset from the specified relative path
animal_dataset = datasets.ImageFolder(root='data/Animals_data/animals/animals',  # Specify the root directory of the dataset
                                       transform=transform)  # Apply the defined transformations to the dataset

butterly_dataset = ImageDataset(root_dir="data/butterfly_data", transform=transform)

In [None]:
# Create a DataLoader for the dataset to enable batch processing
d = butterly_dataset if dataset == BUTTERFLY else animal_dataset
dataloader = torch.utils.data.DataLoader(
    d, BATCH_SIZE, shuffle=True, num_workers=2)

### Training loop

In [None]:
def init_weights_decoder(m):
    if isinstance(m, nn.Conv2d):
        # Check if this Conv layer is in decoder (following Upsample)
        # Decoder conv layers need slightly larger weights to handle upsampled features
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.002)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0.001)
            
    elif isinstance(m, nn.Linear):
        # Decoder linear layers: slightly larger initialization
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.002)
        torch.nn.init.constant_(m.bias, 0.001)

    elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
        if m.weight is not None:
            # Slightly larger scale for decoder norm layers to handle upsampled features
            torch.nn.init.constant_(m.weight, 0.2)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

            
def init_weights_encoder(m):
    if isinstance(m, nn.Conv2d):
        # Encoder conv layers: keep very small random initialization
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.001)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
            
    elif isinstance(m, nn.Linear):
        if 'mu' in m.__class__.__name__.lower():
            # Initialize mu layer to produce values very close to 0
            torch.nn.init.zeros_(m.weight)
            torch.nn.init.zeros_(m.bias)
            # Add tiny noise to prevent pure zeros
            with torch.no_grad():
                m.weight.data += torch.randn_like(m.weight) * 0.0001
                
        elif 'logvar' in m.__class__.__name__.lower():
            # Initialize logvar layer to produce exact zeros
            torch.nn.init.zeros_(m.weight)
            torch.nn.init.zeros_(m.bias)
            
        else:
            # Regular linear layers get very small random initialization
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.001)
            torch.nn.init.zeros_(m.bias)
            
    elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d)):
        if m.weight is not None:
            torch.nn.init.constant_(m.weight, 0.1)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

In [None]:
def generate_and_save_images(encoder_model: Encoder, decoder_model: Decoder, epoch_index, img_dir, display=False):
    # Generate and save images for the current epoch
    with torch.no_grad():  # Disable gradient calculation for inference
        decoder_model.eval()  # Set decoder to evaluation mode
        encoder_model.eval()  # Set encoder to evaluation mode

        # Generate random latent vectors for new images
        sample_latent_vector = torch.randn(100, LATENT_DIM).to(device)  # Generate random latent vectors
        generated_images = decoder_model(sample_latent_vector)  # Generate images using the decoder

        # Collect real images from the dataloader until we have exactly 100
        real_images_list = []
        for images, _ in dataloader:
            real_images_list.append(images.to(device))
            if sum(img.size(0) for img in real_images_list) >= 100:
                break

        # Concatenate all batches into a single tensor and slice to get exactly 100
        real_images = torch.cat(real_images_list, dim=0)[:100]

        # Encode real images to get latent vectors
        mean, log_var = encoder_model(real_images)
        encoded_latents = reparameterize(mean, log_var).to(device)
        reconstructed_images = decoder_model(encoded_latents)  # Reconstruct the real images

        decoder_model.train()  # Set decoder back to training mode
        encoder_model.train()  # Set encoder back to training mode

    if display:
        # Function to display grid using matplotlib
        def show_images(tensor, title):
            tensor = (tensor + 1) / 2.0
            grid = make_grid(tensor, nrow=10, padding=2, normalize=True)  # Create grid
            plt.figure(figsize=(10, 10))
            plt.imshow(grid.permute(1, 2, 0).cpu().numpy())  # Convert to numpy for plotting
            plt.axis('off')
            plt.title(title)
            plt.show()

        show_images(generated_images, f'Generated Images - Epoch {epoch_index + 1}')
        show_images(reconstructed_images, f'Reconstructed Images - Epoch {epoch_index + 1}')
    else:

        # Save generated images
        image_path_gen = f'{img_dir}/generated_{epoch_index + 1}.png'
        save_image(make_grid(generated_images, nrow=10, padding=2, normalize=True), image_path_gen)

        # Save reconstructed images
        image_path_recon = f'{img_dir}/reconstructed_{epoch_index + 1}.png'
        save_image(make_grid(reconstructed_images, nrow=10, padding=2, normalize=True), image_path_recon)

In [None]:
# papermill_description=Q1

def train_vae(beta=1, img_dir=IMG_DIR):
    encoder_model = Encoder().to(device)
    decoder_model = Decoder().to(device)
    encoder_model.apply(init_weights_encoder)
    decoder_model.apply(init_weights_decoder)
    encoder_optimizer = torch.optim.Adam(encoder_model.parameters(), lr=lr)
    decoder_optimizer = torch.optim.Adam(decoder_model.parameters(), lr=lr)

    encoder_scheduler = optim.lr_scheduler.StepLR(
        encoder_optimizer, step_size=NUM_EPOCHS//3, gamma=0.5)
    decoder_scheduler = optim.lr_scheduler.StepLR(
        decoder_optimizer, step_size=NUM_EPOCHS//3, gamma=0.5)

    # Training loop
    # Iterate through the specified number of epochs
    for epoch_index in range(NUM_EPOCHS):

        # Iterate through batches of images
        for images_batch, _ in dataloader:
            # Move the batch of images to the specified device (GPU/CPU)
            images_batch = images_batch.to(device)

            # Zero gradients for both encoder and decoder optimizers
            encoder_optimizer.zero_grad()  # Reset gradients for the encoder optimizer
            decoder_optimizer.zero_grad()  # Reset gradients for the decoder optimizer

            # Forward pass through encoder to obtain mean and log variance
            # Encode the input images to get mean and log variance
            mean_output, log_variance_output = encoder_model(images_batch)

            latent_vector_z = reparameterize(
                mean_output, log_variance_output)
            reconstructed_samples_list = []  # List to store multiple reconstructed samples
            for sample_index in range(NUM_SAMPLES):
                # Sample latent vectors using the reparameterization trick and move to device
                latent_vector_z = reparameterize(
                    mean_output, log_variance_output).to(device)
                # Decode the sampled latent vectors to reconstruct images
                reconstructed_sample = decoder_model(latent_vector_z)
                # Add the reconstructed sample to the list
                reconstructed_samples_list.append(reconstructed_sample)
            reconstructed_samples_tensor = torch.stack(reconstructed_samples_list, dim=0).to(
                device)  # Stack all reconstructed samples into a single tensor and move to device

            # Compute mean reconstruction across all samples
            mean_reconstructed_images = torch.mean(reconstructed_samples_tensor, dim=0).to(
                device)  # Calculate the mean of all reconstructed samples and move to device

            # Calculate reconstruction loss using the mean reconstructed images and the original images
            # Compute reconstruction loss (both tensors are already on device)
            reconstruction_loss_value = reconstruction_loss(
                mean_reconstructed_images, images_batch)

            # Calculate KL divergence loss using the mean and log variance from the encoder
            # Compute KL divergence loss (both tensors are already on device)
            kl_divergence_loss_value = kl_divergence_loss(
                mean_output, log_variance_output)

            # Compute total loss by summing reconstruction loss and KL divergence loss
            total_loss_value = reconstruction_loss_value + \
                beta*kl_divergence_loss_value  # Total loss for both encoder and decoder

            # Backward pass for both encoder and decoder (using total loss)
            total_loss_value.backward()  # Compute gradients for both encoder and decoder

            encoder_optimizer.step()  # Update the encoder's parameters based on the gradients
            decoder_optimizer.step()  # Update the decoder's parameters based on the gradients

        encoder_scheduler.step()
        decoder_scheduler.step()
        log_gradients_to_tensorboard(encoder_model, epoch_index, 'Encoder')
        log_gradients_to_tensorboard(decoder_model, epoch_index, 'Decoder')

        if epoch_index % 5 == 0:
            print(f"Epoch [{epoch_index}/{NUM_EPOCHS}]  Encoder loss: {
                  kl_divergence_loss_value.item():.4f}, Decoder loss: {reconstruction_loss_value.item():.4f}")
            generate_and_save_images(encoder_model, decoder_model, epoch_index, img_dir)

        log_losses_to_tensorboard(
            epoch_index, kl_divergence_loss_value.item(), reconstruction_loss_value.item(), total_loss_value.item())

    # Print a message indicating that the entire training process has finished
    print("Training completed.")
    return encoder_model, decoder_model



encoder_model_path = f'VAE/models/encoder_{dataset}_ns_{NUM_SAMPLES}{tag}.pth'
decoder_model_path = f'VAE/models/decoder_{dataset}_ns_{NUM_SAMPLES}{tag}.pth'
if TRAIN_VAE:
    encoder_model, decoder_model = train_vae()
    torch.save(encoder_model.state_dict(), encoder_model_path)
    torch.save(decoder_model.state_dict(), decoder_model_path)
    print(f"Models saved in 'VAE/models' folder")
else:
    encoder_model = Encoder().to(device)
    decoder_model = Decoder().to(device)
    encoder_model.load_state_dict(torch.load(encoder_model_path, weights_only=True))
    decoder_model.load_state_dict(torch.load(decoder_model_path, weights_only=True))

### Display Images

In [None]:
generate_and_save_images(encoder_model, decoder_model,
                         NUM_EPOCHS, IMG_DIR, display=True)

### Plot lost curves

In [None]:

def plot_curves():
    event_acc.Reload()

    # Extract the scalar values for KL Divergence Loss, Reconstruction Loss, and Total Loss
    kl_loss_values = event_acc.Scalars('Loss/Encoder')  # exact tag for KL divergence loss
    reconstruction_loss_values = event_acc.Scalars('Loss/Decoder')  # exact tag for reconstruction loss
    total_loss_values = event_acc.Scalars('Loss/Total')  # exact tag for total loss

    # Extract steps and values for each scalar
    kl_steps = [scalar.step for scalar in kl_loss_values]
    kl_losses = [scalar.value for scalar in kl_loss_values]

    reconstruction_steps = [scalar.step for scalar in reconstruction_loss_values]
    reconstruction_losses = [scalar.value for scalar in reconstruction_loss_values]

    total_steps = [scalar.step for scalar in total_loss_values]
    total_losses = [scalar.value for scalar in total_loss_values]

    # Plot using Matplotlib
    plt.figure(figsize=(10,6))

    # Plot KL Divergence Loss
    plt.plot(kl_steps, kl_losses, label='KL Divergence Loss', color='blue')

    # Plot Reconstruction Loss
    plt.plot(reconstruction_steps, reconstruction_losses, label='Reconstruction Loss', color='green')

    # Plot Total Loss
    plt.plot(total_steps, total_losses, label='Total Loss', color='red')

    # Add labels and title
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Components Over Time')

    # Add legend and grid
    plt.legend()
    plt.grid(True)

    # Show the plot
    plt.show()

if TRAIN_VAE:
    plot_curves()

##  Perform posterior inference for a pair of images and plot the generated images in the path of linearly interpolated latents

In [None]:
# papermill_description=Q5

def encode_images(encoder, images):
    encoder.eval()
    with torch.no_grad():
        mu, log_var = encoder(images.to(device))
    return mu, log_var

def interpolate_latents(z1, z2, num_points=10):
    alphas = np.linspace(0, 1, num_points)
    interpolated = []
    
    for alpha in alphas:
        z = (1 - alpha) * z1 + alpha * z2
        interpolated.append(z)
    
    return torch.stack(interpolated)

def decode_interpolations(decoder, interpolated_latents):
    decoder.eval()
    with torch.no_grad():
        reconstructed = decoder(interpolated_latents.to(device))
    return reconstructed

def plot_interpolation_path(original_images, reconstructed_images, pair_idx, save_path=None):
    plt.figure(figsize=(15, 3))
    
    # Plot original pair
    plt.subplot(1, 12, 1)
    plt.imshow(original_images[0].cpu().permute(1, 2, 0))
    plt.axis('off')
    plt.title('Original 1')
    
    # Plot interpolations
    for i in range(10):
        plt.subplot(1, 12, i + 2)
        plt.imshow(reconstructed_images[i].cpu().permute(1, 2, 0))
        plt.axis('off')
        plt.title(f'Step {i+1}')
    
    # Plot second original
    plt.subplot(1, 12, 12)
    plt.imshow(original_images[1].cpu().permute(1, 2, 0))
    plt.axis('off')
    plt.title('Original 2')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(f'{save_path}/interpolation_pair_{pair_idx}.png')
    plt.show()

def perform_interpolation_analysis(encoder, decoder, dataset, num_pairs=10, num_points=10,  save_path=None):
    
    for pair_idx in range(num_pairs):
        # Randomly select two images
        idx1, idx2 = np.random.choice(len(dataset), 2, replace=False)
        img1, img2 = dataset[idx1][0], dataset[idx2][0]  # Assuming dataset returns (image, label) pairs
        
        # Add batch dimension
        images = torch.stack([img1, img2])
        
        # Encode images to get latent representations
        mu, log_var = encode_images(encoder, images)
        
        # Sample from the posterior using reparameterization trick
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        
        # Interpolate between the two latent vectors
        interpolated = interpolate_latents(z[0], z[1], num_points)
        interpolated = (interpolated + 1) / 2.0

        
        # Decode interpolated vectors
        reconstructed = decode_interpolations(decoder, interpolated)
        
        # Plot results
        plot_interpolation_path([img1, img2], reconstructed, pair_idx, save_path)
        
if LOAD_BETA_VAE:
    perform_interpolation_analysis(beta_encoder_model, beta_decoder_model, butterly_dataset, save_path="VAE/plots")