In [1]:
# Core PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim

# Torchvision for datasets, transforms, and utilities
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid # Specific utilities for saving and grid images

# Plotting and numerical operations
import matplotlib.pyplot as plt
import numpy as np

# System utilities
import os

In [2]:
# set GPU to device if available 
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda:1


## Define Models

In [3]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Encoder, self).__init__()
        self.latent_dim = latent_dim  # Dimension of the latent space

        # Convolutional layers to extract features from input image
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # Output: N x 32 x 64 x 64
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # Output: N x 64 x 32 x 32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Output: N x 128 x 16 x 16
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Output: N x 256 x 8 x 8
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # Output: N x 512 x 4 x 4
            nn.ReLU()
        )

        # Fully connected layers to map extracted features to latent space
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)  # Mean of latent distribution
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)  # Log variance for reparameterization

    def forward(self, x):
        x = self.conv_layers(x)  # Pass input through convolutional layers
        x = x.view(x.size(0), -1)  # Flatten feature map
        mu = self.fc_mu(x)  # Compute mean
        logvar = self.fc_logvar(x)  # Compute log variance
        return mu, logvar  # Return parameters for latent space sampling


In [4]:
import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim  # Latent space dimension

        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)  # Transform latent vector into feature map

        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # Upsampling to 8x8
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Upsampling to 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Upsampling to 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Upsampling to 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # Final upsampling to 128x128
            nn.Sigmoid()  # Normalize output to [0, 1]
        )

    def forward(self, z):
        x = self.fc(z)  # Expand latent vector
        x = x.view(-1, 512, 4, 4)  # Reshape for deconvolution
        return self.deconv_layers(x)  # Generate reconstructed image


In [5]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)  # Calculate standard deviation
    eps = torch.randn_like(std)    # Sample random noise
    z = mu + eps * std             # Apply reparameterization trick
    return z

In [6]:
import torch.nn.functional as F

def reconstruction_loss(x_reconstructed, x):
    # Calculates reconstruction loss using Mean Squared Error, summed across elements.
    return F.mse_loss(x_reconstructed, x, reduction='sum')

In [7]:
def kl_divergence_loss(mu, logvar):
    # Calculate KL divergence between a diagonal Gaussian and a standard Gaussian
    kl_divergence_loss_value = 0.5 * torch.sum(mu.pow(2) + logvar.exp() - logvar - 1)
    return kl_divergence_loss_value

## Preparing Data

In [8]:
transform = transforms.Compose([ # Define a sequence of image transformations
    transforms.Resize((128, 128)), # Resize all images to 128x128 pixels
    transforms.ToTensor(),         # Convert images to PyTorch tensors and scale pixel values to [0, 1]
])

In [9]:
# Load image dataset, and create data loader for training.

animal_dataset = datasets.ImageFolder(root='/home/vishwa/data_large/GenAI/animals', transform=transform) # Load the image dataset from the specified folder, applying transformations
dataloader = torch.utils.data.DataLoader(animal_dataset, batch_size=64, shuffle=True, num_workers=2) # Create a DataLoader to efficiently load data in batches for training

## Training Models

In [10]:
# Instantiate the encoder model, assuming Encoder is a predefined class for the encoder architecture
encoder_model = Encoder()
# Instantiate the decoder model, assuming Decoder is a predefined class for the decoder architecture
decoder_model = Decoder()

# Move the encoder model to the specified device (e.g., GPU or CPU)
encoder_model.to(device)
# Move the decoder model to the specified device (e.g., GPU or CPU)
decoder_model.to(device)

# Set the encoder model to training mode to enable dropout and batch normalization layers
encoder_model.train()
# Set the decoder model to training mode to enable dropout and batch normalization layers
decoder_model.train()

# Define the optimizer for the encoder model using Adam optimizer with a learning rate of 0.001
encoder_optimizer = torch.optim.Adam(encoder_model.parameters(), lr=0.001)
# Define the optimizer for the decoder model using Adam optimizer with a learning rate of 0.001
decoder_optimizer = torch.optim.Adam(decoder_model.parameters(), lr=0.001)

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import os
import numpy as np

# Assuming Encoder, Decoder, reparameterize, reconstruction_loss, kl_divergence_loss are defined
# Assuming device, num_epochs, dataloader, num_samples, latent_dim are also defined
# Make sure 'latent_dim' is correctly set (e.g., latent_dim = 128)

num_epochs = 500 
num_samples = 10

# Initialize arrays to store losses for each epoch
epoch_total_losses = []
epoch_reconstruction_losses = []
epoch_kl_divergence_losses = []

latent_dim = 128
# Create a fixed latent vector for consistent image generation monitoring across epochs
# This allows you to observe how the decoder improves on generating from the *same* point in latent space.
fixed_latent_sample = torch.randn(1, latent_dim).to(device) # Generates one random latent vector

# Training loop
for epoch_index in range(num_epochs):
    print(f"Epoch {epoch_index + 1}/{num_epochs}")
    epoch_loss = 0
    epoch_reconstruction_loss = 0
    epoch_kl_divergence_loss = 0

    # Set models to training mode at the beginning of each epoch
    encoder_model.train()
    decoder_model.train()

    for batch_index, (images_batch, _) in enumerate(dataloader):
        images_batch = images_batch.to(device)
        # Optional: Print batch progress, but can be verbose for many batches
        # print(f"Processing batch {batch_index + 1} with {images_batch.size(0)} images")

        # Forward pass through encoder to obtain mean and log variance
        mean_output, log_variance_output = encoder_model(images_batch)

        reconstructed_samples_list = []
        for sample_index in range(num_samples):
            # Sample latent vectors using the reparameterization trick
            latent_vector_z = reparameterize(mean_output, log_variance_output)
            reconstructed_sample = decoder_model(latent_vector_z)
            reconstructed_samples_list.append(reconstructed_sample)
        # Stack all reconstructed samples into a single tensor
        reconstructed_samples_tensor = torch.stack(reconstructed_samples_list, dim=0)

        # Compute mean reconstruction across all samples in the num_samples loop
        mean_reconstructed_images = torch.mean(reconstructed_samples_tensor, dim=0)

        # Calculate reconstruction loss using the mean reconstructed images and the original images
        reconstruction_loss_value = reconstruction_loss(mean_reconstructed_images, images_batch)

        # Calculate KL divergence loss using the mean and log variance from the encoder
        kl_divergence_loss_value = kl_divergence_loss(mean_output, log_variance_output)

        # Compute total loss by summing reconstruction loss and KL divergence loss
        total_loss_value = reconstruction_loss_value + kl_divergence_loss_value

        # Zero gradients for both encoder and decoder optimizers
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        # Backward pass for both encoder and decoder (using total loss)
        total_loss_value.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()

        # Accumulate losses for this epoch
        epoch_loss += total_loss_value.item()
        epoch_reconstruction_loss += reconstruction_loss_value.item()
        epoch_kl_divergence_loss += kl_divergence_loss_value.item()

    # --- Actions to perform AFTER each epoch completes ---
    # Calculate average losses for this epoch and store them
    average_epoch_loss = epoch_loss / len(dataloader)
    average_epoch_reconstruction_loss = epoch_reconstruction_loss / len(dataloader)
    average_epoch_kl_divergence_loss = epoch_kl_divergence_loss / len(dataloader)

    epoch_total_losses.append(average_epoch_loss)
    epoch_reconstruction_losses.append(average_epoch_reconstruction_loss)
    epoch_kl_divergence_losses.append(average_epoch_kl_divergence_loss)

    print(f"Average total loss for epoch {epoch_index + 1}: {average_epoch_loss:.4f}")
    print(f"Average reconstruction loss for epoch {epoch_index + 1}: {average_epoch_reconstruction_loss:.4f}")
    print(f"Average KL divergence loss for epoch {epoch_index + 1}: {average_epoch_kl_divergence_loss:.4f}")

    # Generate and save an image using the fixed latent sample after each epoch
    with torch.no_grad(): # Disable gradient calculations for inference
        decoder_model.eval() # Set decoder to evaluation mode
        
        # Generate image from the fixed latent vector
        generated_image_tensor = decoder_model(fixed_latent_sample)
        
        # Convert tensor to NumPy array and rearrange dimensions for plotting
        # Squeeze removes batch dimension (1, C, H, W) -> (C, H, W)
        # Permute rearranges (C, H, W) -> (H, W, C) for matplotlib
        # .cpu().numpy() moves to CPU and converts to NumPy
        generated_image_np = generated_image_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
        
        # Scale pixel values from [0, 1] (Sigmoid output) to [0, 255] for standard image saving
        generated_image_np = (generated_image_np * 255).astype(np.uint8)
        
        # Create a directory to save generated images if it doesn't exist
        save_dir = 'generated_images'
        os.makedirs(save_dir, exist_ok=True)
        
        # Define the path for saving the image
        image_path = os.path.join(save_dir, f'generated_image_epoch_{epoch_index + 1}.png')
        plt.imsave(image_path, generated_image_np) # Save the generated image using matplotlib
        plt.close() # Close the plot to free memory
        print(f"Generated image saved at: {image_path}")

    # Save models every epoch
    save_models_dir = 'saved_models'
    os.makedirs(save_models_dir, exist_ok=True)
    torch.save(encoder_model.state_dict(), os.path.join(save_models_dir, f'encoder_epoch_{epoch_index + 1}.pth'))
    torch.save(decoder_model.state_dict(), os.path.join(save_models_dir, f'decoder_epoch_{epoch_index + 1}.pth'))
    print(f"Models saved in '{save_models_dir}' folder at epoch {epoch_index + 1}\n")

Epoch 1/500
Average total loss for epoch 1: 225431.2494
Average reconstruction loss for epoch 1: 223964.2578
Average KL divergence loss for epoch 1: 1466.9920
Generated image saved at: generated_images/generated_image_epoch_1.png
Models saved in 'saved_models' folder at epoch 1

Epoch 2/500
Average total loss for epoch 2: 125173.8617
Average reconstruction loss for epoch 2: 123994.3296
Average KL divergence loss for epoch 2: 1179.5317
Generated image saved at: generated_images/generated_image_epoch_2.png
Models saved in 'saved_models' folder at epoch 2

Epoch 3/500
Average total loss for epoch 3: 97423.3409
Average reconstruction loss for epoch 3: 95572.1002
Average KL divergence loss for epoch 3: 1851.2407
Generated image saved at: generated_images/generated_image_epoch_3.png
Models saved in 'saved_models' folder at epoch 3

Epoch 4/500
Average total loss for epoch 4: 88540.1980
Average reconstruction loss for epoch 4: 86364.5956
Average KL divergence loss for epoch 4: 2175.6030
Gener