In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os


In [None]:
# =========================
# Config
# =========================
BATCH_SIZE = 256
EPOCHS = 50
NOISE_DIM = 100
LR = 1e-4
BETA1 = 0.5
IMAGE_SIZE = 28
OUTPUT_DIR = "generated_images_pt"

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# =========================
# Dataset
# =========================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

In [None]:
# =========================
# Generator
# =========================
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(NOISE_DIM, 7 * 7 * 256, bias=False),
            nn.BatchNorm1d(7 * 7 * 256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Unflatten(1, (256, 7, 7)),

            nn.ConvTranspose2d(256, 128, 5, 1, 2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(128, 64, 5, 2, 2, output_padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.ConvTranspose2d(64, 1, 5, 2, 2, output_padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

In [None]:
# =========================
# Discriminator
# =========================
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 5, 2, 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Conv2d(64, 128, 5, 2, 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),

            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
G = Generator().to(device)
D = Discriminator().to(device)

In [None]:
# =========================
# Loss & Optimizers
# =========================
criterion = nn.BCEWithLogitsLoss()

optimizer_G = optim.Adam(G.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=LR, betas=(BETA1, 0.999))

In [None]:
# =========================
# Fixed Noise for Visualization
# =========================
fixed_noise = torch.randn(16, NOISE_DIM, device=device)

In [None]:
# =========================
# Image Generation
# =========================
def save_images(epoch):
    G.eval()
    with torch.no_grad():
        fake_images = G(fixed_noise).cpu()
    G.train()

    fake_images = (fake_images + 1) / 2  # [-1,1] â†’ [0,1]

    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(fake_images[i][0], cmap="gray")
        plt.axis("off")

    plt.savefig(os.path.join(OUTPUT_DIR, f"epoch_{epoch:04d}.png"))
    plt.close(fig)

In [None]:
# =========================
# Training Loop
# =========================
for epoch in range(1, EPOCHS + 1):
    for real_images, _ in dataloader:
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # =====================
        # Train Discriminator
        # =====================
        noise = torch.randn(batch_size, NOISE_DIM, device=device)
        fake_images = G(noise)

        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)

        D_real = D(real_images)
        D_fake = D(fake_images.detach())

        loss_D_real = criterion(D_real, real_labels)
        loss_D_fake = criterion(D_fake, fake_labels)
        loss_D = loss_D_real + loss_D_fake

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # =====================
        # Train Generator
        # =====================
        output = D(fake_images)
        loss_G = criterion(output, real_labels)

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

    save_images(epoch)

    print(
        f"Epoch [{epoch}/{EPOCHS}] "
        f"| D Loss: {loss_D.item():.4f} "
        f"| G Loss: {loss_G.item():.4f}"
    )

print("Training complete.")