In [None]:
pip install torch torchvision numpy tqdm scipy


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
from scipy.linalg import sqrtm

In [3]:
# Set device to CPU
device = torch.device("cpu")

In [23]:
# Define Generator
class Generator(nn.Module):
    def __init__(self, latent_dim=64):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 3 * 16 * 16),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z).view(z.size(0), 3, 16, 16)
        return img

In [24]:
# Define Discriminator (Used for BCE-GAN & LS-GAN)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(3 * 16 * 16, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

In [25]:
# Define Wasserstein Discriminator (For WGAN)
class WDiscriminator(nn.Module):
    def __init__(self):
        super(WDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(3 * 16 * 16, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)


In [26]:
# DataLoader for CIFAR-10 (Reduced Size)
transform = transforms.Compose([
    transforms.Resize(16),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


In [27]:
dataset = torchvision.datasets.CIFAR10(root="./data", download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)


Files already downloaded and verified


In [28]:
# Initialize Models
latent_dim = 64
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
w_discriminator = WDiscriminator().to(device)

In [29]:
# Loss Functions
bce_loss = nn.BCEWithLogitsLoss()
ls_loss = nn.MSELoss()

In [30]:
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0005)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0005)
optimizer_WD = optim.RMSprop(w_discriminator.parameters(), lr=0.00005)


In [32]:
# Training Function
def train_gan(gan_type="BCE", epochs=50):
    for epoch in range(epochs):
        for real_imgs, _ in tqdm(dataloader, desc=f"{gan_type}-GAN Epoch {epoch+1}/{epochs}", leave=False):
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.shape[0]

            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)

            if gan_type in ["BCE", "LS"]:
                optimizer_D.zero_grad()
                real_logits = discriminator(real_imgs)
                fake_logits = discriminator(fake_imgs.detach())

                if gan_type == "BCE":
                    d_loss = (bce_loss(real_logits, torch.ones_like(real_logits)) +
                              bce_loss(fake_logits, torch.zeros_like(fake_logits))) / 2
                else:
                    d_loss = (ls_loss(real_logits, torch.ones_like(real_logits)) +
                              ls_loss(fake_logits, torch.zeros_like(fake_logits))) / 2

                d_loss.backward()
                optimizer_D.step()

                # Train Generator
                optimizer_G.zero_grad()
                gen_logits = discriminator(fake_imgs)
                g_loss = bce_loss(gen_logits, torch.ones_like(gen_logits)) if gan_type == "BCE" else ls_loss(gen_logits, torch.ones_like(gen_logits))
                g_loss.backward()
                optimizer_G.step()

            elif gan_type == "WGAN":
                optimizer_WD.zero_grad()
                real_logits = w_discriminator(real_imgs)
                fake_logits = w_discriminator(fake_imgs.detach())

                d_loss = -torch.mean(real_logits) + torch.mean(fake_logits)
                d_loss.backward()
                optimizer_WD.step()

                # Clip weights to enforce Lipschitz constraint
                for p in w_discriminator.parameters():
                    p.data.clamp_(-0.01, 0.01)

                # Train Generator less frequently
                if epoch % 2 == 0:
                    optimizer_G.zero_grad()
                    fake_logits = w_discriminator(fake_imgs)
                    g_loss = -torch.mean(fake_logits)
                    g_loss.backward()
                    optimizer_G.step()

        # Save sample images every 5 epochs
        if epoch % 5 == 0:
            save_image(fake_imgs[:25], f"{gan_type}_generated_{epoch}.png", nrow=5, normalize=True)

# Run all three GANs (Each Runs in ~2 min)
train_gan("BCE", epochs=50)
train_gan("LS", epochs=50)
train_gan("WGAN", epochs=50)




In [33]:
# Evaluation (FID Score Calculation)
def calculate_fid(real, fake):
    mu_real, sigma_real = np.mean(real, axis=0), np.cov(real, rowvar=False)
    mu_fake, sigma_fake = np.mean(fake, axis=0), np.cov(fake, rowvar=False)
    return np.real(np.sum((mu_real - mu_fake) ** 2) + np.trace(sigma_real + sigma_fake - 2 * sqrtm(sigma_real @ sigma_fake)))

z_eval = torch.randn(500, latent_dim).to(device)
fake_samples = generator(z_eval).detach().cpu().numpy().reshape(500, -1)
real_samples = next(iter(dataloader))[0].numpy().reshape(128, -1)

fid_score = calculate_fid(real_samples, fake_samples)
print(f"FID Score: {fid_score}")

FID Score: 672.5090311690965
