<table class="table table-bordered">
    <tr>
        <th style="text-align:center;"><h1>Visual Generative AI Application: Generative Adversarial Networks</h1><h2>Assignment</h2><h3>Specialist Diploma in Applied Generative AI (SDGAI) 
</h3></th>
    </tr>
</table>

# (1) State clearly the goal and objectives you hope to achieve in this notebook


The objective of this project is to design and implement a generative model capable of creating images of fashion items from 10 distinct categories (classes). For this part of the assignment, I will develop an **conditional Diffusion Model**.

For this model, we will analyze the model performance and tune the model hyperparameters during training phase. For this, we will explore the following:

- Vary the number of epochs:
    - Start with 50
    - Change this to 100
- Decrease the learning rate from 0.001 to 0.0001
- Decrease the batch size from 128 to 64

These are the steps I will follow:
1. Load and explore the dataset
2. Build the model
    1. Start with a baseline model
    2. Consider the different hyperparameters that can be tuned
    3. Perform the tuning
    4. Analyse the model's performance based on the different tuning strategies
3. Evaluate the model
    1. Use the model to generate new **specified** images from noise

# (2) Import libraries

In [None]:
import glob
import torch
import torch.nn.functional as F
from torch.optim import Adam
import torchvision.transforms as transforms

from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision import datasets
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# User defined libraries
from utils import other_utils
from utils import ddpm_utils
from utils import UNet_utils

import random
import datetime

import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(0) 

In [None]:
device ='cuda' if torch.cuda.is_available else 'cpu'
print(f'Using {device} device')

# (3) Load/Download Dataset  

We will now download the Dataset for this assignment. You may amend the __batch_size__ parameter, __transform__ function or the attributes of the Dataloader as you deem fit for your processing.

In [None]:
batch_size = 128
img_size = 16

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)  # Scale between [-1, 1]
])

dataset_path = 'D:\\Users\\ng_a\\My NP SDGAI\\PDC-2\\VGAA\\Assignment\\'  # Use double backslashes

dataset = datasets.FashionMNIST(dataset_path, download=False, transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True, drop_last=True)

# (4) Explore the data

In [None]:
labels_map = {
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle Boot',
}
# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))

# Loop through each subplot and plot an image
for i in range(4):
    for j in range(4):
        image, label = dataset[i * 4 + j]  # Get image and label
        image_numpy = image.numpy().squeeze()    # Convert image tensor to numpy array
        axs[i, j].imshow(image_numpy, cmap='gray')  # Plot the image
        axs[i, j].axis('off')  # Turn off axis
        axs[i, j].set_title(f"{labels_map[label]}")  # Set title with label

plt.tight_layout()  # Adjust layout
plt.show()  # Show plot

In [None]:
def show_tensor_images(image_tensor, num_images=9, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''

    # Move the image tensor to CPU
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=3)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')       
    plt.show()

# (5) Modeling

In [None]:
IMG_SIZE = 16
IMG_CH = 1
BATCH_SIZE = 128
N_CLASSES = 10

nrows = 10
ncols = 15

T = nrows * ncols
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)
ddpm = ddpm_utils.DDPM(B, device)

In [None]:
class_names = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9",
]

In [None]:
def get_context_mask(c, drop_prob):
    c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device)
    c_mask = torch.bernoulli(torch.ones_like(c_hot).float() - drop_prob).to(device)
    return c_hot, c_mask

In [None]:
# Build unconditional diffusion model
model = UNet_utils.UNet(
    T, IMG_CH, IMG_SIZE, down_chs=(64, 64, 128), t_embed_dim=8, c_embed_dim=N_CLASSES
)
print("Num params: ", sum(p.numel() for p in model.parameters()))
model = torch.compile(model.to(device))

In [None]:
def do_create_model(timestep, img_ch, img_size, n_classes, device):
    model = UNet_utils.UNet(
        timestep, img_ch, img_size, down_chs=(64, 64, 128), t_embed_dim=8, c_embed_dim=n_classes
    )
    print("Num params: ", sum(p.numel() for p in model.parameters()))
    model = torch.compile(model.to(device))    

    return model

# (6) Training

In [None]:
import torch._dynamo

# Suppress errors and warnings
torch._dynamo.config.suppress_errors = True

TORCH_LOGS="+dynamo" 
TORCHDYNAMO_VERBOSE=1

In [None]:
# def plot_losses(epoch_losses):
#     # Plot the losses
#     plt.figure(figsize=(10, 5))
#     plt.plot(range(1, len(epoch_losses) + 1), epoch_losses, marker='o', linestyle='-', color='b')
#     plt.xlabel('Epoch')
#     plt.ylabel('Loss')
#     plt.title('Loss per Epoch')
#     plt.grid(True)
#     plt.show()

def plot_eval_curves(epoch_losses, filename):
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(epoch_losses) + 1), epoch_losses, marker='o', linestyle='-', color='b')
    # plt.plot(np.arange(len(dis_loss_combine)), dis_loss_combine,'r')
    # plt.plot(np.arange(len(gen_loss_combine)), gen_loss_combine,'b')
    # plt.legend(['Dis Loss','Gen Loss'])
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss per Epoch')
    plt.grid(True)
    plt.savefig(filename)
    plt.show()

In [None]:
def get_optimizer(model, lr):
    optimizer = Adam(model.parameters(), lr)
    return optimizer

In [None]:
def do_train(model, dataloader, opt, epochs, batch_size, img_ch, img_size, n_cols, n_classes, device):
    preview_c = 0
    
    model.train() # Sets the model to training mode
    epoch_losses = [] # Initialize a list to store losses for each epoch
    
    for epoch in range(epochs):
        epoch_loss = 0 # Initialize epoch loss
        for step, batch in enumerate(dataloader):
            c_drop_prob = 0.2
            optimizer.zero_grad()
    
            t = torch.randint(0, T, (batch_size,), device=device).float()
            x = batch[0].to(device)
    
            c_hot, c_mask = get_context_mask(batch[1], c_drop_prob)  # New
            loss = ddpm.get_loss(model, x, t, c_hot, c_mask)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() # Accumulate loss
            
            if epoch % 1 == 0 and step % 100 == 0:
                class_name = class_names[preview_c]
                print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()} | C: {class_name}")
                c_drop_prob = 0 # Do not drop context for preview
                c_hot, c_mask = get_context_mask(torch.Tensor([preview_c]), c_drop_prob)
                ddpm.sample_images(model, img_ch, img_size, ncols, c_hot, c_mask)
                preview_c = (preview_c + 1) % n_classes
                
        epoch_losses.append(epoch_loss / len(dataloader))
        print(f"Epoch {epoch} Average Loss: {epoch_loss / len(dataloader)}") # Print average epoch loss

    return epoch_losses # Return the list of epoch losses

## (6a) Number of Epochs: 50

In [None]:
epochs = 50

cdif_1_model = do_create_model(T, IMG_CH, IMG_SIZE, N_CLASSES, device)
optimizer = get_optimizer(cdif_1_model, 0.001)
epoch_losses_1 = do_train(cdif_1_model, dataloader, optimizer, epochs, BATCH_SIZE, IMG_CH, IMG_SIZE, ncols, N_CLASSES, device)

In [None]:
plot_eval_curves(epoch_losses_1, 'cdif_1_loss.png')

In [None]:
print (epoch_losses_1)

## (6b) Number of Epochs: 100

In [None]:
epochs = 100

cdif_2_model = do_create_model(T, IMG_CH, IMG_SIZE, N_CLASSES, device)
optimizer = get_optimizer(cdif_2_model, 0.001)
epoch_losses_2 = do_train(cdif_2_model, dataloader, optimizer, epochs, BATCH_SIZE, IMG_CH, IMG_SIZE, ncols, N_CLASSES, device)

In [None]:
plot_eval_curves(epoch_losses_2, 'cdif_2_loss.png')

In [None]:
print (epoch_losses_2)

## (6c) Change learning rate : 0.0001

In [None]:
epochs = 100

cdif_3_model = do_create_model(T, IMG_CH, IMG_SIZE, N_CLASSES, device)
optimizer = get_optimizer(cdif_3_model, 0.0001)
epoch_losses_3 = do_train(cdif_3_model, dataloader, optimizer, epochs, BATCH_SIZE, IMG_CH, IMG_SIZE, ncols, N_CLASSES, device)

In [None]:
plot_eval_curves(epoch_losses_3, 'cdif_3_loss.png')

In [None]:
print (epoch_losses_3)

## (6d) Change Batch Size to 64 (was 128)

In [None]:
NEW_BATCH_SIZE = 64

dataloader_64 = DataLoader(dataset, batch_size=NEW_BATCH_SIZE, shuffle=True, drop_last=True)

epochs = 100

cdif_4_model = do_create_model(T, IMG_CH, IMG_SIZE, N_CLASSES, device)
optimizer = get_optimizer(cdif_4_model, 0.001)
epoch_losses_4 = do_train(cdif_4_model, dataloader_64, optimizer, epochs, NEW_BATCH_SIZE, IMG_CH, IMG_SIZE, ncols, N_CLASSES, device)

In [None]:
plot_eval_curves(epoch_losses_4, 'cdif_4_loss.png')

In [None]:
print (epoch_losses_4)

## Display final results of each class

In [None]:
# Number of columns
ncols = 3
# c_drop_prob = 0 # Change me to a value between 1 and 0
c_drop_prob = 0.5

def display_results(model, class_labels, num_classes, img_ch, img_size, n_cols, c_drop_prob):
    plt.figure(figsize=(8,8))
    
    for c in range(num_classes):
        print(class_names[c], ' ', labels_map[int(class_names[c])])
        c_hot, c_mask = get_context_mask(torch.Tensor([c]), c_drop_prob)
        ddpm.sample_images(model, img_ch, img_size, n_cols, c_hot, c_mask, axis_on=True)

In [None]:
display_results(cdif_1_model, class_names, N_CLASSES, IMG_CH, IMG_SIZE, ncols, c_drop_prob)

In [None]:
display_results(cdif_2_model, class_names, N_CLASSES, IMG_CH, IMG_SIZE, ncols, c_drop_prob)

In [None]:
display_results(cdif_3_model, class_names, N_CLASSES, IMG_CH, IMG_SIZE, ncols, c_drop_prob)

In [None]:
display_results(cdif_4_model, class_names, N_CLASSES, IMG_CH, IMG_SIZE, ncols, c_drop_prob)

## Conditioning Reverse Diffusion

In [None]:
@torch.no_grad()
def sample_w(
    model, input_size, T, c, w_tests=[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], store_freq=10
):
    # Preprase "grid of samples" with w for rows and c for columns
    n_samples = len(w_tests) * len(c)

    # One w for each c
    w = torch.tensor(w_tests).float().repeat_interleave(len(c))
    w = w[:, None, None, None].to(device)  # Make w broadcastable
    x_t = torch.randn(n_samples, *input_size).to(device)

    # One c for each w
    c = c.repeat(len(w_tests), 1)

    # Double the batch
    c = c.repeat(2, 1)

    # Don't drop context at test time
    c_mask = torch.ones_like(c).to(device)
    c_mask[n_samples:] = 0.0

    x_t_store = []
    for i in range(0, T)[::-1]:
        # Duplicate t for each sample
        t = torch.tensor([i]).to(device)
        t = t.repeat(n_samples, 1, 1, 1)

        # Double the batch
        x_t = x_t.repeat(2, 1, 1, 1)
        t = t.repeat(2, 1, 1, 1)

        # Find weighted noise
        e_t = model(x_t, t, c, c_mask)
        e_t_keep_c = e_t[:n_samples]
        e_t_drop_c = e_t[n_samples:]
        e_t = (1 + w) * e_t_keep_c - w * e_t_drop_c

        # Deduplicate batch for reverse diffusion
        x_t = x_t[:n_samples]
        t = t[:n_samples]
        x_t = ddpm.reverse_q(x_t, t, e_t)

        # Store values for animation
        if i % store_freq == 0 or i == T or i < 10:
            x_t_store.append(x_t)

    x_t_store = torch.stack(x_t_store)
    return x_t, x_t_store

# (7) Model Evaluation

In [None]:
# Keep all category information for sampling by setting c_drop_prob to 0
def evaluate(model, input_size, n_classes, T, c, filename, c_drop_prob=0):
    c_hot, c_mask = get_context_mask(c, c_drop_prob)
    x_0, x_t_store = sample_w(model, input_size, T, c_hot.float())

    grids = [other_utils.to_image(make_grid(x_t.cpu(), nrow = n_classes)) for x_t in x_t_store]
    other_utils.save_animation(grids, filename)

In [None]:
c = torch.arange(N_CLASSES).to(device) # Generates all 10 fashion items
input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)

evaluate(cdif_1_model, input_size, N_CLASSES, T, c, "fashion_1.gif")

In [None]:
evaluate(cdif_2_model, input_size, N_CLASSES, T, c, "fashion_2.gif")

In [None]:
c = torch.arange(N_CLASSES).to(device) # Generates all 10 fashion items
input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)

evaluate(cdif_3_model, input_size, N_CLASSES, T, c, "fashion_3.gif")

In [None]:
c = torch.arange(N_CLASSES).to(device) # Generates all 10 fashion items
input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)

evaluate(cdif_4_model, input_size, N_CLASSES, T, c, "fashion_4.gif")

# (8) Generating New Samples From Generative Models

In [None]:
labels = [7] # Sneaker
c = torch.tensor(labels)
n_classes = 10
evaluate(cdif_1_model, input_size, n_classes, T, c, "fashion-1a.gif")

In [None]:
labels = [7] # Sneaker
c = torch.tensor(labels)
n_classes = 10
evaluate(cdif_2_model, input_size, n_classes, T, c, "fashion-2a.gif")

In [None]:
labels = [7] # Sneaker
c = torch.tensor(labels)
n_classes = 10
evaluate(cdif_3_model, input_size, n_classes, T, c, "fashion-3a.gif")

In [None]:
labels = [7] # Sneaker
c = torch.tensor(labels)
n_classes = 10
evaluate(cdif_4_model, input_size, n_classes, T, c, "fashion-4a.gif")

# (9) Saving the trained model

In [None]:
torch.save(cdif_1_model.state_dict(), 'cdif_generator_1.pth')
torch.save(cdif_2_model.state_dict(), 'cdif_generator_2.pth')
torch.save(cdif_3_model.state_dict(), 'cdif_generator_3.pth')
torch.save(cdif_4_model.state_dict(), 'cdif_generator_4.pth')