# GAN on MNIST (PyTorch)

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader

# CONFIG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 0.0002
batch_size = 128
epochs = 50  # Increased for better convergence
noise_dim = 100
real_label_val = 0.9  # Label Smoothing (Key improvement)

print(f"Device: {device}")
os.makedirs("results", exist_ok=True)

Device: cuda


In [2]:
# DATASET
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] for Tanh
])

dataloader = DataLoader(
    datasets.MNIST("./data", train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True
)

100%|██████████| 9.91M/9.91M [00:00<00:00, 18.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 504kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.72MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 16.0MB/s]


In [3]:
# MODELS
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 28*28),
            nn.Tanh() # Output range [-1, 1]
        )

    def forward(self, x):
        return self.net(x).view(-1, 1, 28, 28)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# INIT
G = Generator().to(device)
D = Discriminator().to(device)
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [4]:
# TRAINING
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        bs = real_imgs.size(0)

        # Labels (Smoothing applied to real labels)
        real_labels = torch.full((bs, 1), real_label_val, device=device)
        fake_labels = torch.zeros((bs, 1), device=device)

        # --- Train Discriminator ---
        opt_D.zero_grad()

        # Real loss
        out_real = D(real_imgs)
        loss_real = criterion(out_real, real_labels)

        # Fake loss
        z = torch.randn(bs, noise_dim, device=device)
        fake_imgs = G(z)
        out_fake = D(fake_imgs.detach()) # Detach to avoid G gradients
        loss_fake = criterion(out_fake, fake_labels)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        opt_D.step()

        # --- Train Generator ---
        opt_G.zero_grad()

        # Generator aims for real labels (1.0) to trick D
        # We use standard 1.0 here for G target, or reuse smoothed 0.9
        out_gen = D(fake_imgs)
        loss_G = criterion(out_gen, torch.ones((bs, 1), device=device))

        loss_G.backward()
        opt_G.step()

    print(f"Epoch {epoch+1}/{epochs} | Loss D: {loss_D.item():.4f} | Loss G: {loss_G.item():.4f}")

    # Save progress
    if (epoch+1) % 5 == 0:
        utils.save_image(fake_imgs, f"results/epoch_{epoch+1}.png", normalize=True)

# Final Save
z = torch.randn(64, noise_dim, device=device)
utils.save_image(G(z), "results/final_grid.png", normalize=True)
print("Training Complete. Images saved to 'results/' folder.")

Epoch 1/50 | Loss D: 1.7329 | Loss G: 2.4440
Epoch 2/50 | Loss D: 0.7515 | Loss G: 1.8923
Epoch 3/50 | Loss D: 1.1124 | Loss G: 1.1669
Epoch 4/50 | Loss D: 0.9086 | Loss G: 2.1398
Epoch 5/50 | Loss D: 1.0319 | Loss G: 1.3179
Epoch 6/50 | Loss D: 0.9729 | Loss G: 1.0665
Epoch 7/50 | Loss D: 0.9634 | Loss G: 1.3798
Epoch 8/50 | Loss D: 0.9208 | Loss G: 1.2606
Epoch 9/50 | Loss D: 0.8973 | Loss G: 1.3619
Epoch 10/50 | Loss D: 0.8658 | Loss G: 1.7379
Epoch 11/50 | Loss D: 1.0753 | Loss G: 1.5761
Epoch 12/50 | Loss D: 1.0968 | Loss G: 1.1694
Epoch 13/50 | Loss D: 1.1335 | Loss G: 1.2689
Epoch 14/50 | Loss D: 1.1684 | Loss G: 1.7083
Epoch 15/50 | Loss D: 1.1969 | Loss G: 0.9695
Epoch 16/50 | Loss D: 1.1457 | Loss G: 1.1068
Epoch 17/50 | Loss D: 1.2825 | Loss G: 1.6322
Epoch 18/50 | Loss D: 1.2308 | Loss G: 1.5759
Epoch 19/50 | Loss D: 1.2695 | Loss G: 1.2852
Epoch 20/50 | Loss D: 1.2600 | Loss G: 0.9970
Epoch 21/50 | Loss D: 1.2777 | Loss G: 0.9447
Epoch 22/50 | Loss D: 1.2429 | Loss G: 0.85