In [1]:
import torch
import numpy as np
import cv2
import os
import random as rand
import torchvision
import pandas as pd
from tqdm import tqdm
from torch import nn, Tensor
from torch.nn import utils
import matplotlib.pyplot as plt
from typing import Optional
from torch.nn import functional as F
from torchvision.transforms import v2 as T
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from math import ceil
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CIFAR(Dataset):
    def __init__(self, path="/scratch/s25090/archive/cifar-10/train", dataset:Optional[list]=None):
        super().__init__()
        self.path = path
        self.files = os.listdir(self.path) if dataset is None else dataset
        self.T = T.Compose([
           T.ToImage(), 
           T.ToDtype(torch.float32, scale=True),
           T.Resize((32, 32)),
           T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __len__(self,):
        return len(self.files)
    
    def __getitem__(self, idx):
        file = self.files[idx]
        img_path = os.path.join(self.path, file)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.T(img)
        return img

In [3]:
class SNDiscriminator(nn.Module):
    def __init__(self, channels_img=3, features_d=64):
        super(SNDiscriminator, self).__init__()
        
        # Helper to create a spectral normalized block
        def sn_block(in_channels, out_channels, kernel_size, stride, padding):
            return nn.Sequential(
                # WRAP THE CONV LAYER WITH SPECTRAL_NORM
                utils.spectral_norm(
                    nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
                ),
                nn.LeakyReLU(0.2), 
                # Note: SN-GANs usually remove Batch Norm in the Discriminator entirely
            )

        self.disc = nn.Sequential(
            # Input: N x 3 x 32 x 32
            sn_block(channels_img, features_d, 3, 1, 1),
            
            # Downsampling blocks
            sn_block(features_d, features_d * 2, 4, 2, 1),
            sn_block(features_d * 2, features_d * 4, 4, 2, 1),
            sn_block(features_d * 4, features_d * 8, 4, 2, 1),
            
            # Final output layer (Single scalar score)
            # We also apply Spectral Norm to the final layer
            utils.spectral_norm(nn.Conv2d(features_d * 8, 1, 4, 1, 0)),
        )

    def forward(self, x):
        return self.disc(x).view(-1)

class ResGenBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        )
        
        self.shortcut = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
        )

    def forward(self, x):
        return self.conv_block(x) + self.shortcut(x)

In [4]:
class ResNetGenerator(nn.Module):
    def __init__(self, z_dim=100, features_g=256, channels_img=3):
        super().__init__()
        self.linear = nn.Linear(z_dim, 4 * 4 * features_g)
        self.features_g = features_g
        
        self.net = nn.Sequential(
            ResGenBlock(features_g, features_g),
            ResGenBlock(features_g, features_g),  
            ResGenBlock(features_g, features_g),  
            
            nn.BatchNorm2d(features_g),
            nn.ReLU(inplace=True),
            nn.Conv2d(features_g, channels_img, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.linear(z)
        x = x.view(-1, self.features_g, 4, 4)
        return self.net(x)

class SNDiscriminator(nn.Module):
    def __init__(self, channels_img=3, features_d=64):
        super(SNDiscriminator, self).__init__()
        
        def sn_block(in_channels, out_channels, kernel_size, stride, padding):
            return nn.Sequential(
                utils.spectral_norm(
                    nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
                ),
                nn.LeakyReLU(0.2, inplace=True), 
            )

        self.disc = nn.Sequential(
            sn_block(channels_img, features_d, 3, 1, 1),        
            sn_block(features_d, features_d * 2, 4, 2, 1),      
            sn_block(features_d * 2, features_d * 4, 4, 2, 1),
            sn_block(features_d * 4, features_d * 8, 4, 2, 1),
            utils.spectral_norm(nn.Conv2d(features_d * 8, 1, 4, 1, 0)),
        )

    def forward(self, x):
        return self.disc(x).view(-1)
    
class SGANModel(nn.Module):
    def __init__(self, z_dim=100, channels_img=3, features_g=256, features_d=64):
        super().__init__()
        self.generator = ResNetGenerator(z_dim, features_g, channels_img)
        self.discriminator = SNDiscriminator(channels_img, features_d)
        self.z_dim = z_dim

    def forward(self, z):
        return self.generator(z)

    def compute_discriminator_loss(self, real_imgs, z):
        with torch.no_grad():
            fake_imgs = self.generator(z).detach()

        real_logits = self.discriminator(real_imgs)
        fake_logits = self.discriminator(fake_imgs)
        
        loss_real = torch.mean(F.relu(1.0 - real_logits))
        loss_fake = torch.mean(F.relu(1.0 + fake_logits))
        
        d_loss = loss_real + loss_fake
        return d_loss

    def compute_generator_loss(self, z):
        fake_imgs = self.generator(z)
        fake_logits = self.discriminator(fake_imgs)
        
        g_loss = -torch.mean(fake_logits)
        
        return g_loss, fake_imgs

In [None]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 200
DISC_ITERATIONS = 5
DEVICE = "cuda:4" if torch.cuda.is_available() else "cpu"

WEIGHT_PATH = "/scratch/s25090/sngan_outputs/weights/Experiment1"
PLOT_PATH = "/scratch/s25090/sngan_outputs/plots/Experiment1"
os.makedirs(WEIGHT_PATH, exist_ok=True)
os.makedirs(PLOT_PATH, exist_ok=True)

gan_model = SGANModel(z_dim=Z_DIM, channels_img=CHANNELS_IMG).to(DEVICE)
gan_model.train()

opt_gen = torch.optim.Adam(gan_model.generator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_dis = torch.optim.Adam(gan_model.discriminator.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

train_dataset = CIFAR()
loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print("Starting SN-GAN Training (Hinge Loss)...")
gen_loss_list = []
dis_loss_list = []

for epoch in range(NUM_EPOCHS):
    loop = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    epoch_gen_loss = 0
    epoch_dis_loss = 0
    batches_processed = 0
    
    for batch_idx, (real_img) in enumerate(loop):
        if isinstance(real_img, list) or isinstance(real_img, tuple):
            real_img = real_img[0]
            
        real_img = real_img.to(DEVICE)
        cur_batch_size = real_img.shape[0]
        batches_processed += 1

        opt_dis.zero_grad()
        z = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
        fake_img = gan_model.generator(z)
        d_real = gan_model.discriminator(real_img).reshape(-1)
        d_fake = gan_model.discriminator(fake_img.detach()).reshape(-1)
        loss_d_real = torch.mean(F.relu(1.0 - d_real))
        loss_d_fake = torch.mean(F.relu(1.0 + d_fake))
        
        loss_dis = loss_d_real + loss_d_fake
        
        loss_dis.backward()
        opt_dis.step()
        
        epoch_dis_loss += loss_dis.item()

        loss_gen_item = 0 
        
        if batch_idx % DISC_ITERATIONS == 0:
            opt_gen.zero_grad()
            
            gen_fake_logits = gan_model.discriminator(fake_img).reshape(-1)
            
            loss_gen = -torch.mean(gen_fake_logits)
            
            loss_gen.backward()
            opt_gen.step()
            
            loss_gen_item = loss_gen.item()
            epoch_gen_loss += loss_gen_item
            
            loop.set_postfix(
                d_loss=loss_dis.item(),
                g_loss=loss_gen.item()
            )

    avg_dis_loss = epoch_dis_loss / len(loader)
    avg_gen_loss = epoch_gen_loss / (len(loader) / DISC_ITERATIONS)
    
    gen_loss_list.append(avg_gen_loss)
    dis_loss_list.append(avg_dis_loss)

    print(f"Generator Loss: {avg_gen_loss:.4f} | Discriminator Loss: {avg_dis_loss:.4f}")

    if (epoch + 1) % 10 == 0:
        torch.save(gan_model.state_dict(), f"{WEIGHT_PATH}/sngan_epoch_{epoch+1}.pth")
        
        gan_model.eval()
        with torch.no_grad():
            test_z = torch.randn(8, Z_DIM).to(DEVICE)
            gan_images = gan_model.generator(test_z)
            
            comparison = torch.cat([real_img[:8], gan_images[:8]], dim=0)
            grid = make_grid(comparison.cpu(), nrow=8, padding=2, normalize=True)
            
            plt.figure(figsize=(12, 4))
            plt.imshow(grid.permute(1, 2, 0))
            plt.axis('off')
            plt.title(f'Top: Original | Bottom: Generated Image (Epoch {epoch+1})')
            plt.savefig(f"{PLOT_PATH}/Epoch-{epoch+1}.png")
            plt.close()
        gan_model.train()

plt.figure(figsize=(10, 5))
plt.title("SN-GAN Hinge Loss")
plt.plot(gen_loss_list, label="Generator")
plt.plot(dis_loss_list, label="Discriminator")
plt.xlabel("Epochs")
plt.ylabel("Hinge Loss")
plt.savefig(f"{PLOT_PATH}/Experiment1_loss.png")
plt.legend()
plt.show()

In [7]:
DEVICE = "cuda:4" if torch.cuda.is_available() else "cpu"
model = SGANModel()
WEIGHT = torch.load("sngan_outputs/sngan_epoch_200.pth")
model.load_state_dict(WEIGHT)
model.to(DEVICE)
train_loader = DataLoader(CIFAR(), batch_size=128, shuffle=False)

In [8]:
def get_evaluation_metrics(generator, dataloader, device, num_imgs=10000):
    """
    Calculates FID and IS for a GAN generator.
    
    Args:
        generator: The GAN generator model.
        dataloader: DataLoader for real images (needed for FID reference).
        device: 'cuda' or 'cpu'.
        num_imgs: Number of images to generate/use for calculation.
                  (Standard for papers is 50k, but 10k is faster for debugging).
    
    Returns:
        fid_score (float), is_score (float)
    """
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
    inception = InceptionScore(normalize=True).to(device)
    
    generator.eval()
    
    print(f"--- Computing Metrics (Samples: {num_imgs}) ---")
    
    real_count = 0
    for batch in tqdm(dataloader, desc="Processing Real Images"):
        batch = batch.to(device)
        remaining = num_imgs - real_count
        if remaining <= 0:
            break
            
        if batch.shape[0] > remaining:
            batch = batch[:remaining]
        if batch.min() < 0:
            batch = (batch + 1) / 2  # Now [0, 1]
            
        fid.update((batch * 255).to(torch.uint8), real=True)
        real_count += batch.shape[0]

    fake_count = 0
    while fake_count < num_imgs:
        batch_size = min(dataloader.batch_size, num_imgs - fake_count)
        
        z = torch.randn(batch_size, 100).to(device)
        
        with torch.no_grad():
            fake_imgs = generator(z)
        fake_imgs = (fake_imgs + 1) / 2
        fake_uint8 = (fake_imgs * 255).to(torch.uint8)
        
        fid.update(fake_uint8, real=False)
        inception.update(fake_uint8)
        
        fake_count += batch_size
        
    print("Finalizing calculations...")
    fid_score = fid.compute().item()
    is_score_mean, is_score_std = inception.compute()
    
    return fid_score, is_score_mean.item()
print("Evaluating WGAN...")
fid_wgan, is_score_wgan = get_evaluation_metrics(
    model.generator, 
    train_loader, 
    DEVICE, 
    num_imgs=2000
)
print(f"FID: {fid_wgan:.4f} | IS: {is_score_wgan:.4f}")

Evaluating WGAN...




--- Computing Metrics (Samples: 2000) ---


Processing Real Images:   4%|‚ñç         | 16/391 [00:44<17:27,  2.79s/it] 


Finalizing calculations...
FID: 50.3581 | IS: 4.5761
