## Import Required Libraries

In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Import cGAN from demo
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "Demo"))
from models.cgan import create_cgan_models
from cgan_training import ConditionalGANTrainer, initialize_weights

# CIFAR-10 class names
CIFAR10_CLASSES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

torch.manual_seed(42)
np.random.seed(42)

# Determine device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}\n")


## Solution 1: Load CIFAR-10 Dataset

In [None]:
print("=" * 80)
print("SOLUTION 1: Loading CIFAR-10 Dataset")
print("=" * 80)

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
)

train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)

batch_size = 64
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,  # MPS compatibility
)

print(f"✓ Dataset loaded: {len(train_dataset)} training images")
print(f"✓ DataLoader: {len(train_loader)} batches of size {batch_size}\n")


## Solution 2: Create cGAN Models

In [None]:
print("=" * 80)
print("SOLUTION 2: Creating cGAN Models")
print("=" * 80)

latent_dim = 100
num_classes = 10
label_dim = 50

generator, discriminator = create_cgan_models(
    latent_dim=latent_dim,
    num_classes=num_classes,
    label_dim=label_dim,
    num_channels=3,
    device=device,
)

# Initialize weights following DCGAN guidelines
initialize_weights(generator)
initialize_weights(discriminator)

print(
    f"✓ Generator created: {sum(p.numel() for p in generator.parameters()):,} parameters"
)
print(
    f"✓ Discriminator created: {sum(p.numel() for p in discriminator.parameters()):,} parameters\n"
)


## Solution 3: Create Trainer

In [None]:
print("=" * 80)
print("SOLUTION 3: Creating Trainer")
print("=" * 80)

trainer = ConditionalGANTrainer(
    generator=generator,
    discriminator=discriminator,
    device=device,
    lr_g=0.0002,
    lr_d=0.0002,
    beta1=0.5,
    beta2=0.999,
)

print("✓ Trainer initialized with:")
print("  - Learning rate (G): 0.0002")
print("  - Learning rate (D): 0.0002")
print("  - Beta1: 0.5, Beta2: 0.999\n")


## Solution 4: Train the cGAN

In [None]:
print("=" * 80)
print("SOLUTION 4: Training cGAN (20 epochs)")
print("=" * 80)
print("Note: This will take 20-30 minutes on GPU\n")

num_epochs = 20
results = trainer.train(
    train_loader=train_loader,
    num_epochs=num_epochs,
    latent_dim=latent_dim,
    num_classes=num_classes,
    log_interval=50,
)

print(f"✓ Training complete after {num_epochs} epochs\n")


## Solution 5: Plot Training Loss Curves

In [None]:
print("=" * 80)
print("SOLUTION 5: Plotting Loss Curves")
print("=" * 80)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Discriminator loss
ax = axes[0]
ax.plot(results["d_losses"], linewidth=1, color="navy", alpha=0.7, label="D Loss")
ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.5, label="Ideal (0.5)")
ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("BCE Loss", fontsize=11)
ax.set_title("Discriminator Loss Over Training\n(cGAN)", fontsize=12, fontweight="bold")
ax.grid(alpha=0.3)
ax.legend(fontsize=10)
ax.set_ylim([0, max(results["d_losses"][:100]) * 1.2])

# Generator loss
ax = axes[1]
ax.plot(results["g_losses"], linewidth=1, color="darkgreen", alpha=0.7, label="G Loss")
ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("BCE Loss", fontsize=11)
ax.set_title("Generator Loss Over Training\n(cGAN)", fontsize=12, fontweight="bold")
ax.grid(alpha=0.3)
ax.legend(fontsize=10)
ax.set_ylim([0, max(results["g_losses"][:100]) * 1.2])

plt.tight_layout()
plt.show()

# Print statistics
d_losses = np.array(results["d_losses"])
g_losses = np.array(results["g_losses"])
print(f"Discriminator Loss - Mean: {d_losses.mean():.4f}, Std: {d_losses.std():.4f}")
print(f"Generator Loss - Mean: {g_losses.mean():.4f}, Std: {g_losses.std():.4f}\n")


## Solution 6: Class Disentanglement Grid (10×10)

In [None]:
print("=" * 80)
print("SOLUTION 6: Generating 10×10 Class Disentanglement Grid")
print("=" * 80)

grid_images = trainer.generate_all_classes_grid(
    num_classes=10,
    samples_per_class=10,
    latent_dim=latent_dim,
    shared_z=None,
)

# Denormalize to [0, 1]
grid_images_cpu = (grid_images.cpu() + 1) / 2
grid_images_cpu = torch.clamp(grid_images_cpu, 0, 1)

# Visualize as 10×10 grid
fig, axes = plt.subplots(10, 10, figsize=(16, 16))
for class_id in range(10):
    for sample_id in range(10):
        idx = class_id * 10 + sample_id
        ax = axes[class_id, sample_id]
        img = grid_images_cpu[idx].permute(1, 2, 0).numpy()
        ax.imshow(img)
        ax.axis("off")
        if sample_id == 0:
            ax.set_ylabel(CIFAR10_CLASSES[class_id], fontsize=10, fontweight="bold")

plt.suptitle(
    "cGAN Class Disentanglement: 10×10 Grid\n(Rows: Classes | Columns: Same Noise, Different Classes)",
    fontsize=14,
    fontweight="bold",
    y=0.995,
)
plt.tight_layout()
plt.show()

print(f" Generated class grid: {grid_images.shape}")
print(f"  - 10 classes × 10 samples per class = 100 total images")
print(f"  - Same noise z, different class labels y\n")


## Solution 7: Generate Single-Class Samples

In [None]:
print("=" * 80)
print("SOLUTION 7: Generating Single-Class Samples")
print("=" * 80)

target_class = 5  # dogs
class_samples = trainer.generate_class_samples(
    target_class=target_class,
    num_samples=16,
    latent_dim=latent_dim,
)

# Denormalize
class_samples_cpu = (class_samples.cpu() + 1) / 2
class_samples_cpu = torch.clamp(class_samples_cpu, 0, 1)

# Visualize
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(16):
    ax = axes[i // 8, i % 8]
    img = class_samples_cpu[i].permute(1, 2, 0).numpy()
    ax.imshow(img)
    ax.axis("off")

plt.suptitle(
    f'16 Generated Samples from Class "{CIFAR10_CLASSES[target_class]}"',
    fontsize=12,
    fontweight="bold",
)
plt.tight_layout()
plt.show()

print(
    f" Generated 16 samples from class {CIFAR10_CLASSES[target_class]} (class {target_class})"
)
print(f"  - Each image generated with different noise z")
print(f"  - All images conditioned on same class label y={target_class}\n")


## Solution 8: Class Disentanglement Analysis

In [None]:
print("=" * 80)
print("SOLUTION 8: CLASS DISENTANGLEMENT ANALYSIS")
print("=" * 80)

print(
    """
 OBSERVATIONS FROM THE 10×10 GRID:

1. WITHIN-ROW VARIATION (Same class, different noise):
   ✓ Each row shows the SAME CLASS with DIFFERENT APPEARANCES
   ✓ Different poses, colors, orientations
   ✓ Example: Row of dogs shows dogs in different styles
   ✓ This variation comes from the noise vector z

2. ACROSS-ROW DIFFERENCES (Different classes):
   ✓ Each row is DISTINCTLY DIFFERENT from other rows
   ✓ Row 0 clearly different from Row 1, etc.
   ✓ Example: Dogs (row 5) look completely different from cars (row 1)
   ✓ This difference comes from the class label y

3. CLASS DISENTANGLEMENT SUCCESS METRICS:
   
   Strong Disentanglement (What we want):
   ✓ Each row belongs unmistakably to one class
   ✓ Rows show clear variety in appearance
   ✓ Across columns, basic structure preserved but class changes
   
   Weak Disentanglement (Mode Collapse):
   ✗ All samples look identical within a class
   ✗ Some classes are indistinguishable from others
   ✗ Generator ignores class label, only uses noise

4. COMPARING cGAN vs DCGAN:
   
   DCGAN (Unconditional):
   - No class control
   - Random mix of classes in output
   - Pure image generation
   
   cGAN (Conditional):
   - Class control through label y
   - Consistent class generation per condition
   - More structured, controllable output
"""
)
