# Building and Training a cGAN Model


## Contents



- [Intro](#intro)
- [Decoder](#decoder)
- [Encoder](#encoder)
- [cGAN Model](#cgan-model)
- [Digit Sampler](#digit-sampler)
- [Data](#data)
- [Training](#training)
- [Exploring Embeddings](#exploring-embeddings)

### What is the target of this workshop?

### cGAN (paper - https://arxiv.org/abs/1411.1784)

The goal of this workshop is to construct and train a cGAN model from scratch using the mnist dataset.

This will be making use of the TensorFlow functional API to build a Generator and Discriminator model in separate and then combining these to train a complete cGAN model.

### Notebook Breakdown

As with previous ML notebooks the sections marked **##FINISH ME##** are to be completed by you.

Marks for the different parts are shown below.

* Sections are intended to be tackled in order, i.e. 1->9
* In this notebook different sections can be tackled independently
* Parts 1-5 are due as part of Homework 3, by 9:30am on Friday 27 February (though we recommend finishing this earlier!)

| <p align='left'> Title                         | <p align='left'> Parts | <p align='left'> Marks possible |  <p align='left'> Marks awarded |
| ------------------------------------- | ----- | --- | --- |
| <p align='left'> 1. Complete and test the Decoder(Generator) class | <p align='left'>  1  | <p align='left'> 1 | |
| <p align='left'> 2. Complete and test the Encoder(Classifier) class | <p align='left'>  1  | <p align='left'> 1 | |
| <p align='left'> 3. Complete the complete cGAN model | <p align='left'>  1  | <p align='left'> 1 | |
| <p align='left'> 4. Write a method to generate&plot a given image based on its classification | <p align='left'>  1  | <p align='left'> 1 | |
| <p align='left'> 5. Train the cGAN model through 10 epoch | <p align='left'>  1  | <p align='left'> 1 | |
| <p align='left'> 6. Examine the Decoder embedding class | <p align='left'> 1 | <p align='left'> -- | -- |
| <p align='left'> 7. Examine the Encoder embedding class  | <p align='left'> 1 | <p align='left'> -- | -- |
| <p align='left'> 8. Could the embeddings from Encoder&Decoder be shared?  | <p align='left'> 1 | <p align='left'> -- | -- |
| <p align='left'> **Total** | | <p align='left'> max **5** | |


---

One of the goals of this workshop is to understand how to embed non-image information alongside image information in the same training.

This can be in-principle done in many different ways.

The approach adopted here is to add an additional channel(s) of embedded information per-event alongside the input to the model. This is done by first embedding and then casting this information to a mask over the whole image.

This allows the model in training to determine the best embedding for each species and allows the model to use this information alongside the image data.

## Intro

In [None]:
# ----- set thread env vars BEFORE importing torch/numpy/etc. -----
import os
#os.environ["OMP_NUM_THREADS"] = "12"       # OpenMP threads
#os.environ["MKL_NUM_THREADS"] = "12"       # Intel MKL threads
#os.environ["OPENBLAS_NUM_THREADS"] = "12"  # OpenBLAS threads
#os.environ["NUMEXPR_NUM_THREADS"] = "12"   # NumExpr threads
os.environ["TORCH_NUM_THREADS"] = "12"     # PyTorch CPU threads
#os.environ["MKL_DYNAMIC"] = "FALSE"        # Disable Intel MKL dynamic threading
#os.environ["OMP_WAIT_POLICY"] = "ACTIVE"   # Active waiting for better performance
#os.environ["GOMP_NUM_THREADS"] = "12"      # GNU OpenMP threads
#os.environ["KMP_NUM_THREADS"] = "12"       # Intel OpenMP threads

# ============================================================
# Set Random Seeds for Reproducibility
# ============================================================
import random

# Set seed value for reproducibility
SEED = 42

# Python's built-in random module
random.seed(SEED)

# NumPy random seed
import numpy as np
np.random.seed(SEED)

# ----- now import the rest -----
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt


# PyTorch random seed
torch.manual_seed(SEED)

# CUDA random seed (if using GPU)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

# Set cuDNN to deterministic mode for reproducibility
# Note: This may reduce performance slightly
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Explicitly set PyTorch to use all 12 threads for CPU operations
torch.set_num_threads(12)
torch.set_num_interop_threads(12)

print(f"Random seeds set to {SEED} for reproducibility")
print(f"PyTorch threads: {torch.get_num_threads()}")
print(f"PyTorch interop threads: {torch.get_num_interop_threads()}")

In [None]:
device = "cpu"

## Decoder

In [None]:
# ============================================================
# Decoder (Generator)
# ============================================================

class Decoder(nn.Module):
    """
    Decoder (Generator):
    - Input: noise vectors + digit labels
    - Output: generated images with pixels in range [0, 1]
    - Output shape: (batch_size, 1, 28, 28) for MNIST

    Conditioning ("c" in cGAN):
    - Each label is embedded to a small vector.
    - That vector is broadcast to a spatial label feature map and concatenated as channels.

    Shapes:
    - noise_vectors: (batch_size, noise_vector_size)
    - digit_labels: (batch_size,) integers 0..9
    - label_vectors: (batch_size, label_embedding_size)
    - generated_images: (batch_size, 1, 28, 28)
    """

    def __init__(self, noise_vector_size, number_of_classes, label_embedding_size, feature_size):
        """
        Mode notes:
        - decoder_model.train(): normalization layers update running stats and use batch stats.
        - decoder_model.eval(): normalization layers use stored stats (more stable for sampling).

        Parameters:
        - noise_vector_size: size of random noise input per image
        - number_of_classes: number of label classes (MNIST digits -> 10)
        - label_embedding_size: size of the learned label embedding vector
        - feature_size: controls width/capacity of convolution layers
        """
        super().__init__()
        self.noise_vector_size = noise_vector_size
        self.label_embedding_size = label_embedding_size

        # Embedding table maps integer digit labels -> dense vectors.
        self.label_embedding_table = nn.Embedding(number_of_classes, ## FINISH_ME ## fill in label_embedding_size here)

        # Noise-only seed: labels are injected later via channel concatenation.
        internal_logits_dim = feature_size * 4 * 7 * 7
        self.seed_feature_map_linear_layer = nn.Linear(noise_vector_size, internal_logits_dim)
        self.seed_feature_map_batch_normalization = nn.BatchNorm1d(internal_logits_dim)

        # Upsampling stack expects a conditioned seed map with
        # (feature_size*4 + label_embedding_size) channels at 7x7.
        self.upsampling_layers = nn.Sequential(
            nn.ConvTranspose2d(feature_size * 4 + label_embedding_size, feature_size * 2, 4, 2, 1, bias=False),
            nn.LayerNorm((feature_size * 2, 14, 14), eps=1e-6),
            nn.ReLU(),

            nn.ConvTranspose2d(feature_size * 2, feature_size, 4, 2, 1, bias=False),
            ## FINISH_ME ## add LayerNorm and ReLU here like above

            ## Sample down to 1 channel output
            nn.Conv2d(feature_size, 1, 3, 1, 1)

            ## FINISH_ME ## add an activation here to ensure output pixels are in [0, 1] range (e.g., Sigmoid)
        )

        ## FINISH_ME ## call the weight initialization method

    def _initialize_weights(self):
        """Weight init with GAN-friendly defaults."""
        # Initialize weights for different layer types
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1.0)
                nn.init.constant_(module.bias, 0.0)

    def forward(self, noise_vectors, digit_labels):
        """Forward pass (conditional generation)."""

        # Noise -> seed feature map
        # The noise is transformed to a seed feature map of shape (batch_size, feature_size*4, 7, 7)
        batch_size = noise_vectors.size(0)
        seed_feature_vector = self.seed_feature_map_linear_layer(noise_vectors)
        seed_feature_vector = self.seed_feature_map_batch_normalization(seed_feature_vector)
        seed_feature_vector = F.relu(seed_feature_vector)
        seed_feature_map = seed_feature_vector.view(batch_size, -1, 7, 7)      # recall "-1" in reshaping just fills in with whatever # is required


        # digit_labels first need to be embedded and then need to be broadcast across the same spatial dims as the input noise
        # Label embedding -> reshape to (batch_size, embedding_size, 1, 1) then broadcast directly to (batch_size, embedding_size, 7, 7)
        # to match the seed feature map spatial size.
        label_vectors = self.label_embedding_table(digit_labels)  # (B, E, 1, 1)
        label_channel = label_vectors.view(-1, self.label_embedding_size, 1, 1).expand(-1, self.label_embedding_size, 7, 7)    # "-1" in expand means keep the same dimension


        # We now have a seed feature map and a label feature map, both with spatial dimensions 7x7.
        # Condition by concatenating the label feature map as extra channels.
        conditioned_seed = torch.cat([seed_feature_map, label_channel], dim=1)

        # Upsampling to generated images
        # This requires the conditioned seed to have shape ( ... , ... , 7, 7)
        generated_images = self.upsampling_layers(conditioned_seed)
        return generated_images

In [None]:
# Test Decoder
decoder = Decoder(noise_vector_size=16, number_of_classes=10, label_embedding_size=3, feature_size=16).to(device)

# Test with batch of 8
noise = torch.randn(8, 16, device=device)
labels = torch.randint(0, 10, (8,), device=device)

with torch.no_grad():
    output = decoder(noise, labels)

print(f"Input: noise {noise.shape}, labels {labels.shape}")
print(f"Output: {output.shape}")
print(f"Expected: torch.Size([8, 1, 28, 28])")
print(f"✓ Test passed!" if output.shape == (8, 1, 28, 28) else "✗ Test failed!")

## Encoder

In [None]:
# ============================================================
# Encoder (Discriminator) built with Sequential (mirrors Decoder style)
# ============================================================

class Encoder(nn.Module):
    """
    Encoder (Discriminator):
    - Input: images + digit labels
    - Output: probabilities (higher -> more "real", lower -> more "fake") in [0, 1]

    Conditioning:
    - Each label is embedded to a small vector of size label_embedding_size.
    - That vector is broadcast to (label_embedding_size, 28, 28) and concatenated with the image.
      This makes the input have (1 + label_embedding_size) channels.
    """

    def __init__(self, number_of_classes, feature_size, label_embedding_size):
        """
        Mode notes:
        - encoder_model.train(): normalization layers update running stats and use batch stats.
        - encoder_model.eval(): normalization layers use stored stats (useful for evaluation).

        Parameters:
        - number_of_classes: 10 for MNIST
        - feature_size: controls width/capacity of convolution layers
        - label_embedding_size: number of label-conditioning channels to concatenate
        """
        super().__init__()

        self.label_embedding_table = nn.Embedding(number_of_classes, label_embedding_size)
        self.label_embedding_size = label_embedding_size

        # IMPORTANT: no in-place activations (avoids autograd versioning errors)
        # Convolutional stack processes (image + label_channel) while downsampling spatial dims.
        self.convolution_stack = nn.Sequential(
            # Input = (1 + label_embedding_size) channels x 28 x 28
            nn.Conv2d(1 + label_embedding_size, feature_size, 4, 2, 1),
            nn.LeakyReLU(0.2),

            # feature_size x 14 x 14 -> (feature_size*2) x 7 x 7
            nn.Conv2d(feature_size, feature_size * 2, 4, 2, 1, bias=False),
            nn.LayerNorm((feature_size * 2, 7, 7), eps=1e-6),
            nn.LeakyReLU(0.2),

            # (feature_size*2) x 7 x 7 -> (feature_size*4) x 4 x 4
            nn.Conv2d(feature_size * 2, feature_size * 4, 3, 2, 1, bias=False),
            
            ## FINISH_ME ## add LayerNorm and LeakyReLU here like in the Decoder
            
            # Flatten to a vector and map to a single logit
            nn.Flatten(),
            nn.Linear(feature_size * 4 * 4 * 4, 1),  # logits

            ## FINISH_ME ## add an output activation here to ensure output probabilities are in [0, 1] range (e.g., Sigmoid)
        )

        ## FINISH_ME ## call the weight initialization method

    def _initialize_weights(self):
        """Initialize weights with GAN-friendly defaults."""
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
            elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1.0)
                nn.init.constant_(module.bias, 0.0)

    def forward(self, images, digit_labels):
        """
        Forward pass.

        Inputs:
        - images: (batch_size, 1, 28, 28) in [0, 1]
        - digit_labels: (batch_size,) integers 0..9

        Output:
        - output_probabilities: (batch_size, 1) in [0, 1]
        """
        # Images are of shape (batch_size, 1, 28, 28)
        # Digit labels are of shape (batch_size, 1, 1, 1)

        # As with the Decoder, embed labels and broadcast to spatial size 28x28

        ## FINISH_ME ## we need to embed from our digit_labels into a new set of labels with dim (batch_size, label_embedding_size, 28, 28) like we did in the Decoder 

        # Have (batch_size, 1, 28, 28) images and (batch_size, label_embedding_size, 28, 28) label channels.
        # Concatenate along channel dimension -> (batch_size, 1 + label_embedding_size, 28, 28)
        combined_input_tensor = torch.cat([images, label_channel], dim=1)

        # Forward through convolutional stack
        logits = self.convolution_stack(combined_input_tensor)

        return logits

In [None]:
# Test Encoder
encoder = Encoder(number_of_classes=10, feature_size=16, label_embedding_size=3).to(device)

# Test with batch of 8 images
images = torch.randn(8, 1, 28, 28, device=device)
labels = torch.randint(0, 10, (8,), device=device)

with torch.no_grad():
    logits = encoder(images, labels)

print(f"Input: images {images.shape}, labels {labels.shape}")
print(f"Output: {logits.shape}")
print(f"Expected: torch.Size([8, 1])")
print(f"✓ Test passed!" if logits.shape == (8, 1) else "✗ Test failed!")

## cGAN model

In [None]:
# ============================================================
# Conditional GAN class (training + model management only)
# ============================================================

class ConditionalGAN:
    """
    High-level trainer/wrapper for a conditional GAN.

    Responsibilities:
    - Holds the Decoder (Generator) and Encoder (Discriminator)
    - Performs one training update per batch (Encoder step + Decoder step)
    - Tracks loss history for plotting

    Non-goals:
    - Visualization/sampling (handled by DigitSampler)
    """

    def __init__(
        self,
        noise_vector_size=32,          # dimensionality of input noise
        number_of_classes=10,          # number of digit classes (10 for MNIST)
        label_embedding_size=3,        # dimensionality of label embeddings (Decoder + Encoder conditioning channels)
        feature_size=32,               # channel multiplier for conv layers
        learning_rate=1e-3,            # learning rate for Adam optimizers
        optimizer_betas=(0.5, 0.999),  # Adam betas
        gradient_clip_max_norm=1.0,    # for stability
        device=None,
    ):
        """
        Mode notes:
        - During training, both models should be in train() mode.
        - Sampling is handled by DigitSampler (which uses eval() temporarily).

        Parameters:
        - gradient_clip_max_norm: maximum gradient norm (stability)
        - optimizer_betas: Adam betas (beta1, beta2)
          - beta1 controls momentum of the first-moment estimate; 0.5 is common for GAN stability.
          - beta2 controls the decay of the second-moment estimate and is typically kept high.
        """
        self.noise_vector_size = noise_vector_size
        self.number_of_classes = number_of_classes
        self.gradient_clip_max_norm = gradient_clip_max_norm
        self.device = device

        # Models
        self.decoder_model = Decoder(
            noise_vector_size=noise_vector_size,
            number_of_classes=number_of_classes,
            label_embedding_size=label_embedding_size,
            feature_size=feature_size,
        ).to(self.device)

        self.encoder_model = Encoder(
            number_of_classes=number_of_classes,
            feature_size=feature_size,
            label_embedding_size=label_embedding_size,
        ).to(self.device)

        # Optimizers
        self.decoder_optimizer = optim.Adam(
            self.decoder_model.parameters(),
            lr=learning_rate,
            betas=optimizer_betas,
        )
        self.encoder_optimizer =  ## FINISH_ME ## create Adam optimizer for encoder_model like above

        # Training bookkeeping
        self.training_step = 0
        self.encoder_loss_history = []
        self.decoder_loss_history = []

    def create_noise_vectors(self, batch_size):
        """Creates fresh noise vectors on the correct device."""
        return torch.randn(batch_size, self.noise_vector_size, device=self.device)

    def update_encoder_model(self, real_images, digit_labels):
        """
        One update step for the Encoder (Discriminator).

        Key idea:
        - Train D to assign high probability to real images and low probability to fake images.
        - Fake images are generated by the Decoder and then detached so gradients do not flow
          into the Decoder during the Encoder update.

        Returns:
        - encoder_loss_value: float
        - value_function_value: float (useful for monitoring)
        """
        batch_size = real_images.size(0)

        noise_vectors = self.create_noise_vectors(batch_size)
        fake_images_detached = self.decoder_model(noise_vectors, digit_labels).detach()

        encoder_probs_for_real = self.encoder_model(real_images, digit_labels)
        encoder_probs_for_fake = self.encoder_model( ## FINISH_ME ## use fake_images_detached and digit_labels here)

        eps = 1e-6
        log_probability_real_is_real = torch.log(encoder_probs_for_real + eps)
        log_probability_fake_is_fake = torch.log(1.0 - encoder_probs_for_fake + eps)

        value_function = log_probability_real_is_real.mean() + log_probability_fake_is_fake.mean()
        encoder_loss = -value_function  # maximize V <=> minimize -V


        self.encoder_optimizer.zero_grad(set_to_none=True)
        encoder_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.encoder_model.parameters(), self.gradient_clip_max_norm)
        self.encoder_optimizer.step()

        return encoder_loss.item(), value_function.item()

    def update_decoder_model(self, digit_labels):
        """
        One update step for the Decoder (Generator).

        Key idea:
        - Update G using the Encoder signal: gradients flow through Encoder into Decoder.
        - This uses the classic minimax form E[log(1 - D(G(z)))]. (Many implementations
          instead use the non-saturating loss -E[log D(G(z))] for stronger gradients.)

        Returns:
        - decoder_loss_value: float
        """
        batch_size = digit_labels.size(0)
        noise_vectors = self.create_noise_vectors(batch_size)

        fake_images = self.decoder_model( ## FINISH_ME ## use noise_vectors and digit_labels here)
        encoder_probs_for_fake = self.encoder_model( ## FINISH_ME ## use fake_images and digit_labels here)

        eps = 1e-6
        decoder_loss = torch.log(1.0 - encoder_probs_for_fake + eps).mean()

        self.decoder_optimizer.zero_grad(set_to_none=True)
        decoder_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.decoder_model.parameters(), self.gradient_clip_max_norm)
        self.decoder_optimizer.step()

        return decoder_loss.item()

    def train_one_epoch(self, training_data_loader):
        """
        Trains for one full epoch over the DataLoader.

        One training step per batch:
        1) Encoder update
        2) Decoder update
        """
        self.encoder_model.train()
        self.decoder_model.train()

        progress_bar = tqdm(training_data_loader, desc="Training", leave=True)
        for real_images, digit_labels in progress_bar:
            real_images = real_images.to(self.device)
            digit_labels = digit_labels.to(self.device)

            encoder_loss_value, value_function_value = self.update_encoder_model(real_images, digit_labels)
            decoder_loss_value = self.update_decoder_model(digit_labels)

            self.encoder_loss_history.append( ## FINISH_ME ## append encoder_loss_value to encoder_loss_history)
            self.decoder_loss_history.append( ## FINISH_ME ## append decoder_loss_value to decoder_loss_history)

            progress_bar.set_postfix(
                value_function=f"{value_function_value:.3f}",
                encoder_loss=f"{encoder_loss_value:.3f}",
                decoder_loss=f"{decoder_loss_value:.3f}",
                training_step=self.training_step,
            )
            self.training_step += 1


In [None]:
# Test ConditionalGAN
gan = ConditionalGAN(
    noise_vector_size=16,   # noise vector size
    number_of_classes=10,   # MNIST digits 0-9
    label_embedding_size=3, # label embedding size
    feature_size=16,        # feature size
    device=device
)

# Test create_noise_vectors
noise = gan.create_noise_vectors(batch_size=8)
print(f"Noise vectors: {noise.shape}")
print(f"Expected: torch.Size([8, 16])")  # matches noise_vector_size=16

# Test update_encoder_model with fake data
fake_images = torch.randn(8, 1, 28, 28, device=gan.device)
fake_labels = torch.randint(0, 10, (8,), device=gan.device)

encoder_loss, value_func = gan.update_encoder_model(fake_images, fake_labels)
print(f"\nEncoder update: loss={encoder_loss:.4f}, V={value_func:.4f}")

# Test update_decoder_model
decoder_loss = gan.update_decoder_model(fake_labels)
print(f"Decoder update: loss={decoder_loss:.4f}")


## Digit Sampler

In [None]:
# ============================================================
# Sampling helper (all sampling code lives here)
# ============================================================

class DigitSampler:
    """
    Sampling-only helper.

    Why this exists:
    - Keeps ConditionalGAN focused on training and optimization.
    - Centralizes eval()/train() toggling so sampling is stable/repeatable.

    Typical usage:
    - digit_sampler = DigitSampler(gan_model)
    - digit_sampler.display_digit_grid()
    """

    def __init__(self, gan_model):
        """
        Mode notes:
        - This class switches decoder_model to eval() while sampling,
          then restores train() afterward.

        Parameters:
        - gan_model: a ConditionalGAN instance
        """
        self.gan_model = gan_model
        self.decoder_model = gan_model.decoder_model
        self.noise_vector_size = gan_model.noise_vector_size
        self.number_of_classes = gan_model.number_of_classes
        self.device = gan_model.device

    def generate_single_digit(self, digit_label):
        """
        Generates a single digit image for a specific label.

        Mode notes:
        - Toggles decoder to eval() for normalization stability, then restores mode.
        - Uses torch.no_grad() to avoid autograd overhead.

        Args:
        - digit_label: integer 0-9

        Returns:
        - generated_image: (1, 1, 28, 28) tensor in [0, 1]
        """
        was_training = self.decoder_model.training
        self.decoder_model.eval()

        with torch.no_grad():
            digit_tensor = torch.tensor([digit_label], dtype=torch.long, device=self.device)
            noise = self.gan_model.create_noise_vectors(batch_size=1)
            generated_image = self.decoder_model(noise, digit_tensor)

        if was_training:
            self.decoder_model.train()
        return generated_image

    def display_digit_grid(self, repeats_per_digit=2):
        """
        Displays a grid of generated digits in the notebook.

        Mode notes:
        - eval() for stable normalization output
        - no_grad() for speed + lower memory
        """
        was_training = self.decoder_model.training
        self.decoder_model.eval()

        with torch.no_grad():
            generated_images = []
            for digit in range(self.number_of_classes):
                for _ in range(repeats_per_digit):
                    img = self.generate_single_digit(digit)  # (1, 1, 28, 28)
                    generated_images.append( ## FINISH_ME ## append img to generated_images list)

            generated_images = torch.cat(generated_images, dim=0)  # (N, 1, 28, 28)

            image_grid = make_grid(
                ## FINISH_ME ## use generated_images here,
                nrow= ## FINISH_ME ## set nrow to number_of_classes,
                normalize=True,
                value_range=(0, 1),
                padding=2,
            )

            plt.figure(figsize=(12, 4))
            plt.imshow(image_grid.permute(1, 2, 0).cpu().numpy())
            plt.axis("off")
            plt.tight_layout()
            plt.show()

        if was_training:
            self.decoder_model.train()

In [None]:
# ============================================================
# Data
# ============================================================

def build_mnist_training_data_loader(batch_size, number_of_workers=2):
    """
    Builds a DataLoader for the FULL MNIST training dataset.

    Mode notes:
    - Independent of model mode.
    - Outputs (real_images, digit_labels) per batch.

    Normalization:
    - MNIST ToTensor() -> [0,1]
    """
    # Compose basic transforms: ToTensor converts PIL [0,255] -> torch float [0,1]
    image_transform = transforms.Compose([
        transforms.ToTensor(),  # Convert PIL image to tensor in [0, 1]
    ])

    # Download and prepare the training dataset (60k MNIST training samples)
    training_dataset = datasets.MNIST(
        root="./data",            # directory to store downloaded data
        train=True,               # load training set (not test set)
        download=True,            # auto-download if not already present
        transform=image_transform # apply image transforms
    )

    # Construct the DataLoader
    training_data_loader = DataLoader( ## FINISH_ME ## make a DataLoader instance for the training data
        ## FINISH_ME ## use training_dataset here,
        ## FINISH_ME ## use batch_size here,
        ## FINISH_ME ## set shuffle to True,
        ## FINISH_ME ## use number_of_workers here,
    )

    return training_data_loader


## Training

In [None]:
number_of_epochs = 10  # Full training passes over the dataset
batch_size = 32       # Images per batch; trade-off between speed and stability

# Instantiate the cGAN
gan_model = ConditionalGAN(
    noise_vector_size=32,          # size of input noise vector
    number_of_classes=10,          # 10 for MNIST
    label_embedding_size=3,        # dimensionality of label embedding inside the Decoder
    feature_size=32,               # channel multiplier for Conv stacks (capacity)
    learning_rate=1e-3,            # Adam learning rate
    optimizer_betas=(0.5, 0.999),  # (beta1, beta2) for Adam; (0.5, 0.999) is common for GANs
    gradient_clip_max_norm=1.0,    # stabilize training by clipping large gradients
    device=device,
)

# Build data loader (uses torchvision MNIST; downloads on first run)
training_data_loader = build_mnist_training_data_loader(
    batch_size=batch_size, # Adjust based on memory and speed needs
    number_of_workers=4    # Increase to better use CPU cores
)

# Helper dedicated to sampling/visualizing digits from the Decoder
# Keeps training logic separate and clean

digit_sampler = DigitSampler(gan_model)

# ============================================================
for epoch_index in range(number_of_epochs):
    print(f"\nEpoch {epoch_index + 1}/{number_of_epochs}")

    # One full pass over the loader: Encoder step then Decoder step per batch
    gan_model.train_one_epoch(training_data_loader)

    # Sample once per epoch to monitor progression in generation quality
    print(f"\nEpoch {epoch_index + 1} samples:")
    digit_sampler.display_digit_grid()


In [None]:

# Visualize training curves
def plot_losses(gan_model):
    """Plots per-batch losses after training in the notebook."""

    # Convert to positive values for readability (losses are negative log-likelihood terms).
    encoder_losses_positive = [l for l in gan_model.encoder_loss_history]
    decoder_losses_positive = [-l for l in gan_model.decoder_loss_history]

    ## FINISH_ME ## create a figure and axis for plotting

plot_losses(gan_model)

## Exploring Embeddings

In [None]:
# The following code exrtacts the embedding weights from the trained Decoder's label embedding table.
# This can be used for analysis or visualization of how the model learned to represent digit classes.

decoder_embeddings = gan_model.decoder_model.label_embedding_table.weight.detach().cpu().numpy()
print(f"decoder_embeddings shape: {decoder_embeddings.shape}")

## FINISH_ME ## need to extract the trained encoder trained embeddings

Now we want to explore the embeddings from:

* Either the Encoder or Decoder (untrained)
* The Encoder from the trained cGAN
* The Decoder from the trained cGAN

It might be helpful to plot the distribution from different angles to compare how different numbers are embedded into this space.

The X, Y, Z axis should always have an appropriate range of `[-1.5,1.5]` or `[-2,2]` to make sure you can clearly see the impact of training.

How does this compare to when we looked at the LS for the AE for mnist?

In [None]:
# ============================================================
# Example color mapping visualization for the digit embeddings
# ============================================================
# Define consistent colors per digit (0-9)
digit_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
                '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# Create a dedicated color mapping image instead of printing the mapping
fig_colors, ax_colors = plt.subplots(figsize=(5, 5))
y_positions = list(range(10))
ax_colors.barh(y_positions, [1] * 10, color=digit_colors)
for y, color in zip(y_positions, digit_colors):
    ax_colors.text(
        0.5,
        y,
        f"Digit {y}  {color}",
        va='center',
        ha='center',
        color='white',
        fontweight='bold',
        path_effects=[patheffects.withStroke(linewidth=2.0, foreground='black')],
    )
ax_colors.set_yticks([])
ax_colors.set_xticks([])
ax_colors.set_xlim(0, 1)
ax_colors.invert_yaxis()
ax_colors.set_title("Digit Color Mapping", fontsize=12, fontweight='bold')
for spine in ax_colors.spines.values():
    spine.set_visible(False)
plt.tight_layout()
plt.show()


In [None]:
# ============================================================
# 3D Scatter Plot of Decoder Label Embeddings
# ============================================================

# Use raw decoder label embeddings directly (embedding size is 3)
decoder_embeddings_3d = decoder_embeddings
if decoder_embeddings_3d.shape[1] != 3:
    raise ValueError(f"Expected 3D embeddings, got {decoder_embeddings_3d.shape[1]} dimensions")

# Create 3D scatter plot with multiple viewing angles (stacked vertically for better readability)
# Multiple views help when points overlap along one axis
fig = plt.figure(figsize=(9, 16))

# Three different viewing angles for interactivity
angles = [(20, 45), (20, 135), (20, 225)]
titles = ['View 1 (45°)', 'View 2 (135°)', 'View 3 (225°)']

for idx, (elev, azim) in enumerate(angles):

    ax = fig.add_subplot(3, 1, idx + 1, projection='3d')
    
    # Plot each digit as a colored point with a large numeric label
    for digit in range(10):
        ax.scatter(
            ## FINISH_ME ## use decoder_embeddings_3d[digit, 0] here,
            ## FINISH_ME ## this is the y dimension,
            ## FINISH_ME ## this is the z dimension,
            c= ## FINISH_ME ## use the correct color from digit_colors here,
            s=350,
            alpha=0.8,
            edgecolors='black',
            linewidth=1.5
        )       
    
    # Plot axes lines to orient embedding directions
    ax.plot([0, 1.2], [0, 0], [0, 0], 'r-', linewidth=2, alpha=0.5)
    ax.plot([0, 0], [0, 1.2], [0, 0], 'g-', linewidth=2, alpha=0.5)
    ax.plot([0, 0], [0, 0], [0, 1.2], 'b-', linewidth=2, alpha=0.5)
    
    ax.set_xlabel('Embedding dim 1', fontsize=10)
    ax.set_ylabel('Embedding dim 2', fontsize=10)
    ax.set_zlabel('Embedding dim 3', fontsize=10)
    ax.set_xlim([-1.5, 1.5])
    ax.set_ylim([-1.5, 1.5])
    ax.set_zlim([-1.5, 1.5])
    ax.set_title(titles[idx], fontsize=11, fontweight='bold')

    ## This is the bit that gives us the different viewing angles(!)
    ax.view_init(elev=elev, azim=azim)
    ax.grid(True, alpha=0.15)

plt.tight_layout()
plt.show()


In [None]:
# ============================================================
# Untrained (vanilla) Decoder label embeddings (3D, single plot)
# ============================================================

## FINISH_ME ##

## First need to create a new untrained ConditionalGAN instance to extract its decoder embeddings
## Be careful not to overwrite the trained gan_model used above!

## Then as above, extract the decoder embeddings and plot them in 3D




In [None]:
# ============================================================
# Trained Decoder(Generator) label embeddings from cGAN (3D, single plot)
# ============================================================

## FINISH_ME ##

## Plot the trained decoder embeddings in 3D as above,
## How do they compare to the untrained ones?


In [None]:
# ============================================================
# Trained Encoder(Discriminator) label embeddings from cGAN (3D, single plot)
# ============================================================

## FINISH_ME ##

## Plot the trained encoder embeddings in 3D as above,
## How do they compare to the decoder embeddings?



Q: Now we've looked at the embeddings for both Encoder and Decoder do you think you can share the embeddings between the Encoder and Decoder in this model?

A: ## FINISH_ME ##