# W GAN

## Preliminaries

In [6]:
# Import necessary libraries
import torch  # Import PyTorch library for deep learning operations
import torch.nn as nn  # Import neural network module for building neural network architectures
import torch.optim as optim  # Import optimization module for training neural networks
from torchvision import datasets, transforms  # Import datasets and transforms for image processing
from torch.utils.data import Dataset, DataLoader, random_split  # Import Dataset, DataLoader, and random_split for data handling
from torchvision.utils import save_image, make_grid  # Import utility functions to save and display images
import torchvision  # Import torchvision library for computer vision tasks
import torchvision.models as models  # Import pre-trained models from torchvision
import matplotlib.pyplot as plt  # Import plotting library for visualizations
import os  # Import os module for file and directory operations
import numpy as np  # Import numpy library for numerical operations
from PIL import Image  # Import PIL for image processing
import pathlib  # Import pathlib for file path operations
import pandas as pd  # Import pandas for data manipulation and analysis
import scipy  # Import scipy for scientific computing
from tqdm import tqdm  # Import tqdm for progress bars
from torch.optim.lr_scheduler import ExponentialLR  # Import ExponentialLR scheduler for learning rate decay
from torch.nn.utils import spectral_norm  # Import spectral_norm for weight normalization in neural networks
from torchvision.transforms import ToTensor, Compose, Normalize, Resize  # Import specific transform functions
from torchvision.models import inception_v3  # Import Inception v3 model for FID calculation
from scipy import linalg  # Import linalg from scipy for matrix operations in FID calculation
from torch.nn.functional import adaptive_avg_pool2d  # Import adaptive average pooling for FID calculation
from torchvision.models import inception_v3  # Import Inception v3 model for FID calculation
from torchvision import transforms  # Import transforms for image preprocessing

In [7]:
# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Set device
print(f"Using device: {device}")  # Print the device being used

Using device: cuda


## Define models

In [8]:
# Note that the function returns batch_size number of noise vectors, each of dimension latent_dim. Scalar of each noise vector is sampled from a standard normal distribution with mean 0 and variance 1
def sample_noise(batch_size, latent_dim=100): # The function takes two arguments: batch_size and latent_dim. batch_size is the number of noise vectors to generate, and latent_dim is the dimension of each noise vector.
    """
    Generate noise vectors for the generator input.
    
    Args:
        batch_size (int): The number of noise vectors to generate.
        latent_dim (int): The dimension of each noise vector. Default is 100.
    
    Returns:
        torch.Tensor: A tensor of shape (batch_size, latent_dim, 1, 1) containing the generated noise vectors.
    """
    noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)  # Generate random noise vectors from a standard normal distribution
    return noise  # Return the generated noise tensor
    noise = torch.randn(batch_size, latent_dim, 1, 1, device=device)  #  torch.randn(), which generates a tensor of random numbers from a standard normal distribution (mean 0, variance 1)
    return noise  # Return the noise tensor

In [9]:
# Define the Generator model class. It takes an input of dimension N x 100 x 1 x 1 ie a batch of noise vectors and outputs N x 3 x 128 x 128 ie a batch of images of dimension 3(RGB) x 128 x 128
class Generator(nn.Module):  # Defined a class named Generator that inherits from nn.Module of PyTorch. nn.module is a class that defines a neural network in PyTorch

    def __init__(self):  # Constructor function. Whenever an instance of the class is created, this function initializes the instance
        super(Generator, self).__init__()  # Call parent constructor before any Generator specific initialization. This pattern ensures that all the benefits and functionalities provided by nn.Module are available to Generator, like parameter tracking, module moving to different devices (CPU/GPU), etc.
        
        self.main = nn.Sequential( # We define property of class called main, which is a ordered  ordered sequence of neural network layers. When we pass input through self.main, it goes through these layers in the order they are listed.
            nn.ConvTranspose2d(in_channels=100, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),  # Input: N x 100 x 1 x 1, Output: N x 512 x 4 x 4, 512 filters of size 4x4x100
            nn.BatchNorm2d(num_features=512),  # Batch normalization for 512 channels
            nn.ReLU(inplace=True),  # ReLU activation
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),  # Input: N x 512 x 4 x 4, Output: N x 256 x 8 x 8, 256 filters of size 4x4x512
            nn.BatchNorm2d(num_features=256),  # Batch normalization for 256 channels
            nn.ReLU(inplace=True),  # ReLU activation
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),  # Input: N x 256 x 8 x 8, Output: N x 128 x 16 x 16, 128 filters of size 4x4x256
            nn.BatchNorm2d(num_features=128),  # Batch normalization for 128 channels
            nn.ReLU(inplace=True),  # ReLU activation
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),  # Input: N x 128 x 16 x 16, Output: N x 64 x 32 x 32, 64 filters of size 4x4x128
            nn.BatchNorm2d(num_features=64),  # Batch normalization for 64 channels
            nn.ReLU(inplace=True),  # ReLU activation
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1, bias=False),  # Input: N x 64 x 32 x 32, Output: N x 32 x 64 x 64, 32 filters of size 4x4x64
            nn.BatchNorm2d(num_features=32),  # Batch normalization for 32 channels
            nn.ReLU(inplace=True),  # ReLU activation
            nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),  # Input: N x 32 x 64 x 64, Output: N x 3 x 128 x 128, 3 filters of size 4x4x32
            nn.Tanh()  # Tanh activation to ensure output is in range [-1, 1]

        )

    def forward(self, input):  # Forward pass
        return self.main(input)  
    
    def initialize_weights(self):  # Method to initialize weights of the Generator
        for module in self.modules():  # Iterate through all modules in the Generator
            if isinstance(module, (nn.ConvTranspose2d, nn.BatchNorm2d)):  # Check if the module is a ConvTranspose2d or BatchNorm2d layer
                nn.init.normal_(module.weight.data, mean=0.0, std=0.02)  # Initialize weights with normal distribution (mean=0, std=0.02)
                if module.bias is not None:  # If the module has a bias parameter
                    nn.init.constant_(module.bias.data, 0.0)  # Initialize bias to zero


In [10]:
# Define the Discriminator model. The instance takes input of dimension N x 3 x 128 x 128 and outputs N x 1 tensor corresponding to f(x)
class Discriminator(nn.Module):  # Discriminator class

    def __init__(self):  # Constructor
        super(Discriminator, self).__init__()  # Call parent constructor
        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)),  # Apply spectral normalization to Conv2d layer
            nn.LeakyReLU(0.2, inplace=True),  # LeakyReLU activation with negative slope of 0.2
            spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False)),  # Apply spectral normalization to Conv2d layer
            nn.BatchNorm2d(num_features=128),  # Batch normalization for 128 channels
            nn.LeakyReLU(0.2, inplace=True),  # LeakyReLU activation with negative slope of 0.2
            spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False)),  # Apply spectral normalization to Conv2d layer
            nn.BatchNorm2d(num_features=256),  # Batch normalization for 256 channels
            nn.LeakyReLU(0.2, inplace=True),  # LeakyReLU activation with negative slope of 0.2
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False)),  # Apply spectral normalization to Conv2d layer
            nn.BatchNorm2d(num_features=512),  # Batch normalization for 512 channels
            nn.LeakyReLU(0.2, inplace=True),  # LeakyReLU activation with negative slope of 0.2
            spectral_norm(nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)),  # Apply spectral normalization to Conv2d layer
            nn.BatchNorm2d(256),  # Batch normalization for 256 channels
            nn.LeakyReLU(0.2, inplace=True),  # Activation with negative slope of 0.2
            spectral_norm(nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)),  # Apply spectral normalization to Conv2d layer
            nn.BatchNorm2d(128),  # Batch normalization for 128 channels
            nn.LeakyReLU(0.2, inplace=True),  # Activation with negative slope of 0.2
            spectral_norm(nn.Conv2d(in_channels=128, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False)),  # Apply spectral normalization to Conv2d layer
            nn.Flatten()  # Flatten the output to N x 1
        )

    def forward(self, input):  # Forward pass
        return self.main(input) 

    def weight_clip(self, clip_value=0.01): # Method to clip weights of the discriminator
        for p in self.parameters(): # Iterate through all parameters of the discriminator
            p.data.clamp_(-clip_value, clip_value) # Clips the weights to be within the range [-clip_value, clip_value]

    def initialize_weights(self):  # Method to initialize weights of the discriminator
        for module in self.modules():  # Iterate through all modules in the discriminator
            if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):  # Check if the module is a Conv2d or BatchNorm2d layer
                nn.init.normal_(module.weight.data, mean=0.0, std=0.02)  # Initialize weights with normal distribution
                if module.bias is not None:  # Check if the module has a bias parameter
                    nn.init.constant_(module.bias.data, 0.0)  # Initialize bias to zero

## Loss for Wasserstein GANs

The Wasserstein GAN loss function is defined as:

$$
\theta^* = \arg\min_\theta \max_\omega \mathbb{E}_{x \sim P_{data}}[f(x)] - \mathbb{E}_{x \sim P_{generator}}[f(x)]
$$

Where:
- $\theta^*$ represents the optimal parameters of the generator
- $\theta$ are the generator parameters
- $\omega$ are the discriminator parameters
- $f(x)$ is 1 - Lipschitz discriminator function
- $P_{data}$ is the real data distribution
- $P_{generator}$ is the generator's distribution

### Generator Loss Function

The generator tries to minimize the following loss function:

$$
Loss = -\mathbb{E}_{x \sim P_{generator}}[f(x)]
$$

In [11]:
# Takes the output of the discriminator for the benerated image ie Batch size N x 1 values and returns the loss for generator
def generator_loss(fake_output):
    """
    Calculate the generator loss for Wasserstein GAN.
    
    Args:
        fake_output (torch.Tensor): The output of the discriminator for batch of images generated by the generator.
    
    Returns:
        torch.Tensor: The generator loss.
    """
    return -torch.mean(fake_output)

### Discriminator Loss Function

the discriminator tries to maximize the following loss function:

$$
Loss = -(\mathbb{E}_{x \sim P_{data}}[f(x)] - \mathbb{E}_{x \sim P_{generator}}[f(x)])
$$

In [12]:
# Takes the output of the discriminator for the real image and generated image
def discriminator_loss(real_output, fake_output):
    """
    Calculate the discriminator loss for Wasserstein GAN.
    
    Args:
        real_output (torch.Tensor): The output of the discriminator for a batch of real images.
        fake_output (torch.Tensor): The output of the discriminator for a batch of fake images ie images generated by the generator.
    
    Returns:
        torch.Tensor: The discriminator loss.
    """
    return -(torch.mean(real_output) - torch.mean(fake_output))

## Data preparation

In [39]:
class SimpleImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir  # Directory containing all images (including subfolders)
        self.transform = transform  # Transformations to apply to images
        self.image_files = []  # List to store all image file paths

        # Traverse through all subfolders
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_files.append(os.path.join(root, file))

    def __len__(self):
        return len(self.image_files)  # Return the total number of images

    def __getitem__(self, idx):
        img_path = self.image_files[idx]  # Get image path
        image = Image.open(img_path).convert('RGB')  # Open image and convert to RGB
        
        if self.transform:
            image = self.transform(image)  # Apply transformations if any
        
        return image  # Return only the image, no label needed for GAN

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images to 128x128
    transforms.ToTensor(),  # Convert images to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize each RGB channel to [-1, 1]
])

# Load the Animal dataset from local machine
data_dir = r'D:\Users\VICTOR\Desktop\ADRL\Assignment 1\Animal_data_resized\animals'  # Path to the Animal dataset

# Create dataset and dataloader
dataset = SimpleImageDataset(root_dir=data_dir, transform=transform)  # Use our simple dataset class
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)  # Create dataloader

print(f"Loaded {len(dataset)} images.")  # Print the total number of images loaded

Loaded 5400 images.


## Training loop

In [40]:
# Instantiate the models and move to device
generator = Generator().to(device) 
discriminator = Discriminator().to(device)

In [41]:
# Define paths for saved models
generator_path = r'D:\Users\VICTOR\Desktop\ADRL\Assignment 1\WGAN - Experiments\3 Further trained 2\saved_models_WGANs_butterfly\generator_epoch_9.pth'  # Path to the saved generator model
discriminator_path = r'D:\Users\VICTOR\Desktop\ADRL\Assignment 1\WGAN - Experiments\3 Further trained 2\saved_models_WGANs_butterfly\discriminator_epoch_9.pth'  # Path to the saved discriminator model

# Load the generator model
generator = Generator().to(device)  # Create a new generator instance and move it to the device
generator.load_state_dict(torch.load(generator_path))  # Load the saved state dictionary into the generator

# Load the discriminator model
discriminator = Discriminator().to(device)  # Create a new discriminator instance and move it to the device
discriminator.load_state_dict(torch.load(discriminator_path))  # Load the saved state dictionary into the discriminator

print("Models loaded successfully and moved to device.")  # Print a success message

  generator.load_state_dict(torch.load(generator_path))  # Load the saved state dictionary into the generator


Models loaded successfully and moved to device.


  discriminator.load_state_dict(torch.load(discriminator_path))  # Load the saved state dictionary into the discriminator


In [42]:
# Define optimizers and schedulers
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.0, 0.9))  # Optimizer for discriminator
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.0, 0.9))  # Optimizer for generator
scheduler_d = ExponentialLR(optimizer_d, gamma=0.99)  # Learning rate scheduler for discriminator
scheduler_g = ExponentialLR(optimizer_g, gamma=0.99)  # Learning rate scheduler for generator

In [43]:
# Set models to training mode
generator.train() 

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): BatchNorm2d(

In [44]:
discriminator.train() 

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (12): BatchNorm2d(256, eps=1e-05, 

In [45]:
# Define hyperparameters
num_epochs = 100 # Number of epochs for training
noise_dim = 100 # Define the noise dimension
clip_value = 0.01  # Clipping parameter
n_critic = 1  # Number of critic iterations per generator iteration

### Differentiable Augmentation in GAN Training

1. The `DiffAugment` function takes an input tensor `x` (typically a batch of images) and applies a series of augmentations based on the specified policy.

2. The `policy` parameter is a string that can include 'color', 'translation', and 'cutout', separated by commas. These correspond to different types of augmentations.

3. If the input tensor doesn't have channels as the first dimension (after batch size), it rearranges the dimensions.

4. For each augmentation type specified in the policy:
   - It retrieves the corresponding augmentation functions from the `AUGMENT_FNS` dictionary.
   - It applies each of these functions to the input tensor `x`.

5. The augmentation functions (defined elsewhere in the code) include:
   - `rand_brightness`: Randomly adjusts brightness
   - `rand_saturation`: Randomly adjusts saturation
   - `rand_contrast`: Randomly adjusts contrast
   - `rand_translation`: Randomly translates the image
   - `rand_cutout`: Applies random cutouts to the image

6. After applying all augmentations, it ensures the tensor is contiguous in memory and returns the augmented tensor.

7. The key idea behind Differentiable Augmentation is that these augmentations are applied to both real and generated images during GAN training. This helps to regularize the discriminator and can significantly improve GAN performance, especially when training data is limited.

8. By applying the same augmentations to both real and generated images, the discriminator is forced to focus on more robust features rather than potentially overfitting to specific details of the limited training data. This can lead to better generalization and more stable training, particularly in scenarios where the amount of training data is small.


In [52]:
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# Code taken from: https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_pytorch.py

def DiffAugment(x, policy='color,translation,cutout', channels_first=True):
    """
    Apply differentiable augmentation to the input tensor.
    
    Args:
        x (torch.Tensor): Input tensor to be augmented.
        policy (str): Comma-separated string specifying which augmentations to apply.
        channels_first (bool): Whether the input tensor has channels as the first dimension.
    
    Returns:
        torch.Tensor: Augmented tensor.
    """
    if not x.is_cuda:  # Check if input tensor is on GPU
        x = x.cuda()  # Move tensor to GPU if not already there
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)  # Rearrange dimensions if channels are not first
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)  # Apply each augmentation function specified in the policy
        if not channels_first:
            x = x.permute(0, 2, 3, 1)  # Rearrange dimensions back if necessary
        x = x.contiguous()  # Ensure tensor is contiguous in memory
    return x


def rand_brightness(x):
    """
    Apply random brightness augmentation to the input tensor.
    
    Args:
        x (torch.Tensor): Input tensor to be augmented.
    
    Returns:
        torch.Tensor: Brightness-augmented tensor.
    """
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)  # Add random brightness offset
    return x


def rand_saturation(x):
    """
    Apply random saturation augmentation to the input tensor.
    
    Args:
        x (torch.Tensor): Input tensor to be augmented.
    
    Returns:
        torch.Tensor: Saturation-augmented tensor.
    """
    x_mean = x.mean(dim=1, keepdim=True)  # Calculate mean across color channels
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1,
                                   dtype=x.dtype, device=x.device) * 2) + x_mean  # Adjust saturation randomly
    return x


def rand_contrast(x):
    """
    Apply random contrast augmentation to the input tensor.
    
    Args:
        x (torch.Tensor): Input tensor to be augmented.
    
    Returns:
        torch.Tensor: Contrast-augmented tensor.
    """
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)  # Calculate mean across all dimensions except batch
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1,
                                   dtype=x.dtype, device=x.device) + 0.5) + x_mean  # Adjust contrast randomly
    return x


def rand_translation(x, ratio=0.125):
    """
    Apply random translation augmentation to the input tensor.
    
    Args:
        x (torch.Tensor): Input tensor to be augmented.
        ratio (float): Maximum translation ratio relative to image size.
    
    Returns:
        torch.Tensor: Translation-augmented tensor.
    """
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)  # Calculate maximum shift
    translation_x = torch.randint(-shift_x, shift_x + 1,
                                  size=[x.size(0), 1, 1], device=x.device)  # Random horizontal shift
    translation_y = torch.randint(-shift_y, shift_y + 1,
                                  size=[x.size(0), 1, 1], device=x.device)  # Random vertical shift
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )  # Create meshgrid for indexing
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)  # Apply horizontal shift
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)  # Apply vertical shift
    x_pad = torch.nn.functional.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])  # Pad input tensor
    x = x_pad.permute(0, 2, 3, 1).contiguous()[
        grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()  # Apply translation using grid sampling
    return x


def rand_cutout(x, ratio=0.5):
    """
    Apply random cutout augmentation to the input tensor.
    
    Args:
        x (torch.Tensor): Input tensor to be augmented.
        ratio (float): Size of the cutout relative to image size.
    
    Returns:
        torch.Tensor: Cutout-augmented tensor.
    """
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)  # Calculate cutout size
    offset_x = torch.randint(0, x.size(
        2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)  # Random x offset
    offset_y = torch.randint(0, x.size(
        3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)  # Random y offset
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )  # Create meshgrid for indexing
    grid_x = torch.clamp(grid_x + offset_x -
                         cutout_size[0] // 2, min=0, max=x.size(2) - 1)  # Calculate x coordinates of cutout region
    grid_y = torch.clamp(grid_y + offset_y -
                         cutout_size[1] // 2, min=0, max=x.size(3) - 1)  # Calculate y coordinates of cutout region
    mask = torch.ones(x.size(0), x.size(2), x.size(3),
                      dtype=x.dtype, device=x.device)  # Create mask tensor
    mask[grid_batch, grid_x, grid_y] = 0  # Set cutout region to 0 in mask
    x = x * mask.unsqueeze(1)  # Apply mask to input tensor
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],  # Color augmentation functions
    'translation': [rand_translation],  # Translation augmentation function
    'cutout': [rand_cutout],  # Cutout augmentation function
}

### Augmentation in GAN Training

Augmentation Application
- Both real and generated images are augmented using differentiable techniques before passing to the discriminator.
- This creates a more challenging task for the discriminator.

Discriminator Training
- The discriminator now sees augmented versions of both real and fake images.
- This increases difficulty in distinguishing between real and fake.

Generator Training
- The generator produces images as usual.
- These images are then augmented before being passed to the discriminator.
- Generator's loss is based on the discriminator's output on augmented fake images.

Backpropagation
- Differentiable augmentations allow gradients to flow through them.
- Both generator and discriminator learn from the augmented data.

Regularization Effect
- Augmentation reduces overfitting to specific details in limited training data.
- Encourages the discriminator to focus on more robust and general features.

In [50]:
# Training loop
for epoch in range(num_epochs):  # Outer loop for every epoch
    print(f"Epoch {epoch+1}/{num_epochs}")  # Print for every epoch

    for i, batch in enumerate(dataloader):  # Inner loop for every batch
        print(f"Batch {i+1}/{len(dataloader)}")  # Print for every batch

        # The batch is already the tensor of images
        real_images = batch.to(device)  # Move the tensor to the device (GPU or CPU)
        batch_size = real_images.size(0)  # Get batch size

        # Train Discriminator
        for _ in range(n_critic):  # Train discriminator more times than generator
            optimizer_d.zero_grad()  # Reset gradients for the discriminator's parameters to zero
            
            # Apply DiffAugment to real images
            augmented_real_images = DiffAugment(real_images, policy='color,translation,cutout').to(device)  # Move augmented images to device
            real_output = discriminator(augmented_real_images)

            noise = sample_noise(batch_size * len(augmented_real_images), noise_dim).to(device)  # Generate batch of noise vectors
            fake_images = generator(noise)  # Pass the noise through the generator to generate fake images
            
            # Apply DiffAugment to fake images
            augmented_fake_images = DiffAugment(fake_images.detach(), policy='color,translation,cutout').to(device)  # Move augmented images to device
            fake_output = discriminator(augmented_fake_images)
            
            loss_d = discriminator_loss(real_output, fake_output)  # Calculate discriminator loss

            loss_d.backward()  # Backpropagate discriminator loss
            optimizer_d.step()  # Update discriminator weights

        # Train Generator
        optimizer_g.zero_grad()  # Reset the gradients of generator 

        noise = sample_noise(batch_size, noise_dim).to(device)  # Generate batch of noise vectors 
        fake_images = generator(noise)  # Pass the noise through the generator to generate fake images
        
        # Apply DiffAugment to fake images for generator training
        augmented_fake_images = DiffAugment(fake_images, policy='color,translation,cutout').to(device)  # Move augmented images to device
        fake_output = discriminator(augmented_fake_images)

        loss_g = generator_loss(fake_output)  # Calculate generator loss

        loss_g.backward()  # Backpropagate generator loss
        optimizer_g.step()  # Update generator weights

        # Print losses
        print(f"Discriminator loss: {loss_d.item():.4f}, Generator loss: {loss_g.item():.4f}") 

        # Save generated images for every batch
        generator.eval()
        with torch.no_grad():  # Temporarily disables gradient computation
            fake = fake_images[0].unsqueeze(0).cpu()
            fake = (fake + 1) / 2.0  # Transform from [-1, 1] to [0, 1]
            img = torchvision.utils.make_grid(fake, normalize=False)
            img_np = img.detach().permute(1, 2, 0).numpy()
            plt.figure(figsize=(8, 8))
            plt.imshow(img_np)
            plt.axis('off')
            plt.title(f"Epoch {epoch+1}, Batch {i+1}")
            save_dir = r"D:\Users\VICTOR\Desktop\ADRL\Assignment 1\generated_images_WGAN_animal"
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f'generated_image_epoch_{epoch+1}_batch_{i+1}.png')
            plt.savefig(save_path)
            plt.close()
        generator.train()

        torch.cuda.empty_cache()

    # Save generator and discriminator models after every epoch
    save_dir = r"D:\Users\VICTOR\Desktop\ADRL\Assignment 1\saved_models_WGANs_animal"
    os.makedirs(save_dir, exist_ok=True)
    generator_save_path = os.path.join(save_dir, f'generator_epoch_{epoch+1}.pth')
    discriminator_save_path = os.path.join(save_dir, f'discriminator_epoch_{epoch+1}.pth')
    torch.save(generator.state_dict(), generator_save_path)
    torch.save(discriminator.state_dict(), discriminator_save_path)
    print(f"Models saved for epoch {epoch+1}")

    # Update learning rate
    scheduler_d.step()  # Update discriminator learning rate
    scheduler_g.step()  # Update generator learning rate

Epoch 1/100
Batch 1/2700



KeyboardInterrupt

