In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
latent_dim = 64
hidden_dim = 256
image_dim = 784  # 28x28
num_epochs = 100
batch_size = 64
lr = 0.0002

# Load MNIST Dataset with fewer samples for quick demonstration
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = torchvision.datasets.MNIST(root='./data',
                                   train=True,
                                   transform=transform,
                                   download=True)
# Use a subset of data for faster training
subset_size = 5000
subset_indices = torch.randperm(len(dataset))[:subset_size]
dataset = torch.utils.data.Subset(dataset, subset_indices)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, image_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

generator = Generator()
discriminator = Discriminator()

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

# Lists to store losses
g_losses = []
d_losses = []

# Fixed noise for visualization
fixed_noise = torch.randn(16, latent_dim)

def save_images(epoch):
    with torch.no_grad():
        fake_images = generator(fixed_noise).reshape(-1, 28, 28)
        plt.figure(figsize=(10, 2.5))
        for i in range(4):
            plt.subplot(1, 4, i+1)
            plt.imshow(fake_images[i].detach().numpy(), cmap='gray')
            plt.axis('off')
        plt.suptitle(f'Generated Images - Epoch {epoch}')
        plt.savefig(f'gan_epoch_{epoch}.png')
        plt.close()

# Training Loop
print("Starting training...")
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        batch_size = real_images.size(0)
        real_images = real_images.view(-1, image_dim)

        # Train Discriminator
        d_optimizer.zero_grad()
        label_real = torch.ones(batch_size, 1)
        label_fake = torch.zeros(batch_size, 1)

        output_real = discriminator(real_images)
        d_loss_real = criterion(output_real, label_real)

        noise = torch.randn(batch_size, latent_dim)
        fake_images = generator(noise)
        output_fake = discriminator(fake_images.detach())
        d_loss_fake = criterion(output_fake, label_fake)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        output_fake = discriminator(fake_images)
        g_loss = criterion(output_fake, label_real)
        g_loss.backward()
        g_optimizer.step()

    # Save losses
    g_losses.append(g_loss.item())
    d_losses.append(d_loss.item())

    # Save images at specific epochs
    if epoch in [0, 4, 9]:
        save_images(epoch)
        print(f'Epoch [{epoch}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

# Plot losses
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label='Generator Loss')
plt.plot(d_losses, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Losses Over Time')
plt.legend()
plt.savefig('gan_losses.png')
plt.close()

print("Training completed! Generated images and loss plot have been saved.")

100%|██████████| 9.91M/9.91M [00:00<00:00, 57.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.70MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.5MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.48MB/s]


Starting training...
Epoch [0/100], d_loss: 1.1621, g_loss: 0.7033
Epoch [4/100], d_loss: 0.3841, g_loss: 1.8596
Epoch [9/100], d_loss: 0.6528, g_loss: 2.9646
Training completed! Generated images and loss plot have been saved.
