In [None]:
import torch
import torch.nn as nn

from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from IPython import get_ipython
from IPython.display import Image as IPImage, display
import matplotlib.pyplot as plt


import os
import numpy as np
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_dir = './data'
os.makedirs(data_dir, exist_ok=True)

image_dir = 'images'
os.makedirs(image_dir, exist_ok=True) #Create the images directory if it does not exist, and proceed without error if it does exist

In [None]:
params = {
    'num_classes': 10,
    'latent_space': 100,
    'input_size': (1,32,32),
    'image_size': 32,
    'lr': 2e-4,
    'b1': 0.5,
    'b2': 0.999,
    'epochs': 100,
    'batch_size': 64,
}

In [None]:
train_transform = transforms.Compose([
    transforms.ToTensor(), # Convert image to tensor
    transforms.Resize(params['image_size']), # Resizes the image to the specified size
    transforms.Normalize([0.5],[0.5]), # Normalizes image with mean and std
    ])

"""
Args:
    - ToTensor: Converts image to a tensor and normalizes pixel values to [0, 1]
    - Resize: Resizes the image to the specified size
    - Normalize: Normalizes image with a mean of 0.5 and std of 0.5 for all channels

Returns
    - train_transform (torchvision.transforms.Compose): The composed transformation function
"""

train_dataset = datasets.MNIST(data_dir, train=True, transform=train_transform, download=True)

"""
Args:
    - data_dir (str): Directory to store the dataset
    - train (bool): If True, loads the training set; otherwise loads the test set
    - transform (callable): Transformation function to apply to the images
    - download (bool): If True, downloads the dataset if not found locally

Returns:
    - train_dataset (torchvision.datasets.MNIST): The MNIST dataset
"""

train_dataloader = DataLoader(train_dataset, params['batch_size'], shuffle=True)

"""
Args:
    - train_dataset (torch.utils.data.Dataset): The dataset to load
    - batch_size (int): Number of samples per batch
    - shuffle (bool): Whether to shuffle the dataset

Returns:
    - train_dataloader (torch.utils.data.DataLoader): DataLoader object to iterate over the dataset
"""

In [None]:
# Conditional GAN Generator

class Generator(nn.Module):

   """
   Generator for CGAN.

   Args:
       params (dict): Dictionary of model paramets with keys:
       - num_classes (int): Number of classes for label embedding.
       - latent_space (int): Dimensionality of the noise vector.
       - input_size (tuple): Shape of the generated image (channels, height, width)

   Returns:
       torch.Tensor: Generated Images of shape(batch_size, *input_size)

   Example:
       >>> params = {'num_classes': 10, 'latent_space': 100, 'input_size': (1, 32, 32)}
       >>> generator = Generator(params)
       >>> noise = torch.randn(16, 100)
       >>> labels = torch.randint(0, 10, (16,))
       >>> generated_images = generator(noise, labels)
       >>> print(generated_images.shape)  # torch.Size([16, 1, 32, 32])
   """

    def __init__(self, params):
        super().__init__()
        self.num_classes = params['num_classes']
        self.latent_dim = params['latent_space']
        self.input_size = params['input_size']

        # Label embedding matrix
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)

        # Generator Model
        self.model = nn.Sequential(
            *self.block(self.latent_dim + self.num_classes, 128, normalize=False),
            *self.block(128,256),
            *self.block(256,512),
            *self.block(512,1024),
            nn.Linear(1024, int(np.prod(self.input_size))), # last fully connected layer
            nn.Tanh() # output: [-1, 1]
        )

    def block(self, in_channels, out_channels, normalize=True):

        """
        Defines a block consisting of a fully connected layer, optional batch normalization, and a LeakyReLU activation function.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            normalize (bool, optional): Whether to apply Batch Normalization (default=True).

        Returns:
            list: A list containing the layers of the block, including Linear, BatchNorm1d (if normalize=True),
            and LeakyReLU activation functions.
        """

        layers = []
        layers.append(nn.Linear(in_channels, out_channels)) # fc layer
        if normalize:
            layers.append(nn.BatchNorm1d(out_channels, 0.8)) # Batch Normalization
        layers.append(nn.LeakyReLU(0.2)) # LeakyReLU
        return layers

    def forward(self, noise, labels):

        """
        Forward pass of the Generator. Concatenates the label embedding with the noise and
        passes it through the model to generate an image.

        Args:
            noise (torch.Tensor): Input noise vector of shape (batch_size, latent_dim).
            labels (torch.Tensor): Input labels of shape (batch_size,).

        Returns:
            torch.Tensor: Generated image of shape (batch_size, channels, height, width).
        """

        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.input_size)
        return img

In [None]:
# Conditional GAN Discriminator

class Discriminator(nn.Module):

    """
    Discriminator for CGAN.

    Args:
        params (dict): Dictionary of model parameters with keys:
            - num_classes (int): Number of classes for label embedding.
            - input_size (tuple): Shape of the input image (channels, height, width).

    Returns:
        torch.Tensor: Validity scores of shape (batch_size, 1).

    Example:
        >>> params = {'num_classes': 10, 'input_size': (1, 32, 32)}
        >>> discriminator = Discriminator(params)
        >>> images = torch.randn(16, 1, 32, 32)
        >>> labels = torch.randint(0, 10, (16,))
        >>> validity = discriminator(images, labels)
        >>> print(validity.shape)  # torch.Size([16, 1])
    """

    def __init__(self,params):
        super().__init__()
        self.num_classes = params['num_classes']
        self.input_size = params['input_size']

        # Label embedding matrix
        self.label_emb = nn.Embedding(self.num_classes, self.num_classes)

        self.model = nn.Sequential(
            nn.Linear(self.num_classes + int(np.prod(self.input_size)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1), # Final fc layer, output: 0 or 1
        )

    def forward(self, img, labels):
        """
        Forward pass of the Discriminator. Concatenates the label embedding with the image
        and passes it through the model to predict whether the image is real or fake.

        Args:
            img (torch.Tensor): Input image of shape (batch_size, channels, height, width).
            labels (torch.Tensor): Input labels of shape (batch_size,).

        Returns:
            torch.Tensor: Validity score of shape (batch_size, 1) indicating whether the image is real or fake.
        """

        # Concatenate label embedding and image to produce input
        dis_input = torch.cat((img.view(img.size(0), -1), self.label_emb(labels)), -1)
        validity = self.model(dis_input)
        return validity

In [None]:
# Loss function
adversarial_loss = torch.nn.MSELoss() # Minimize the discriminate probability difference between true image and false image.

"""
Args:
    - input (torch.Tensor): The predicted values from the model
    - target (torch.Tensor): The true labels or ground truth

Returns:
    - loss (torch.Tensor): The computed loss value
"""

# Initialize generator and discriminator
generator = Generator(params)
discriminator = Discriminator(params)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))

"""
Args:
    - params (iterable): The parameters of the model to optimize
    - lr (float): Learning rate, controls how much the model weights are adjusted at each step
    - betas (tuple): Beta values to control momentums (default: (0.5, 0.999))

Returns:
    - optimizer_G (torch.optim.Adam): The Adam optimizer for the generator
"""

In [None]:
# Sample Image

def sample_image(n_row, batches_done):

    """
    Saves and optionally displays a grid of generated images.

    Args:
        n_row (int): Number of rows/columns in the grid of images.
        batches_done (int): Identifier for the batch (used for saving file name).

    Returns:
        None(store image and print).
    """

    # Sample noise
    z = torch.randn(n_row ** 2, params['latent_space'])
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = torch.tensor(labels)
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)

     # Save the image
    file_path = f"images/{batches_done}.png"
    save_image(gen_imgs.data, file_path, nrow=n_row, normalize=True)

    # Display the image in Colab
    display(IPImage(file_path))

In [None]:
batch_count = 0

for epoch in tqdm(range(params['epochs'])):
    for i, (imgs, labels) in enumerate(train_dataloader):
        """
        Args:
            epoch(int): The curren epoch number in the training process.
            imgs(torch.Tensor): A batch of input images with shape (batch_size, channels, height, width).
            labels(torch.Tensor): A batch of unput labels with shape (batch_size,).
            z(torch.Tensor): The noise vector used as input to the Generator, with shape (batch_size, latent_space).
            gen_labels(torch.Tensor): The labels for the generated images, with shape(batch_size,).
            valid(torch.Tensor): The target values for real images, typically filled with 1.0, indicating 'real' images.
            fake(torch.Tensor): The target values for generated (fake) images, typically filled with 0.0, indicating 'fake' images.

        Returns:
            None(Update Generator and Discriminator every batch while training)
        """

        # Adversarial ground truths
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device)
        fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)

        real_imgs = imgs.to(device)
        real_labels = labels.to(device)

        # Train Generator
        optimizer_G.zero_grad()

        # Sample noise as Generator input
        z = torch.randn(imgs.size(0), params['latent_space']).to(device)
        gen_labels = torch.randint(0, params['num_classes'], (imgs.size(0),)).to(device)

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measure Generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs, real_labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        """
        Generator Loss:
          - The loss is the Binary Cross-Entropy loss between the Discriminator's output for the generated images and the target label '1'.
          - The Generator loss encourages it to produce images that are more likely to be classified as real by the Discriminator.
        Discriminator Loss:
          - The total loss is the sum of:
            - 'real_loss': Binary Cross-Entropy loss for real images.
            - 'fake_loss': Binary Cross-Entropy loss for generated images.
          - The Discriminator loss encourages it to output a high value for real images and a low value for fake images.
        """

        batches_done = epoch * len(train_dataloader) + i
        if batches_done % 1000 == 0:
            print(
                "[Epoch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, params['epochs'], d_loss.item(), g_loss.item())
            )
            sample_image(n_row=10, batches_done=batches_done)

      """
      About Result:
         Discriminator Loss: If loss is close to 0, discriminator is classifying well.
         Generator Loss: Generator is generating better images. Especially G loss is stable in certain range,
                         Generator is generating good quality images consistently.
      """


Output hidden; open in https://colab.research.google.com to view.