In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import os

In [35]:
# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 28
channels = 1
batch_size = 64
latent_dim = 100
lambda_gp = 10
n_critic = 5  # Train critic more frequently
epochs = 150  
lr = 0.0002

In [36]:
# Data Preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
data = datasets.MNIST("./data", train=True, download=True, transform=transform)
dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)


In [37]:
# Generator
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1 or classname.find("Linear") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, channels * img_size * img_size),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), channels, img_size, img_size)
        return img

generator = Generator().to(device)
generator.apply(weights_init_normal)


Generator(
  (model): Sequential(
    (0): Linear(in_features=100, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Linear(in_features=256, out_features=512, bias=True)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Linear(in_features=512, out_features=784, bias=True)
    (9): Tanh()
  )
)

In [38]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(channels * img_size * img_size, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

discriminator = Discriminator().to(device)
discriminator.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2, inplace=True)
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [39]:
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [40]:
# Gradient Penalty
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.size(0), 1).to(device)
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [41]:
# Training Loop
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        imgs = imgs.to(device)

        # Train Critic
        optimizer_D.zero_grad()
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        fake_imgs = generator(z)
        real_validity = discriminator(imgs)
        fake_validity = discriminator(fake_imgs)
        gradient_penalty = compute_gradient_penalty(discriminator, imgs, fake_imgs)
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
        d_loss.backward()
        optimizer_D.step()

        # Train Generator every n_critic steps
        if i % n_critic == 0:
            optimizer_G.zero_grad()
            fake_imgs = generator(z)
            g_loss = -torch.mean(discriminator(fake_imgs))
            g_loss.backward()
            optimizer_G.step()

    print(f"Epoch [{epoch + 1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

# Save models
os.makedirs("models", exist_ok=True)
torch.save(generator.state_dict(), "models/generator.pth")
torch.save(discriminator.state_dict(), "models/discriminator.pth")

Epoch [1/150] | D Loss: -2.0150 | G Loss: -7.6286
Epoch [2/150] | D Loss: -1.7162 | G Loss: -9.6233
Epoch [3/150] | D Loss: -1.6967 | G Loss: -9.2962
Epoch [4/150] | D Loss: -1.2508 | G Loss: -6.4426
Epoch [5/150] | D Loss: -2.0948 | G Loss: -3.3744
Epoch [6/150] | D Loss: -2.6375 | G Loss: -2.9998
Epoch [7/150] | D Loss: -2.9771 | G Loss: -1.6837
Epoch [8/150] | D Loss: -2.7262 | G Loss: -5.0573
Epoch [9/150] | D Loss: -2.4172 | G Loss: -3.0480
Epoch [10/150] | D Loss: -2.1630 | G Loss: -3.5768
Epoch [11/150] | D Loss: -2.3166 | G Loss: -4.2820
Epoch [12/150] | D Loss: -1.7793 | G Loss: -5.3115
Epoch [13/150] | D Loss: -1.7578 | G Loss: -4.3544
Epoch [14/150] | D Loss: -1.4586 | G Loss: -3.8592
Epoch [15/150] | D Loss: -1.6961 | G Loss: -4.1401
Epoch [16/150] | D Loss: -1.9963 | G Loss: -3.6713
Epoch [17/150] | D Loss: -1.6805 | G Loss: -4.8006
Epoch [18/150] | D Loss: -1.5447 | G Loss: -3.9769
Epoch [19/150] | D Loss: -1.1086 | G Loss: -4.9399
Epoch [20/150] | D Loss: -1.6587 | G Los

In [None]:
# Generate some random images 
z = torch.randn(5000, latent_dim).to(device)
generated_imgs = generator(z)

In [19]:
# Save Generated Images
os.makedirs("generated_images", exist_ok=True)
for i, img in enumerate(generated_imgs[:25]):
    img = img.cpu().detach().numpy().squeeze()
    img = ((img + 1) / 2) * 255  # Rescale to [0, 255]
    img = Image.fromarray(img.astype(np.uint8))
    img.save(f"generated_images/img_{i + 1}.png")