<a href="https://colab.research.google.com/github/prajwalBirwadkar/GAN-Experiential-Learning/blob/main/GAN_Experiential_Learning_Tasks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Name : Prajwal Birwadkar**

**PRN: 24070149003**

In [1]:
!pip install torchmetrics
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
latent_dim = 100
img_size = 32
channels = 3
batch_size = 64
epochs = 50
sample_interval = 10  # Save generated images every 10 epochs

# Data loading and preprocessing for CIFAR-10
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)



Collecting torchmetrics
  Downloading torchmetrics-1.6.2-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.13.1-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->torchmetrics)
  D

100%|██████████| 170M/170M [00:13<00:00, 12.4MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [2]:
# Generator (shared across all GAN variants)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, channels, 3, padding=1),
            nn.Tanh()  # Output range: [-1, 1]
        )

    def forward(self, z):
        return self.model(z)

# Discriminator for BCE GAN
class DiscriminatorBCE(nn.Module):
    def __init__(self):
        super(DiscriminatorBCE, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1),
            nn.Sigmoid()  # Output probability
        )

    def forward(self, img):
        return self.model(img)

# Discriminator for LS-GAN (no sigmoid)
class DiscriminatorLS(nn.Module):
    def __init__(self):
        super(DiscriminatorLS, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1)  # Real-valued output
        )

    def forward(self, img):
        return self.model(img)



In [3]:
# Critic for WGAN (no batch norm, no sigmoid)
class CriticWGAN(nn.Module):
    def __init__(self):
        super(CriticWGAN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1)  # Real-valued output
        )

    def forward(self, img):
        return self.model(img)

# Function to save generated images
def save_generated_images(generator, epoch, latent_dim, num_images=64, filename='generated_images'):
    z = torch.randn(num_images, latent_dim).to(device)
    gen_imgs = generator(z)
    save_image(gen_imgs, f"{filename}_epoch_{epoch}.png", normalize=True)

# Training function for BCE GAN
def train_bce_gan(generator, discriminator, dataloader, epochs):
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            real = torch.ones(imgs.size(0), 1).to(device)
            fake = torch.zeros(imgs.size(0), 1).to(device)
            real_imgs = imgs.to(device)
            z = torch.randn(imgs.size(0), latent_dim).to(device)
            gen_imgs = generator(z)

            # Train Discriminator
            optimizer_D.zero_grad()
            real_loss = criterion(discriminator(real_imgs), real)
            fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
            d_loss = real_loss + fake_loss
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            g_loss = criterion(discriminator(gen_imgs), real)
            g_loss.backward()
            optimizer_G.step()

        print(f"[BCE GAN] Epoch {epoch}/{epochs}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
        if epoch % sample_interval == 0:
            save_generated_images(generator, epoch, latent_dim, filename='bce_gan_images')



In [4]:
# Training function for LS-GAN
def train_ls_gan(generator, discriminator, dataloader, epochs):
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            real = torch.ones(imgs.size(0), 1).to(device)
            fake = torch.zeros(imgs.size(0), 1).to(device)
            real_imgs = imgs.to(device)
            z = torch.randn(imgs.size(0), latent_dim).to(device)
            gen_imgs = generator(z)

            # Train Discriminator
            optimizer_D.zero_grad()
            real_loss = criterion(discriminator(real_imgs), real)
            fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
            d_loss = real_loss + fake_loss
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            g_loss = criterion(discriminator(gen_imgs), real)
            g_loss.backward()
            optimizer_G.step()

        print(f"[LS-GAN] Epoch {epoch}/{epochs}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
        if epoch % sample_interval == 0:
            save_generated_images(generator, epoch, latent_dim, filename='ls_gan_images')



In [5]:
# Training function for WGAN
def train_wgan(generator, critic, dataloader, epochs, n_critic=5, clip_value=0.01):
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    optimizer_C = optim.Adam(critic.parameters(), lr=0.0001, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (imgs, _) in enumerate(dataloader):
            real_imgs = imgs.to(device)
            z = torch.randn(imgs.size(0), latent_dim).to(device)
            gen_imgs = generator(z)

            # Train Critic
            for _ in range(n_critic):
                optimizer_C.zero_grad()
                real_loss = critic(real_imgs).mean()
                fake_loss = critic(gen_imgs.detach()).mean()
                c_loss = fake_loss - real_loss
                c_loss.backward()
                optimizer_C.step()
                # Weight clipping for Lipschitz constraint
                for p in critic.parameters():
                    p.data.clamp_(-clip_value, clip_value)

            # Train Generator
            optimizer_G.zero_grad()
            g_loss = -critic(gen_imgs).mean()
            g_loss.backward()
            optimizer_G.step()

        print(f"[WGAN] Epoch {epoch}/{epochs}, C Loss: {c_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
        if epoch % sample_interval == 0:
            save_generated_images(generator, epoch, latent_dim, filename='wgan_images')



In [7]:
# Main execution
if __name__ == "__main__":
    # Initialize models
    generator_bce = Generator().to(device)
    discriminator_bce = DiscriminatorBCE().to(device)
    generator_ls = Generator().to(device)
    discriminator_ls = DiscriminatorLS().to(device)
    generator_wgan = Generator().to(device)
    critic_wgan = CriticWGAN().to(device)

    # Train BCE GAN
    print("Training BCE GAN...")
    train_bce_gan(generator_bce, discriminator_bce, dataloader, epochs)
    torch.save(generator_bce.state_dict(), 'generator_bce.pth')



Training BCE GAN...
[BCE GAN] Epoch 0/50, D Loss: 1.7901, G Loss: 0.3826
[BCE GAN] Epoch 1/50, D Loss: 0.9169, G Loss: 1.5627
[BCE GAN] Epoch 2/50, D Loss: 0.9553, G Loss: 1.5259
[BCE GAN] Epoch 3/50, D Loss: 0.9091, G Loss: 1.7885
[BCE GAN] Epoch 4/50, D Loss: 0.8159, G Loss: 1.7274
[BCE GAN] Epoch 5/50, D Loss: 0.9142, G Loss: 2.3132
[BCE GAN] Epoch 6/50, D Loss: 0.8072, G Loss: 1.3929
[BCE GAN] Epoch 7/50, D Loss: 0.6925, G Loss: 1.7980
[BCE GAN] Epoch 8/50, D Loss: 0.8700, G Loss: 2.8023
[BCE GAN] Epoch 9/50, D Loss: 0.6832, G Loss: 1.8592
[BCE GAN] Epoch 10/50, D Loss: 0.8071, G Loss: 3.9306
[BCE GAN] Epoch 11/50, D Loss: 1.1767, G Loss: 1.0529
[BCE GAN] Epoch 12/50, D Loss: 1.0056, G Loss: 3.6858
[BCE GAN] Epoch 13/50, D Loss: 1.2526, G Loss: 1.2303
[BCE GAN] Epoch 14/50, D Loss: 0.6966, G Loss: 3.3356
[BCE GAN] Epoch 15/50, D Loss: 0.7845, G Loss: 2.3956
[BCE GAN] Epoch 16/50, D Loss: 0.4945, G Loss: 2.7821
[BCE GAN] Epoch 17/50, D Loss: 0.5503, G Loss: 2.7422
[BCE GAN] Epoch 18

In [8]:
    # Train LS-GAN
    print("Training LS-GAN...")
    train_ls_gan(generator_ls, discriminator_ls, dataloader, epochs)
    torch.save(generator_ls.state_dict(), 'generator_ls.pth')



Training LS-GAN...
[LS-GAN] Epoch 0/50, D Loss: 0.3399, G Loss: 0.5519
[LS-GAN] Epoch 1/50, D Loss: 0.4847, G Loss: 0.3407
[LS-GAN] Epoch 2/50, D Loss: 0.1509, G Loss: 1.3398
[LS-GAN] Epoch 3/50, D Loss: 0.3124, G Loss: 1.1296
[LS-GAN] Epoch 4/50, D Loss: 0.1637, G Loss: 0.9235
[LS-GAN] Epoch 5/50, D Loss: 0.1698, G Loss: 0.5785
[LS-GAN] Epoch 6/50, D Loss: 0.2135, G Loss: 1.5788
[LS-GAN] Epoch 7/50, D Loss: 0.4181, G Loss: 0.8506
[LS-GAN] Epoch 8/50, D Loss: 0.2281, G Loss: 0.5354
[LS-GAN] Epoch 9/50, D Loss: 0.1636, G Loss: 0.7851
[LS-GAN] Epoch 10/50, D Loss: 0.3046, G Loss: 0.8330
[LS-GAN] Epoch 11/50, D Loss: 0.1447, G Loss: 1.2773
[LS-GAN] Epoch 12/50, D Loss: 0.7511, G Loss: 0.0486
[LS-GAN] Epoch 13/50, D Loss: 0.2096, G Loss: 1.4310
[LS-GAN] Epoch 14/50, D Loss: 0.3126, G Loss: 0.9385
[LS-GAN] Epoch 15/50, D Loss: 0.3721, G Loss: 1.6163
[LS-GAN] Epoch 16/50, D Loss: 0.1552, G Loss: 0.8956
[LS-GAN] Epoch 17/50, D Loss: 0.2662, G Loss: 1.6505
[LS-GAN] Epoch 18/50, D Loss: 0.3481,

In [9]:
    # Train WGAN
    print("Training WGAN...")
    train_wgan(generator_wgan, critic_wgan, dataloader, epochs)
    torch.save(generator_wgan.state_dict(), 'generator_wgan.pth')



Training WGAN...
[WGAN] Epoch 0/50, C Loss: 0.5528, G Loss: 1.4260
[WGAN] Epoch 1/50, C Loss: 0.3721, G Loss: -2.8018
[WGAN] Epoch 2/50, C Loss: -0.0122, G Loss: -0.0229
[WGAN] Epoch 3/50, C Loss: 0.1817, G Loss: 2.2143
[WGAN] Epoch 4/50, C Loss: -0.0383, G Loss: 0.0366
[WGAN] Epoch 5/50, C Loss: -1.5126, G Loss: -10.6114
[WGAN] Epoch 6/50, C Loss: 0.3141, G Loss: 0.4430
[WGAN] Epoch 7/50, C Loss: -2.0881, G Loss: 8.1316
[WGAN] Epoch 8/50, C Loss: -3.3487, G Loss: 8.2016
[WGAN] Epoch 9/50, C Loss: -1.2108, G Loss: 8.5983
[WGAN] Epoch 10/50, C Loss: -0.9548, G Loss: 8.6378
[WGAN] Epoch 11/50, C Loss: -2.0376, G Loss: 9.6060
[WGAN] Epoch 12/50, C Loss: -1.1277, G Loss: 9.4575
[WGAN] Epoch 13/50, C Loss: -1.2971, G Loss: 11.1078
[WGAN] Epoch 14/50, C Loss: -0.2178, G Loss: 9.9282
[WGAN] Epoch 15/50, C Loss: -0.8622, G Loss: 13.7578
[WGAN] Epoch 16/50, C Loss: -3.0287, G Loss: -3.2560
[WGAN] Epoch 17/50, C Loss: -1.2163, G Loss: -11.4953
[WGAN] Epoch 18/50, C Loss: -0.0069, G Loss: 0.1502


In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import inception_v3
import numpy as np
from scipy.linalg import sqrtm

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained Inception V3
inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
inception_model.eval()

# Function to preprocess images for Inception V3
def preprocess_images(images):
    # Assuming images are in [-1, 1], convert to [0, 1]
    images = (images * 0.5 + 0.5).clamp(0, 1)
    # Resize to 299x299 as expected by Inception V3
    images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
    # Normalize to match Inception V3 input (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
    images = (images - mean) / std
    return images

# Compute Inception Score
def compute_inception_score_manual(generator, num_samples=1000, batch_size=50):
    generator.eval()
    preds = []
    with torch.no_grad():
        for _ in range(num_samples // batch_size):
            noise = torch.randn(batch_size, 100).to(device)  # Adjust 100 to your latent dim
            fake_images = generator(noise)
            fake_images = preprocess_images(fake_images)
            pred = inception_model(fake_images)
            preds.append(F.softmax(pred, dim=1).cpu().numpy())
    preds = np.concatenate(preds, axis=0)

    # Calculate IS
    scores = []
    for i in range(0, len(preds), batch_size):
        p = preds[i:i+batch_size]
        kl_div = p * (np.log(p + 1e-16) - np.log(np.mean(p, axis=0, keepdims=True) + 1e-16))
        scores.append(np.exp(np.mean(kl_div.sum(axis=1))))
    return np.mean(scores)

# Compute FID
def compute_fid_manual(generator, real_images, num_samples=1000, batch_size=50):
    generator.eval()

    # Get real image features
    with torch.no_grad():
        real_images = preprocess_images(real_images)
        real_features = inception_model(real_images).cpu().numpy()

    # Get fake image features
    fake_features = []
    with torch.no_grad():
        for _ in range(num_samples // batch_size):
            noise = torch.randn(batch_size, 100).to(device)  # Adjust 100 to your latent dim
            fake_images = generator(noise)
            fake_images = preprocess_images(fake_images)
            feat = inception_model(fake_images).cpu().numpy()
            fake_features.append(feat)
    fake_features = np.concatenate(fake_features, axis=0)

    # Calculate mean and covariance
    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu_fake, sigma_fake = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)

    # FID calculation
    diff = mu_real - mu_fake
    covmean = sqrtm(sigma_real.dot(sigma_fake))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
    return fid

# Example usage
real_images = next(iter(dataloader))[0].to(device)[:1000]  # Adjust size as needed

print("Evaluating BCE GAN...")
is_score_bce = compute_inception_score_manual(generator_bce)
fid_score_bce = compute_fid_manual(generator_bce, real_images)
print(f"BCE GAN - IS: {is_score_bce:.4f}, FID: {fid_score_bce:.4f}")

print("Evaluating LS-GAN...")
is_score_ls = compute_inception_score_manual(generator_ls)
fid_score_ls = compute_fid_manual(generator_ls, real_images)
print(f"LS-GAN - IS: {is_score_ls:.4f}, FID: {fid_score_ls:.4f}")

print("Evaluating WGAN...")
is_score_wgan = compute_inception_score_manual(generator_wgan)
fid_score_wgan = compute_fid_manual(generator_wgan, real_images)
print(f"WGAN - IS: {is_score_wgan:.4f}, FID: {fid_score_wgan:.4f}")

Evaluating BCE GAN...
BCE GAN - IS: 4.7519, FID: 720.1011
Evaluating LS-GAN...
LS-GAN - IS: 4.4756, FID: 706.1688
Evaluating WGAN...
WGAN - IS: 2.2496, FID: 1085.4244


Analysis

Inception Score (IS):
A higher IS generally indicates that the generated images are both diverse and of high quality. Here, the BCE GAN and LS-GAN have significantly higher scores than the WGAN, suggesting they produce more recognizable and varied images.

Fréchet Inception Distance (FID):
A lower FID implies that the generated images are closer in distribution to the real images. LS-GAN has the lowest FID (706.1688), with BCE GAN following closely, while WGAN's high FID indicates a larger discrepancy from the real image distribution.

Conclusion
BCE GAN and LS-GAN are performing comparably well in terms of both IS and FID, with LS-GAN slightly better on the FID metric.
WGAN appears to be underperforming compared to the other two, as reflected by both its lower IS and higher FID.