# GAN Training Tutorial

This notebook demonstrates how to train a Generative Adversarial Network (GAN) using the generative AI training framework.

## 1. Setup and Imports

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

# Import our framework
from generative_ai.models.gan import DCGAN
from generative_ai.training.gan_trainer import GANTrainer
from generative_ai.config.gan_config import GANConfig
from generative_ai.data.datasets import SyntheticImageDataset
from generative_ai.data.transforms import get_image_transforms
from generative_ai.utils.helpers import set_seed
from generative_ai.utils.visualization import plot_samples, plot_gan_losses

# Set random seed for reproducibility
set_seed(42)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# Create configuration
config = GANConfig(
    batch_size=32,
    num_epochs=50,
    latent_dim=100,
    image_size=64,
    channels=3,
    generator_lr=0.0002,
    discriminator_lr=0.0002,
    device=str(device),
    seed=42,
    log_interval=50
)

print("Configuration:")
for key, value in config.to_dict().items():
    print(f"  {key}: {value}")

## 3. Create Dataset and DataLoader

In [None]:
# Create synthetic dataset for demonstration
transform = get_image_transforms(config.image_size, normalize=True, augment=True)
dataset = SyntheticImageDataset(
    size=2000,
    image_size=config.image_size,
    channels=config.channels,
    transform=transform
)

dataloader = DataLoader(
    dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=2
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of batches: {len(dataloader)}")

# Visualize some samples
sample_batch = next(iter(dataloader))
print(f"Sample batch shape: {sample_batch.shape}")

## 4. Create Model

In [None]:
# Create DCGAN model
model = DCGAN(
    latent_dim=config.latent_dim,
    image_size=config.image_size,
    channels=config.channels,
    g_features=config.generator_features,
    d_features=config.discriminator_features
)

print(f"Generator parameters: {sum(p.numel() for p in model.generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in model.discriminator.parameters()):,}")

# Test generation
with torch.no_grad():
    test_noise = model.generate_noise(4, device)
    test_samples = model.generator(test_noise)
    print(f"Generated sample shape: {test_samples.shape}")

## 5. Training

In [None]:
# Create trainer
trainer = GANTrainer(config, model, device)

print("Starting training...")
# Note: In a real scenario, you might want to reduce num_epochs for notebook execution
config.num_epochs = 10  # Reduced for notebook demo
trainer.train(dataloader)

print("Training completed!")

## 6. Results Visualization

In [None]:
# Plot training losses
plot_gan_losses(trainer.g_losses, trainer.d_losses)

# Generate some samples
model.eval()
with torch.no_grad():
    # Generate fake samples
    noise = model.generate_noise(16, device)
    fake_samples = model.generator(noise)
    
    # Get real samples
    real_samples = next(iter(dataloader))[:16]
    
    # Plot comparison
    plot_samples(real_samples, fake_samples, nrow=4)

## 7. Model Evaluation

In [None]:
# Generate a large batch for evaluation
model.eval()
with torch.no_grad():
    # Generate 100 samples
    noise = model.generate_noise(100, device)
    generated_samples = model.generator(noise)
    
    print(f"Generated {generated_samples.shape[0]} samples")
    print(f"Sample statistics:")
    print(f"  Mean: {generated_samples.mean().item():.4f}")
    print(f"  Std: {generated_samples.std().item():.4f}")
    print(f"  Min: {generated_samples.min().item():.4f}")
    print(f"  Max: {generated_samples.max().item():.4f}")

## 8. Save Model

In [None]:
# Save the trained model
import os
os.makedirs('saved_models', exist_ok=True)

torch.save({
    'generator_state_dict': model.generator.state_dict(),
    'discriminator_state_dict': model.discriminator.state_dict(),
    'config': config.to_dict(),
    'g_losses': trainer.g_losses,
    'd_losses': trainer.d_losses
}, 'saved_models/dcgan_model.pth')

print("Model saved to 'saved_models/dcgan_model.pth'")

## 9. Load and Test Saved Model

In [None]:
# Load the saved model
checkpoint = torch.load('saved_models/dcgan_model.pth', map_location=device)

# Create new model instance
loaded_model = DCGAN(
    latent_dim=config.latent_dim,
    image_size=config.image_size,
    channels=config.channels
)

# Load state dictionaries
loaded_model.generator.load_state_dict(checkpoint['generator_state_dict'])
loaded_model.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
loaded_model.to(device)

# Test generation with loaded model
loaded_model.eval()
with torch.no_grad():
    test_samples = loaded_model.generate_samples(8, device)
    print(f"Generated samples from loaded model: {test_samples.shape}")

print("Model loaded and tested successfully!")