## Part 0: Setup and Imports

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# Import cGAN from demo folder
sys.path.insert(0, os.path.join(os.path.dirname("."), "Demo"))
from models.cgan import create_cgan_models
from cgan_training import ConditionalGANTrainer, initialize_weights

# CIFAR-10 class names
CIFAR10_CLASSES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

torch.manual_seed(42)
np.random.seed(42)

# Determine device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


In [None]:
# Visualization helpers 
def plot_loss_curves_cgan(results):
    """Plot discriminator and generator losses for cGAN."""
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))

    d_losses = np.array(results["d_losses"])
    g_losses = np.array(results["g_losses"])

    # Plot D loss
    axes[0].plot(d_losses, linewidth=2, color="navy", label="D Loss")
    axes[0].axhline(y=0.5, color="red", linestyle="--", alpha=0.5, label="Ideal (0.5)")
    axes[0].set_xlabel("Training Step", fontsize=11)
    axes[0].set_ylabel("BCE Loss", fontsize=11)
    axes[0].set_title(
        "Discriminator Loss Over Training", fontsize=12, fontweight="bold"
    )
    axes[0].grid(alpha=0.3)
    axes[0].legend()

    # Plot G loss
    axes[1].plot(g_losses, linewidth=2, color="darkgreen", label="G Loss")
    axes[1].set_xlabel("Training Step", fontsize=11)
    axes[1].set_ylabel("BCE Loss", fontsize=11)
    axes[1].set_title("Generator Loss Over Training", fontsize=12, fontweight="bold")
    axes[1].grid(alpha=0.3)
    axes[1].legend()

    plt.tight_layout()
    plt.show()

    # Print statistics
    print(f"D Loss - Mean: {d_losses.mean():.4f}, Std: {d_losses.std():.4f}")
    print(f"G Loss - Mean: {g_losses.mean():.4f}, Std: {g_losses.std():.4f}")


def visualize_class_grid(grid_images, class_names, title="cGAN Class Disentanglement"):
    """Visualize 10x10 class disentanglement grid."""
    grid_images_cpu = (grid_images.cpu() + 1) / 2
    grid_images_cpu = torch.clamp(grid_images_cpu, 0, 1)

    fig, axes = plt.subplots(10, 10, figsize=(16, 16))

    for class_id in range(10):
        for sample_id in range(10):
            idx = class_id * 10 + sample_id
            ax = axes[class_id, sample_id]
            img = grid_images_cpu[idx].cpu().permute(1, 2, 0).numpy()
            ax.imshow(img)
            ax.axis("off")
            if sample_id == 0:
                ax.set_ylabel(class_names[class_id], fontsize=10, fontweight="bold")

    plt.suptitle(
        f"{title}\n(Rows: Classes | Columns: Same Noise, Different Classes)",
        fontsize=14,
        fontweight="bold",
    )
    plt.tight_layout()
    plt.show()


def visualize_class_samples(class_samples, class_name, title="Generated Samples"):
    """Visualize 16 samples from a specific class."""
    class_samples_cpu = (class_samples.cpu() + 1) / 2
    class_samples_cpu = torch.clamp(class_samples_cpu, 0, 1)

    fig, axes = plt.subplots(2, 8, figsize=(16, 4))

    for i in range(16):
        ax = axes[i // 8, i % 8]
        img = class_samples_cpu[i].cpu().permute(1, 2, 0).numpy()
        ax.imshow(img)
        ax.axis("off")

    plt.suptitle(
        f'16 Generated Samples from Class "{class_name}"',
        fontsize=12,
        fontweight="bold",
    )
    plt.tight_layout()
    plt.show()


print(" Visualization helpers loaded")


---

## Part 1: Load CIFAR-10 Dataset

**TODO:** Load the CIFAR-10 dataset with appropriate transforms.

**Requirements:**
- Normalize images to [-1, 1] range (using mean=0.5, std=0.5)
- Create DataLoader with batch_size=64
- Enable shuffling for training

In [None]:
# TODO 1: Load CIFAR-10 dataset with transforms
# Step 1a: Define transforms (ToTensor + Normalize to [-1, 1])
transform = transforms.Compose(
    [
        # YOUR CODE HERE: Add ToTensor()
        # YOUR CODE HERE: Add Normalize with mean=0.5, std=0.5 for each channel
    ]
)

# Step 1b: Load training dataset
train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

# Step 1c: Create DataLoader
batch_size = 64
train_loader = DataLoader(
    # YOUR CODE HERE: Pass dataset and batch size
    # YOUR CODE HERE: Enable shuffling
    num_workers=0,  # MPS compatibility
)

print(f"Dataset loaded: {len(train_dataset)} training images")
print(f"DataLoader: {len(train_loader)} batches of size {batch_size}")


---

## Part 2: Create cGAN Models

**TODO:** Create the conditional generator and discriminator.

**Requirements:**
- Use `create_cgan_models()` function
- Set latent_dim=100 (noise vector dimension)
- Set num_classes=10 (CIFAR-10 has 10 classes)
- Initialize weights using DCGAN guidelines

In [None]:
# TODO 2: Create cGAN models
latent_dim = 100
num_classes = 10
label_dim = 50  # Label embedding dimension

# Step 2a: Create generator and discriminator
generator, discriminator = create_cgan_models(
    # YOUR CODE HERE: Pass latent_dim
    # YOUR CODE HERE: Pass num_classes
    # YOUR CODE HERE: Pass label_dim
    num_channels=3,
    # YOUR CODE HERE: Pass device
)

# Step 2b: Initialize weights
initialize_weights(generator)
initialize_weights(discriminator)

# Verify model parameters
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(
    f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}"
)


---

## Part 3: Create Trainer

**TODO:** Instantiate the ConditionalGANTrainer.

**Requirements:**
- Pass generator and discriminator
- Set lr_g=0.0002 (Generator learning rate)
- Set lr_d=0.0002 (Discriminator learning rate)
- Use Adam optimizer with beta1=0.5, beta2=0.999

In [None]:
# TODO 3: Create trainer
trainer = ConditionalGANTrainer(
    # YOUR CODE HERE: Pass generator
    # YOUR CODE HERE: Pass discriminator
    # YOUR CODE HERE: Pass device
    lr_g=0.0002,
    lr_d=0.0002,
    beta1=0.5,
    beta2=0.999,
)

print("Trainer initialized successfully")
print(f"Learning rates - G: 0.0002, D: 0.0002")


---

## Part 4: Train the cGAN

**TODO:** Train the model for 20 epochs.

**Requirements:**
- Use trainer.train() method
- Set num_epochs=20
- Pass latent_dim and num_classes
- Log progress every 50 batches

**Note:** This will take 20-30 minutes on GPU.

In [None]:
# TODO 4: Train the cGAN
num_epochs = 20

print(f"Training cGAN for {num_epochs} epochs...")
print("This will take 20-30 minutes on GPU\n")

results = trainer.train(
    # YOUR CODE HERE: Pass train_loader
    # YOUR CODE HERE: Pass num_epochs
    # YOUR CODE HERE: Pass latent_dim
    # YOUR CODE HERE: Pass num_classes
    log_interval=50,
)

print(f"\n Training complete after {num_epochs} epochs")


---

## Part 5: Plot Loss Curves

**TODO:** Visualize Generator and Discriminator losses during training.

**Requirements:**
- Create 1×2 subplot (D loss, G loss)
- Plot results['d_losses'] and results['g_losses']
- Add labels and titles
- Add reference line at y=0.5 (ideal discriminator loss)

In [None]:
# TODO 5: Plot loss curves
plot_loss_curves_cgan(results)

---

## Part 6: Generate 10×10 Class Disentanglement Grid

**TODO:** Generate a 10×10 grid showing all classes with the same noise vector.

**Requirements:**
- Use trainer.generate_all_classes_grid()
- Generate 10 classes × 10 samples per class = 100 images
- Same noise z, different class labels y
- Denormalize from [-1,1] to [0,1]
- Visualize as 10×10 subplot grid

In [None]:
# TODO 6: Generate 10x10 class disentanglement grid
print("Generating 10×10 class disentanglement grid...")

# Step 6a: Generate grid
grid_images = trainer.generate_all_classes_grid(
    # YOUR CODE HERE: Pass num_classes=10
    # YOUR CODE HERE: Pass samples_per_class=10
    # YOUR CODE HERE: Pass latent_dim
    shared_z=None,
)

# Step 6b: Visualize using helper
visualize_class_grid(
    grid_images, CIFAR10_CLASSES, title="cGAN Class Disentanglement: 10×10 Grid"
)

print(f"✓ Grid generated: {grid_images.shape}")


---

## Part 7: Generate Single-Class Samples

**TODO:** Generate 16 images of a specific class (e.g., dogs).

**Requirements:**
- Use trainer.generate_class_samples()
- Set target_class=5 (dogs)
- Generate num_samples=16
- Denormalize to [0,1]
- Display as 2×8 grid

In [None]:
# TODO 7: Generate single-class samples
target_class = 5  # dogs

print(f"Generating 16 samples of class '{CIFAR10_CLASSES[target_class]}'...")

# Step 7a: Generate samples
class_samples = trainer.generate_class_samples(
    # YOUR CODE HERE: Pass target_class
    # YOUR CODE HERE: Pass num_samples=16
    # YOUR CODE HERE: Pass latent_dim
)

# Step 7b: Visualize using helper
visualize_class_samples(class_samples, CIFAR10_CLASSES[target_class])

print(
    f"✓ Generated 16 samples from class {CIFAR10_CLASSES[target_class]} (class {target_class})"
)



## Part 8: Analyze Class Disentanglement Quality

**TODO:** Answer the following questions about the cGAN's class disentanglement.

**Analysis Questions:**

1. **Within-Row Variation:** Looking at the 10×10 grid, do images in the same row (same class) show style variation? What causes this variation?

2. **Across-Row Differences:** Are rows clearly different from each other? Can you easily tell which row is dogs vs cars just by looking?

3. **Class Control:** Do the class labels seem to have a strong effect on the generated images? How can you tell?

4. **Quality Issues:** Do any rows look wrong or blurry? Which classes are generated most successfully?

5. **Data Augmentation:** Would you use these generated images to augment a training dataset? Why or why not?

## Part 8 Analysis: Your Answers

**Question 1: Within-Row Variation**
```
YOUR ANSWER HERE
```

**Question 2: Across-Row Differences**
```
YOUR ANSWER HERE
```

**Question 3: Class Control**
```
YOUR ANSWER HERE
```

**Question 4: Quality Issues**
```
YOUR ANSWER HERE
```

**Question 5: Data Augmentation**
```
YOUR ANSWER HERE
```

---

## Part 9: Compare Generated vs Real Images (Optional)

**TODO:** Create side-by-side comparison of generated vs real images for each class.

**Requirements:**
- Generate one image per class
- Find one real image per class from dataset
- Display as 3 rows × 10 columns
- Row 1: Generated, Row 2: Real, Row 3: Assessment

In [None]:
# TODO 9 (Optional): Compare generated vs real images
print("Generating comparison grid...")

fig, axes = plt.subplots(3, 10, figsize=(18, 6))

for class_id in range(10):
    # Step 9a: Generate image for this class
    generated = trainer.generate_class_samples(
        # YOUR CODE HERE: Pass target_class=class_id
        # YOUR CODE HERE: Pass num_samples=1
        # YOUR CODE HERE: Pass latent_dim
    )
    generated_denorm = (generated.cpu() + 1) / 2
    generated_denorm = torch.clamp(generated_denorm, 0, 1)

    # Step 9b: Find real image for this class
    # YOUR CODE HERE: Find index of first image with class_id in train_dataset.targets
    idx = None  # YOUR CODE HERE
    real_img = (train_dataset[idx][0] + 1) / 2

    # Step 9c: Display generated image
    ax = axes[0, class_id]
    # YOUR CODE HERE: Display generated image
    ax.set_title(f"{CIFAR10_CLASSES[class_id]}", fontsize=9)
    ax.axis("off")
    if class_id == 0:
        ax.text(
            -0.3,
            0.5,
            "Generated",
            transform=ax.transAxes,
            fontsize=10,
            fontweight="bold",
            va="center",
        )

    # Step 9d: Display real image
    ax = axes[1, class_id]
    # YOUR CODE HERE: Display real image
    ax.axis("off")
    if class_id == 0:
        ax.text(
            -0.3,
            0.5,
            "Real",
            transform=ax.transAxes,
            fontsize=10,
            fontweight="bold",
            va="center",
        )

    # Step 9e: Add quality assessment
    ax = axes[2, class_id]
    ax.text(
        0.5,
        0.5,
        f"Quality:\nGood✓",
        ha="center",
        va="center",
        fontsize=8,
        bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.7),
    )
    ax.axis("off")
    if class_id == 0:
        ax.text(
            -0.3,
            0.5,
            "Assessment",
            transform=ax.transAxes,
            fontsize=10,
            fontweight="bold",
            va="center",
        )

plt.suptitle("Generated vs Real Images (One per Class)", fontsize=13, fontweight="bold")
plt.tight_layout()
plt.show()

print(" Comparison complete")


---

## Summary

### What You've Learned

 **Conditional GANs:** How to add class information to Generator and Discriminator

 **Label Conditioning:** Early concatenation (G) vs late concatenation (D)



### Key Takeaways

1. **Conditioning enables control:** cGANs let you generate specific classes on demand

2. **Disentanglement matters:** Good separation of class from style is crucial for quality

3. **Grid visualization is powerful:** The 10×10 grid clearly shows whether class labels work

4. **Practical uses:** Generate synthetic data for underrepresented classes, augment datasets

### Next Steps to Improve

+ Train for more epochs (50+) for better quality
+ Try different architectures (spectral norm, progressive training)
+ Compute metrics (FID, IS) for quantitative evaluation
+ Use generated images to train a classifier and measure improvement
+ Experiment with CIFAR-100 or ImageNet subsets
+ Compare with other conditional architectures