In [None]:
!nvidia-smi

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import os

# ===================== 1. PARAMETERS =====================
batch_size = 128       # Batch size
latent_dim = 100       # Size of the noise vector
image_size = 28        # MNIST image size
epochs = 10            # Number of epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

# ===================== 2. LOAD DATA =====================
# Transform: convert to tensor and normalize to [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# ===================== 3. MODEL ARCHITECTURES =====================
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 128, 7, 1, 0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()  # Output in the range [-1, 1]
        )

    def forward(self, x):
        return self.model(x)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()  # Output probability of being real
        )

    def forward(self, x):
        return self.model(x)

# ===================== 4. INITIALIZATION =====================
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Fixed noise for monitoring progress
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)

# Create a folder to save results
os.makedirs("mnist_gan_results", exist_ok=True)

# ===================== 5. TRAINING GAN =====================
# List to store generated images at each epoch
generated_images = []

for epoch in range(epochs):
    for i, (real_imgs, _) in enumerate(train_loader):
        # Prepare real and fake labels
        real_imgs = real_imgs.to(device)
        real_labels = torch.full((real_imgs.size(0), 1), 0.9, device=device)  # Label smoothing
        fake_labels = torch.zeros(real_imgs.size(0), 1, device=device)

        # === Train the Discriminator ===
        optimizer_D.zero_grad()
        noise = torch.randn(real_imgs.size(0), latent_dim, 1, 1, device=device)
        fake_imgs = generator(noise)
        
        # Loss for real and fake images
        real_loss = criterion(discriminator(real_imgs), real_labels)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()

        # === Train the Generator ===
        optimizer_G.zero_grad()
        fake_labels.fill_(1)  # Generator aims to "fool" the discriminator
        g_loss = criterion(discriminator(fake_imgs), fake_labels)
        g_loss.backward()
        optimizer_G.step()

    # Print losses for monitoring
    print(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
    
    # Generate and save images for this epoch
    with torch.no_grad():
        fake_imgs = generator(fixed_noise).detach().cpu()
        generated_images.append(fake_imgs)  # Store for visualization

print("Training complete!")

# ===================== 6. VISUALIZATION =====================
# Combine generated images into a grid with labels for each epoch
fig, axes = plt.subplots(2, 5, figsize=(15, 6))  # 2x5 grid for 10 epochs
fig.suptitle("Generated Images Across Epochs", fontsize=16)

for i, ax in enumerate(axes.flat):
    img_grid = vutils.make_grid(generated_images[i], nrow=8, normalize=True)
    ax.imshow(img_grid.permute(1, 2, 0))  # Convert to HWC format
    ax.axis("off")
    ax.set_title(f"Epoch {i+1}")

plt.tight_layout()
plt.subplots_adjust(top=0.88)
plt.show()
