# Generative Models in PyTorch

This notebook demonstrates various generative models including Autoencoders (AEs), Variational Autoencoders (VAEs), and Generative Adversarial Networks (GANs).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Data Preparation

We'll use the MNIST dataset for demonstrating generative models.

In [None]:
# Create output directory
output_dir = "generative_models_outputs"
os.makedirs(output_dir, exist_ok=True)

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

dataloader = DataLoader(
    mnist_dataset, 
    batch_size=128, 
    shuffle=True,
    num_workers=2 if os.name == 'posix' else 0
)

# Image dimensions
IMG_SIZE = 28
IMG_CHANNELS = 1
FLAT_IMG_SIZE = IMG_SIZE * IMG_SIZE * IMG_CHANNELS

# Visualize some samples
samples, labels = next(iter(dataloader))
plt.figure(figsize=(10, 2))
for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.imshow(samples[i].squeeze(), cmap='gray')
    plt.title(f"Label: {labels[i]}")
    plt.axis('off')
plt.tight_layout()
plt.show()

## 2. Autoencoders (AE)

Autoencoders learn to compress data into a lower-dimensional representation and then reconstruct it.

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim=FLAT_IMG_SIZE, latent_dim=32):
        super(Autoencoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.ReLU(True),
            nn.Linear(128, latent_dim),
            nn.ReLU(True)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, input_dim),
            nn.Sigmoid()  # Output between 0 and 1
        )

    def forward(self, x):
        x_flat = x.view(x.size(0), -1)
        encoded = self.encoder(x_flat)
        decoded = self.decoder(encoded)
        decoded = decoded.view(x.size(0), IMG_CHANNELS, IMG_SIZE, IMG_SIZE)
        return decoded, encoded

### Training the Autoencoder

In [None]:
# Create model
ae_model = Autoencoder(latent_dim=64).to(device)
print(f"Autoencoder parameters: {sum(p.numel() for p in ae_model.parameters()):,}")

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(ae_model.parameters(), lr=1e-3)

# Training
num_epochs = 5
ae_losses = []

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch_idx, (imgs, _) in enumerate(dataloader):
        imgs = imgs.to(device)
        
        # Forward pass
        reconstructed, _ = ae_model(imgs)
        loss = criterion(reconstructed, imgs)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(dataloader)
    ae_losses.append(avg_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

### Visualizing Autoencoder Results

In [None]:
# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(ae_losses)
plt.title('Autoencoder Training Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.grid(True)
plt.show()

# Visualize reconstructions
ae_model.eval()
with torch.no_grad():
    test_imgs, _ = next(iter(dataloader))
    test_imgs = test_imgs[:8].to(device)
    reconstructed, _ = ae_model(test_imgs)
    
    # Plot original vs reconstructed
    fig, axes = plt.subplots(2, 8, figsize=(12, 4))
    for i in range(8):
        # Original
        axes[0, i].imshow(test_imgs[i].cpu().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('Original')
        
        # Reconstructed
        axes[1, i].imshow(reconstructed[i].cpu().squeeze(), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel('Reconstructed')
    
    plt.suptitle('Autoencoder Reconstruction')
    plt.tight_layout()
    plt.show()

## 3. Variational Autoencoders (VAE)

VAEs learn a probabilistic latent representation, allowing for generation of new samples.

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim=FLAT_IMG_SIZE, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, h_dim)
        self.fc_mu = nn.Linear(h_dim, z_dim)      # Mean of latent distribution
        self.fc_logvar = nn.Linear(h_dim, z_dim)  # Log variance of latent distribution
        
        # Decoder
        self.fc3 = nn.Linear(z_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim)
    
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))
    
    def forward(self, x):
        x_flat = x.view(x.size(0), -1)
        mu, logvar = self.encode(x_flat)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        x_recon = x_recon.view(x.size(0), IMG_CHANNELS, IMG_SIZE, IMG_SIZE)
        return x_recon, mu, logvar

# VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    # Reconstruction loss
    BCE = F.binary_cross_entropy(
        recon_x.view(-1, FLAT_IMG_SIZE), 
        x.view(-1, FLAT_IMG_SIZE), 
        reduction='sum'
    )
    
    # KL divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + KLD

### Training the VAE

In [None]:
# Create model
vae_model = VAE(z_dim=20).to(device)
print(f"VAE parameters: {sum(p.numel() for p in vae_model.parameters()):,}")

# Optimizer
vae_optimizer = optim.Adam(vae_model.parameters(), lr=1e-3)

# Training
num_epochs = 10
vae_losses = []

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch_idx, (imgs, _) in enumerate(dataloader):
        imgs = imgs.to(device)
        
        # Forward pass
        recon_imgs, mu, logvar = vae_model(imgs)
        loss = vae_loss(recon_imgs, imgs, mu, logvar)
        
        # Backward pass
        vae_optimizer.zero_grad()
        loss.backward()
        vae_optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(dataloader.dataset)
    vae_losses.append(avg_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

### Visualizing VAE Results

In [None]:
# Plot training loss
plt.figure(figsize=(8, 4))
plt.plot(vae_losses)
plt.title('VAE Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (BCE + KLD)')
plt.grid(True)
plt.show()

# Visualize reconstructions
vae_model.eval()
with torch.no_grad():
    test_imgs, _ = next(iter(dataloader))
    test_imgs = test_imgs[:8].to(device)
    recon_imgs, _, _ = vae_model(test_imgs)
    
    # Plot original vs reconstructed
    fig, axes = plt.subplots(2, 8, figsize=(12, 4))
    for i in range(8):
        # Original
        axes[0, i].imshow(test_imgs[i].cpu().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('Original')
        
        # Reconstructed
        axes[1, i].imshow(recon_imgs[i].cpu().squeeze(), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel('Reconstructed')
    
    plt.suptitle('VAE Reconstruction')
    plt.tight_layout()
    plt.show()

### Generating New Samples with VAE

In [None]:
# Generate new samples from random latent vectors
with torch.no_grad():
    # Sample from standard normal distribution
    z = torch.randn(16, 20).to(device)
    
    # Decode to generate images
    generated = vae_model.decode(z)
    generated = generated.view(16, IMG_CHANNELS, IMG_SIZE, IMG_SIZE)
    
    # Plot generated samples
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i in range(16):
        ax = axes[i // 4, i % 4]
        ax.imshow(generated[i].cpu().squeeze(), cmap='gray')
        ax.axis('off')
    
    plt.suptitle('VAE Generated Samples')
    plt.tight_layout()
    plt.show()

### Exploring the Latent Space

In [None]:
# Interpolation in latent space
with torch.no_grad():
    # Get two random images
    imgs, _ = next(iter(dataloader))
    img1 = imgs[0:1].to(device)
    img2 = imgs[1:2].to(device)
    
    # Encode to latent space
    mu1, _ = vae_model.encode(img1.view(1, -1))
    mu2, _ = vae_model.encode(img2.view(1, -1))
    
    # Interpolate between latent vectors
    interpolations = []
    for alpha in np.linspace(0, 1, 10):
        z_interp = (1 - alpha) * mu1 + alpha * mu2
        img_interp = vae_model.decode(z_interp)
        interpolations.append(img_interp)
    
    # Plot interpolations
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    for i, img in enumerate(interpolations):
        axes[i].imshow(img.view(IMG_SIZE, IMG_SIZE).cpu(), cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'{i/9:.1f}')
    
    plt.suptitle('Latent Space Interpolation')
    plt.tight_layout()
    plt.show()

## 4. Generative Adversarial Networks (GAN)

GANs consist of two networks competing against each other: a generator that creates fake data and a discriminator that tries to distinguish real from fake.

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=FLAT_IMG_SIZE):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),
            nn.Sigmoid()  # Output between 0 and 1
        )
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), IMG_CHANNELS, IMG_SIZE, IMG_SIZE)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_dim=FLAT_IMG_SIZE):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Probability of being real
        )
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

### Training the GAN

In [None]:
# Hyperparameters
z_dim = 100
lr = 0.0002
num_epochs = 20

# Create models
generator = Generator(z_dim=z_dim).to(device)
discriminator = Discriminator().to(device)

print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

# Loss and optimizers
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# For visualization
fixed_noise = torch.randn(64, z_dim).to(device)

In [None]:
# Training loop
g_losses = []
d_losses = []

for epoch in range(num_epochs):
    epoch_g_loss = 0.0
    epoch_d_loss = 0.0
    
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        
        # Labels for real and fake images
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        # Real images
        real_pred = discriminator(real_imgs)
        d_loss_real = adversarial_loss(real_pred, real_labels)
        
        # Fake images
        z = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = generator(z)
        fake_pred = discriminator(fake_imgs.detach())
        d_loss_fake = adversarial_loss(fake_pred, fake_labels)
        
        # Total discriminator loss
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        optimizer_D.step()
        
        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        
        # Generate fake images
        z = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = generator(z)
        
        # Generator wants discriminator to think fake images are real
        fake_pred = discriminator(fake_imgs)
        g_loss = adversarial_loss(fake_pred, real_labels)
        
        g_loss.backward()
        optimizer_G.step()
        
        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()
    
    # Average losses
    avg_g_loss = epoch_g_loss / len(dataloader)
    avg_d_loss = epoch_d_loss / len(dataloader)
    g_losses.append(avg_g_loss)
    d_losses.append(avg_d_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}] D_loss: {avg_d_loss:.4f}, G_loss: {avg_g_loss:.4f}")
    
    # Generate sample images every 5 epochs
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            generator.eval()
            fake_imgs = generator(fixed_noise)
            
            fig, axes = plt.subplots(8, 8, figsize=(8, 8))
            for i in range(64):
                ax = axes[i // 8, i % 8]
                ax.imshow(fake_imgs[i].cpu().squeeze(), cmap='gray')
                ax.axis('off')
            plt.suptitle(f'GAN Generated Images - Epoch {epoch+1}')
            plt.tight_layout()
            plt.show()
            
            generator.train()

### Visualizing GAN Training

In [None]:
# Plot training losses
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAN Training Losses')
plt.legend()
plt.grid(True)
plt.show()

### Generating Images with Trained GAN

In [None]:
# Generate a large batch of images
generator.eval()
with torch.no_grad():
    z = torch.randn(100, z_dim).to(device)
    generated_imgs = generator(z)
    
    # Plot a grid of generated images
    grid = make_grid(generated_imgs, nrow=10, normalize=True)
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.cpu().permute(1, 2, 0).squeeze(), cmap='gray')
    plt.title('100 GAN Generated Images')
    plt.axis('off')
    plt.show()

## 5. Comparison of Generative Models

Let's compare the different generative models we've trained.

In [None]:
# Generate samples from each model
fig, axes = plt.subplots(3, 8, figsize=(12, 6))

with torch.no_grad():
    # Autoencoder: Show reconstructions
    test_imgs, _ = next(iter(dataloader))
    test_imgs = test_imgs[:8].to(device)
    ae_recon, _ = ae_model(test_imgs)
    
    for i in range(8):
        axes[0, i].imshow(ae_recon[i].cpu().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_ylabel('AE\nRecon', rotation=0, labelpad=40)
    
    # VAE: Generate new samples
    z_vae = torch.randn(8, 20).to(device)
    vae_gen = vae_model.decode(z_vae).view(8, 1, 28, 28)
    
    for i in range(8):
        axes[1, i].imshow(vae_gen[i].cpu().squeeze(), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_ylabel('VAE\nGen', rotation=0, labelpad=40)
    
    # GAN: Generate new samples
    z_gan = torch.randn(8, z_dim).to(device)
    gan_gen = generator(z_gan)
    
    for i in range(8):
        axes[2, i].imshow(gan_gen[i].cpu().squeeze(), cmap='gray')
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_ylabel('GAN\nGen', rotation=0, labelpad=40)

plt.suptitle('Comparison of Generative Models')
plt.tight_layout()
plt.show()

## Summary

In this notebook, we've explored three fundamental types of generative models:

1. **Autoencoders (AE)**: Learn compressed representations and can reconstruct data
   - Good for dimensionality reduction and denoising
   - Cannot generate truly new samples

2. **Variational Autoencoders (VAE)**: Learn probabilistic latent representations
   - Can generate new samples by sampling from the latent space
   - Often produce slightly blurry results
   - Provide interpretable latent space

3. **Generative Adversarial Networks (GAN)**: Use adversarial training
   - Can produce very sharp, realistic images
   - Training can be unstable and requires careful tuning
   - No direct way to encode images to latent space

Each model has its strengths and is suited for different applications. Modern generative models often combine ideas from these fundamental approaches.