## Import Required Libraries

In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

# Import improved training module
from improved_gan_training import (
    BaselineGAN,
    ComparisonTrainer,
    FeatureMatchingGAN,
    LabelSmoothingGAN,
)
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from models.basic_gan import create_generator, create_discriminator

## Setup Environment

In [None]:
print("Setting up environment...")
torch.manual_seed(42)

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}\n")


## Load Fashion MNIST Dataset

In [None]:
# Standard transforms: convert to tensor and normalize to [-1, 1]
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # mean=0.5, std=0.5 → range [-1, 1]
    ]
)

print("Loading Fashion MNIST dataset...")
train_dataset = datasets.FashionMNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

batch_size = 64
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

print(f"✓ Dataset loaded")
print(f"  Total images: {len(train_dataset)}")
print(f"  Batch size: {batch_size}")
print(f"  Batches per epoch: {len(train_loader)}\n")


## Create Comparison Trainer

In [None]:
print("Initializing comparison trainer...\n")
trainer = ComparisonTrainer(device=device)


## Train All Variants

This will train three GAN variants for comparison.

In [None]:
print("=" * 80)
print("TRAINING THREE GAN VARIANTS")
print("(This will take ~15-20 minutes depending on your device)")
print("=" * 80 + "\n")

results = trainer.train_all_variants(
    generator_class=create_generator,
    discriminator_class=create_discriminator,
    train_loader=train_loader,
    num_epochs=50,
    lr=0.0002,
    beta1=0.5,
)


## Visualize Loss Curves Comparison

In [None]:
print("\n" + "=" * 80)
print("VISUALIZING RESULTS")
print("=" * 80 + "\n")

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

colors = {
    "Baseline": "navy",
    "Label Smoothing": "orange",
    "Feature Matching": "green",
}
alphas = {"Baseline": 0.8, "Label Smoothing": 0.8, "Feature Matching": 0.8}

# Discriminator losses
ax = axes[0]
for variant_name, losses in results.items():
    d_losses = losses["d_losses"]
    ax.plot(
        d_losses,
        label=variant_name,
        color=colors.get(variant_name, "blue"),
        alpha=alphas.get(variant_name, 0.7),
        linewidth=1.5,
    )

ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.5, label="Ideal (0.5)")
ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("Discriminator Loss", fontsize=11)
ax.set_title("Discriminator Loss Comparison", fontsize=12, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(alpha=0.3)

# Generator losses
ax = axes[1]
for variant_name, losses in results.items():
    g_losses = losses["g_losses"]
    ax.plot(
        g_losses,
        label=variant_name,
        color=colors.get(variant_name, "blue"),
        alpha=alphas.get(variant_name, 0.7),
        linewidth=1.5,
    )

ax.set_xlabel("Training Step", fontsize=11)
ax.set_ylabel("Generator Loss", fontsize=11)
ax.set_title("Generator Loss Comparison", fontsize=12, fontweight="bold")
ax.legend(fontsize=10)
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()


## Compute Stability Metrics

In [None]:
print("\nComputing stability metrics...\n")
metrics = trainer.get_stability_metrics()

print("Stability Metrics (Last 100 Batches):")
print("=" * 80)
for variant_name, metric_dict in metrics.items():
    print(f"\n{variant_name}:")
    print(f"  Discriminator Loss:")
    print(f"    Average: {metric_dict['avg_d_loss']:.4f}")
    print(f"    Std Dev: {metric_dict['std_d_loss']:.4f}")
    print(f"    Oscillation: {metric_dict['d_oscillation']:.6f}")
    print(f"  Generator Loss:")
    print(f"    Average: {metric_dict['avg_g_loss']:.4f}")
    print(f"    Std Dev: {metric_dict['std_g_loss']:.4f}")
    print(f"    Oscillation: {metric_dict['g_oscillation']:.6f}")


## Comparison Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

variants = list(metrics.keys())
x_pos = np.arange(len(variants))
width = 0.35

# Average D Loss
ax = axes[0, 0]
d_losses = [metrics[v]["avg_d_loss"] for v in variants]
bars = ax.bar(
    x_pos,
    d_losses,
    color=[colors.get(v, "blue") for v in variants],
    alpha=0.7,
    width=width,
)
ax.axhline(y=0.5, color="gray", linestyle="--", alpha=0.5, label="Ideal (0.5)")
ax.set_ylabel("Average D Loss", fontsize=10)
ax.set_title(
    "Average Discriminator Loss (Last 100 Batches)", fontsize=11, fontweight="bold"
)
ax.set_xticks(x_pos)
ax.set_xticklabels(variants, rotation=15, ha="right")
ax.legend()
ax.grid(alpha=0.3, axis="y")

# D Loss Stability (Std Dev)
ax = axes[0, 1]
d_stds = [metrics[v]["std_d_loss"] for v in variants]
bars = ax.bar(
    x_pos,
    d_stds,
    color=[colors.get(v, "blue") for v in variants],
    alpha=0.7,
    width=width,
)
ax.set_ylabel("Std Dev of D Loss", fontsize=10)
ax.set_title("D Loss Stability (Lower is Better)", fontsize=11, fontweight="bold")
ax.set_xticks(x_pos)
ax.set_xticklabels(variants, rotation=15, ha="right")
ax.grid(alpha=0.3, axis="y")

# Average G Loss
ax = axes[1, 0]
g_losses = [metrics[v]["avg_g_loss"] for v in variants]
bars = ax.bar(
    x_pos,
    g_losses,
    color=[colors.get(v, "blue") for v in variants],
    alpha=0.7,
    width=width,
)
ax.set_ylabel("Average G Loss", fontsize=10)
ax.set_title(
    "Average Generator Loss (Last 100 Batches)", fontsize=11, fontweight="bold"
)
ax.set_xticks(x_pos)
ax.set_xticklabels(variants, rotation=15, ha="right")
ax.grid(alpha=0.3, axis="y")

# Loss Oscillation
ax = axes[1, 1]
d_osc = [metrics[v]["d_oscillation"] for v in variants]
bars = ax.bar(
    x_pos,
    d_osc,
    color=[colors.get(v, "blue") for v in variants],
    alpha=0.7,
    width=width,
)
ax.set_ylabel("D Loss Oscillation", fontsize=10)
ax.set_title(
    "Loss Oscillation (Lower = Smoother Training)", fontsize=11, fontweight="bold"
)
ax.set_xticks(x_pos)
ax.set_xticklabels(variants, rotation=15, ha="right")
ax.grid(alpha=0.3, axis="y")

plt.tight_layout()
plt.show()


## Detailed Comparison Report

In [None]:
trainer.print_comparison_report()


## Analysis and Recommendations

In [None]:
print("\n" + "=" * 80)
print("ANALYSIS AND RECOMMENDATIONS")
print("=" * 80)

print("\nANSWERS TO KEY QUESTIONS:")
print("-" * 80)

print("\n1. Which technique produced the most stable training?")
print("   Looking at Std Dev of D Loss (lowest is best)...")
best_stability = min(metrics.items(), key=lambda x: x[1]["std_d_loss"])
print(f"   WINNER: {best_stability[0]}")
print(f"     Std Dev: {best_stability[1]['std_d_loss']:.6f}")
print("   Why: Lower variance means more consistent loss curve → better learning")

print("\n2. Which technique kept D loss closest to the ideal 0.5?")
print("   Looking for balance (not too good, not too bad)...")
best_balance = min(metrics.items(), key=lambda x: abs(x[1]["avg_d_loss"] - 0.5))
print(f"    WINNER: {best_balance[0]}")
print(f"     Average D Loss: {best_balance[1]['avg_d_loss']:.4f}")
print("   Why: Balanced loss indicates good G-D competition")

print("\n3. Which technique showed the smoothest training (least oscillation)?")
print("   Looking at D Loss Oscillation (lowest is best)...")
best_smooth = min(metrics.items(), key=lambda x: x[1]["d_oscillation"])
print(f"    WINNER: {best_smooth[0]}")
print(f"     Oscillation: {best_smooth[1]['d_oscillation']:.6f}")
print("   Why: Smooth training helps avoid mode collapse and instability")
