In [1]:
import torch
import numpy as np
import cv2
import os
import random as rand
import torchvision
import pandas as pd
from tqdm import tqdm
from torch import nn, Tensor
import matplotlib.pyplot as plt
from typing import Optional
from torch.nn import functional as F
from torchvision.transforms import v2 as T
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from math import ceil
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CIFAR(Dataset):
    def __init__(self, path="/scratch/s25090/archive/cifar-10/train", dataset:Optional[list]=None):
        super().__init__()
        self.path = path
        self.files = os.listdir(self.path) if dataset is None else dataset
        self.T = T.Compose([
           T.ToImage(), 
           T.ToDtype(torch.float32, scale=True),
           T.Resize((32, 32)),
           T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self,):
        return len(self.files)
    
    def __getitem__(self, idx):
        file = self.files[idx]
        img_path = os.path.join(self.path, file)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.T(img)
        return img

In [3]:
class GenBlock(nn.Module):
    def __init__(self, in_channel, out_channel, is_final):
        super().__init__()
        layers = [
            nn.Conv2d(in_channel, (out_channel+in_channel)//2, 3, 1, 1),
            nn.BatchNorm2d((out_channel+in_channel)//2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d((in_channel+out_channel)//2, out_channel, 3, 1, 1)
        ]
        
        if not is_final:
            layers.append(nn.BatchNorm2d(out_channel))
            layers.append(nn.LeakyReLU(0.2, inplace=True))

        layers.append(nn.UpsamplingNearest2d(scale_factor=2))
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)

class DisBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, (out_channel+in_channel)//2, 3, 1, 1),
            nn.BatchNorm2d((out_channel+in_channel)//2),
            nn.LeakyReLU(0.2),
            nn.Conv2d((out_channel+in_channel)//2, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2)
        )
    
    def forward(self, x):
        return self.layer(x)

class ResGenBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        )
        
        self.shortcut = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
        )

    def forward(self, x):
        return self.conv_block(x) + self.shortcut(x)

In [4]:
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super().__init__()
        self.initial_linear = nn.Linear(z_dim, 1024 * 4 * 4)
        
        self.net = nn.Sequential(
            GenBlock(1024, 512, is_final=False), 
            GenBlock(512, 256, is_final=False),  
            GenBlock(256, 128,  is_final=False), 
            GenBlock(128,  64,  is_final=False),
            GenBlock(64,  64,  is_final=True),
        )
        
        self.final_layer = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        if len(z.shape) > 2:
            z = z.view(z.size(0), -1)
            
        x = self.initial_linear(z)
        x = x.view(-1, 1024, 4, 4)
        x = self.net(x)
        return self.final_layer(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            DisBlock(3, 32),   
            DisBlock(32, 64),
            DisBlock(64, 128),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 1) 
        )

    def forward(self, x):
        x = self.net(x)
        return self.classifier(x)

class ResNetGenerator(nn.Module):
    def __init__(self, z_dim=100, base_channels=256):
        super().__init__()
        self.linear = nn.Linear(z_dim, 4 * 4 * base_channels)
        self.base_channels = base_channels

        self.blocks = nn.Sequential(
            ResGenBlock(base_channels, base_channels),    
            ResGenBlock(base_channels, base_channels // 2), 
            ResGenBlock(base_channels // 2, base_channels // 4),
        )
        
        self.final_layer = nn.Sequential(
            nn.BatchNorm2d(base_channels // 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels // 4, 3, 3, 1, 1), 
            nn.Tanh() 
        )

    def forward(self, z):
        if z.ndim > 2: z = z.view(z.size(0), -1)
            
        x = self.linear(z)
        x = x.view(-1, self.base_channels, 4, 4)
        x = self.blocks(x)
        return self.final_layer(x)

class WGANModel(nn.Module):
    def __init__(self, z_dim=100, is_res=True):
        super().__init__()
        self.generator = Generator(z_dim) if not is_res else ResNetGenerator()
        self.discriminator = Discriminator()
        self.z_dim = z_dim

    def forward(self, z):
        return self.generator(z)

    def compute_discriminator_loss(self, real_imgs, z):
        with torch.no_grad():
            fake_imgs = self.generator(z).detach()

        real_logits = self.discriminator(real_imgs)
        fake_logits = self.discriminator(fake_imgs)

        d_loss = -(torch.mean(real_logits) - torch.mean(fake_logits))
        
        return d_loss

    def compute_generator_loss(self, z):
        fake_imgs = self.generator(z)
        
        fake_logits = self.discriminator(fake_imgs)
        
        g_loss = -torch.mean(fake_logits)
        
        return g_loss, fake_imgs

In [None]:
LEARNING_RATE = 5e-5        
WEIGHT_CLIP = 0.01
N_CRITIC = 5  
DEVICE = 'cuda:3' if torch.cuda.is_available() else 'cpu'
epochs = 300

gan_model = WGANModel()
gan_model = gan_model.to(DEVICE)

opt_gen = torch.optim.RMSprop(gan_model.generator.parameters(), lr=LEARNING_RATE)
opt_dis = torch.optim.RMSprop(gan_model.discriminator.parameters(), lr=LEARNING_RATE)

train_dataset = CIFAR()
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

gen_loss_list = []
dis_loss_list = []

for epoch in range(epochs):
    gan_model.train()
    tqdm_data = tqdm(train_loader, desc=f"Epoch-{epoch+1}/{epochs}")
    
    batch_gen_loss = 0
    batch_dis_loss = 0
    
    for batch_idx, (real_img) in enumerate(tqdm_data):
        real_img = real_img.to(DEVICE)
        bs = real_img.size(0)

        for param in gan_model.discriminator.parameters():
            param.requires_grad = True

        z_dis = torch.randn(bs, 100).to(DEVICE)
        opt_dis.zero_grad()
        
        dis_loss = gan_model.compute_discriminator_loss(real_img, z_dis)
        dis_loss.backward()
        opt_dis.step()

        for p in gan_model.discriminator.parameters():
            p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
        
        batch_dis_loss += dis_loss.item()

        if batch_idx % N_CRITIC == 0:
            for param in gan_model.discriminator.parameters():
                param.requires_grad = False # Freeze D to save computation
            
            z = torch.randn(bs, 100).to(DEVICE)
            opt_gen.zero_grad()
            
            gen_loss, fake_img = gan_model.compute_generator_loss(z)
            gen_loss.backward()
            opt_gen.step()
            
            batch_gen_loss += gen_loss.item()
            current_gen_loss = gen_loss.item() # For tqdm
        else:
            current_gen_loss = batch_gen_loss / (batch_idx + 1) if batch_idx > 0 else 0

        tqdm_data.set_postfix({
            "GenLoss": current_gen_loss,
            "DisLoss": dis_loss.item()
        })

    avg_gen_loss = batch_gen_loss / (len(train_loader) / N_CRITIC)
    avg_dis_loss = batch_dis_loss / len(train_loader)
    
    gen_loss_list.append(avg_gen_loss)
    dis_loss_list.append(avg_dis_loss)

    print(f"Generator Loss: {avg_gen_loss:.4f}\nDiscriminator Loss: {avg_dis_loss:.4f}")

    if (epoch + 1) % 10 == 0:
        torch.save(gan_model.state_dict(), f'/scratch/s25090/wgan_outputs/weights/Experiment1/gan_epoch_{epoch+1}.pth')
        gan_model.eval()
        with torch.no_grad():
            test_z = torch.randn(8, 100).to(DEVICE)
            gan_image = gan_model.generator(test_z)
            
            comparison = torch.cat([real_img[:8], gan_image[:8]], dim=0)
            grid = make_grid(comparison.cpu(), nrow=8, padding=2, normalize=True)
            
            plt.figure(figsize=(12, 4))
            plt.imshow(grid.permute(1, 2, 0))
            plt.axis('off')
            plt.title(f'Top: Original | Bottom: Generated Image (Epoch {epoch+1})')
            plt.savefig(f"/scratch/s25090/wgan_outputs/plots/Experiment1/Epoch-{epoch+1}.png")
            plt.close()

plt.figure(figsize=(10, 5))
plt.title("Generator vs Discriminator Loss (WGAN)")
plt.plot(gen_loss_list, label="Generator")
plt.plot(dis_loss_list, label="Discriminator")
plt.xlabel("Epochs")
plt.ylabel("Wasserstein Loss")
plt.savefig(f"/scratch/s25090/wgan_outputs/plots/Experiment1_loss.png")
plt.legend()
plt.show()

In [None]:
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
    """
    Calculates the gradient penalty loss for WGAN GP.
    """
    alpha = torch.rand((real_samples.size(0), 1, 1, 1)).to(device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    
    d_interpolates = critic(interpolates)
    fake = torch.ones(d_interpolates.shape).to(device)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    
    gradient_norm = gradients.norm(2, dim=1)
    
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty

In [None]:
LEARNING_RATE = 1e-4 
BATCH_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 100
CRITIC_ITERATIONS = 5
LAMBDA_GP = 10
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"
WEIGHT_PATH = "/scratch/s25090/wgan_GP_outputs/weights/Experiment1"
PLOT_PATH = "/scratch/s25090/wgan_GP_outputs/plots/Experiment1"
os.makedirs(WEIGHT_PATH, exist_ok=True)
os.makedirs(PLOT_PATH, exist_ok=True)
gan_model = WGANModel(z_dim=Z_DIM).to(DEVICE)
gan_model.train()

opt_gen = torch.optim.Adam(gan_model.generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = torch.optim.Adam(gan_model.discriminator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
train_dataset = CIFAR()
loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print("Starting WGAN-GP Training...")
gen_loss_list = []
dis_loss_list = []
for epoch in range(NUM_EPOCHS):
    loop = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    batch_gen_loss = 0
    batch_dis_loss = 0
    
    for batch_idx, (real) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
        fake = gan_model.generator(noise)
        
        critic_real = gan_model.discriminator(real).reshape(-1)
        critic_fake = gan_model.discriminator(fake).reshape(-1)
        
        loss_critic_wasserstein = -(torch.mean(critic_real) - torch.mean(critic_fake))
        
        gp = compute_gradient_penalty(gan_model.discriminator, real, fake, DEVICE)
        
        loss_critic = loss_critic_wasserstein + (LAMBDA_GP * gp)
        
        opt_critic.zero_grad()
        loss_critic.backward(retain_graph=True)
        opt_critic.step()
        batch_dis_loss += loss_critic.item()
        batch_gen_loss += -torch.mean(critic_fake).item()
        if batch_idx % CRITIC_ITERATIONS == 0:
            gen_fake = gan_model.discriminator(fake).reshape(-1)
            loss_gen = -torch.mean(gen_fake)
            
            opt_gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()
            
            loop.set_postfix(
                loss_critic=loss_critic.item(),
                loss_gen=loss_gen.item(),
                gp=gp.item()
            )

    avg_gen_loss = batch_gen_loss / (len(loader) / CRITIC_ITERATIONS)
    avg_dis_loss = batch_dis_loss / len(loader)
    
    gen_loss_list.append(avg_gen_loss)
    dis_loss_list.append(avg_dis_loss)

    print(f"Generator Loss: {avg_gen_loss:.4f}\nDiscriminator Loss: {avg_dis_loss:.4f}")

    if (epoch + 1) % 10 == 0:
        torch.save(gan_model.state_dict(), f"{WEIGHT_PATH}/gan_epoch_{epoch+1}.pth")
        gan_model.eval()
        with torch.no_grad():
            test_z = torch.randn(8, 100).to(DEVICE)
            gan_image = gan_model.generator(test_z)
            
            comparison = torch.cat([real[:8], gan_image[:8]], dim=0)
            grid = make_grid(comparison.cpu(), nrow=8, padding=2, normalize=True)
            
            plt.figure(figsize=(12, 4))
            plt.imshow(grid.permute(1, 2, 0))
            plt.axis('off')
            plt.title(f'Top: Original | Bottom: Generated Image (Epoch {epoch+1})')
            plt.savefig(f"{PLOT_PATH}/Epoch-{epoch+1}.png")
            plt.close()

plt.figure(figsize=(10, 5))
plt.title("Generator vs Discriminator Loss (WGAN)")
plt.plot(gen_loss_list, label="Generator")
plt.plot(dis_loss_list, label="Discriminator")
plt.xlabel("Epochs")
plt.ylabel("Wasserstein Loss")
plt.savefig(f"{PLOT_PATH}/Experiment1_loss.png")
plt.legend()
plt.show()

In [8]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
wgan_model = WGANModel()
wgangp_model = WGANModel()
wgan_model = wgan_model.to(DEVICE)
wgangp_model = wgangp_model.to(DEVICE)
WEIGHT_GP = torch.load("wgan_GP_outputs/gan_epoch_100.pth")
WEIGHT = torch.load("wgan_outputs/gan_epoch_300.pth")
wgangp_model.load_state_dict(WEIGHT_GP)
wgan_model.load_state_dict(WEIGHT) 
test_loader = DataLoader(CIFAR(), batch_size=128, shuffle=False)

In [10]:
def get_evaluation_metrics(generator, dataloader, device, num_imgs=10000):
    """
    Calculates FID and IS for a GAN generator.
    
    Args:
        generator: The GAN generator model.
        dataloader: DataLoader for real images (needed for FID reference).
        device: 'cuda' or 'cpu'.
        num_imgs: Number of images to generate/use for calculation.
                  (Standard for papers is 50k, but 10k is faster for debugging).
    
    Returns:
        fid_score (float), is_score (float)
    """
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
    inception = InceptionScore(normalize=True).to(device)
    
    generator.eval()
    
    print(f"--- Computing Metrics (Samples: {num_imgs}) ---")
    
    real_count = 0
    for batch in tqdm(dataloader, desc="Processing Real Images"):
        batch = batch.to(device)
        remaining = num_imgs - real_count
        if remaining <= 0:
            break
            
        if batch.shape[0] > remaining:
            batch = batch[:remaining]
        if batch.min() < 0:
            batch = (batch + 1) / 2  # Now [0, 1]
            
        fid.update((batch * 255).to(torch.uint8), real=True)
        real_count += batch.shape[0]

    fake_count = 0
    while fake_count < num_imgs:
        batch_size = min(dataloader.batch_size, num_imgs - fake_count)
        
        z = torch.randn(batch_size, 100).to(device)
        
        with torch.no_grad():
            fake_imgs = generator(z)
        fake_imgs = (fake_imgs + 1) / 2
        fake_uint8 = (fake_imgs * 255).to(torch.uint8)
        
        fid.update(fake_uint8, real=False)
        inception.update(fake_uint8)
        
        fake_count += batch_size
        
    print("Finalizing calculations...")
    fid_score = fid.compute().item()
    is_score_mean, is_score_std = inception.compute()
    
    return fid_score, is_score_mean.item()
print("Evaluating WGAN...")
fid_wgan, is_score_wgan = get_evaluation_metrics(
    wgan_model.generator, 
    test_loader, 
    DEVICE, 
    num_imgs=2000
)
print(f"FID: {fid_wgan:.4f} | IS: {is_score_wgan:.4f}")

print("Evaluating WGAN-GP...")
fig_wgan_gp, is_score_wgan_gp = get_evaluation_metrics(
    wgangp_model.generator, 
    test_loader, 
    DEVICE, 
    num_imgs=2000
)
print(f"FID: {fig_wgan_gp:.4f} | IS: {is_score_wgan_gp:.4f}")

Evaluating WGAN...
--- Computing Metrics (Samples: 2000) ---


Processing Real Images:   4%|▍         | 16/391 [00:39<15:30,  2.48s/it] 


Finalizing calculations...
FID: 65.8375 | IS: 3.8508
Evaluating WGAN-GP...
--- Computing Metrics (Samples: 2000) ---


Processing Real Images:   4%|▍         | 16/391 [00:05<02:08,  2.93it/s]


Finalizing calculations...
FID: 60.8983 | IS: 3.9543
