Code has been segregated for clarity of understanding. For it to be fully functional, all modules will either have to be set up as .py files, or all modules will have to be present within the same notebook. This is simply to provide understanding and structure to the working notes submission.

# Importing Libraries

In [None]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models, utils
from torchvision.models import inception_v3
from torchvision.datasets import ImageFolder, DatasetFolder
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch.autograd import Function
import torchvision.models as models
import torchvision.transforms as transforms
import shutil
import cv2
import random
from tqdm.notebook import tqdm
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
import time

# Tuning

Gradient penalty to prevent mode collapse

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculate gradient penalty for regularizing the discriminator"""
    batch_size = real_samples.size(0)

    # resize fake samples to match real samples if dimensions differ
    if real_samples.shape != fake_samples.shape:
        fake_samples = F.interpolate(fake_samples, size=real_samples.shape[2:],
                                    mode='bilinear', align_corners=False)

    # random interpolation of real and fake samples
    alpha = torch.rand((batch_size, 1, 1, 1), device=real_samples.device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)

    # discriminator output for interpolated images
    d_interpolates = D(interpolates)

    # calculate gradients
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # calculate gradient penalty
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()

    return gradient_penalty

# Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G = Generator().to(device)
D = Discriminator().to(device)

# Hyperparameter finetuning

In [None]:
for name, module in D.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        nn.utils.spectral_norm(module)

In [None]:
hu_loss = HULoss(bins=100).to(device)

In [None]:
# Hyperparameters
lambda_adv = 1.0
lambda_nce = 2.0
lambda_tv = 5e-7
lambda_fm = 5.0
lambda_hu = 1.0
lambda_gp = 1.0
n_epochs = 60
batch_size = 256
lr_G, lr_D = 2e-4, 2e-4
beta1_G, beta2_G = 0.5, 0.9
beta1_D, beta2_D = 0.5, 0.999

optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(beta1_G, beta2_G), weight_decay=0)
optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(beta1_D, beta2_D), weight_decay=0)

scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_G, T_max=n_epochs)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=50, gamma=0.5)  # slower decay for D

adv_criterion = nn.BCEWithLogitsLoss()
nce_criterion = PatchNCELoss(temperature=0.05).to(device)
fm_loss = FeatureMatchingLoss().to(device)

feat_extractor = nn.Sequential(*list(D.model.children())[:-1]).to(device)

In [None]:
save_dir = "./generated_from_real_not_used"
os.makedirs(save_dir, exist_ok=True)

In [None]:
def resize_to_match(tensor1, tensor2):
    if tensor1.size(2) != tensor2.size(2) or tensor1.size(3) != tensor2.size(3):
        tensor2 = F.interpolate(tensor2, size=tensor1.shape[2:], mode='bilinear', align_corners=False)
    return tensor2

In [None]:
for epoch in range(n_epochs):
    epoch_g_loss = 0.0
    epoch_d_loss = 0.0
    num_batches = 0

    # Add DataLoader profiling
    for i, (real_img, _) in enumerate(tqdm(dataloader_generated, desc=f"Epoch {epoch+1}")):
        start = time.time()

        real_img = real_img.to(device)
        batch_size = real_img.size(0)

        # Train Generator
        optimizer_G.zero_grad()

        # Generate fake images
        fake_img = G(real_img)

        pred_fake = D(fake_img)

        real_feats = feat_extractor(real_img)
        fake_feats = feat_extractor(fake_img)

        adv_loss = adv_criterion(pred_fake, torch.ones_like(pred_fake))
        nce_loss = nce_criterion(real_feats, fake_feats)
        fm_reg_loss = fm_loss([real_feats], [fake_feats])
        hu_reg_loss = hu_loss(real_img, fake_img)

        G_loss = (lambda_adv * adv_loss +
                  lambda_nce * nce_loss +
                  lambda_fm * fm_reg_loss +
                  lambda_hu * hu_reg_loss)

        G_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        pred_real = D(real_img)
        pred_fake = D(fake_img.detach())  # detach to avoid backprop through generator

        real_loss = adv_criterion(pred_real, torch.ones_like(pred_real))
        fake_loss = adv_criterion(pred_fake, torch.zeros_like(pred_fake))
        gan_loss = (real_loss + fake_loss) * 0.5

        gp = compute_gradient_penalty(D, real_img, fake_img.detach())

        D_loss = gan_loss + lambda_gp * gp

        D_loss.backward()
        optimizer_D.step()

        # Track statistics
        epoch_g_loss += G_loss.item()
        epoch_d_loss += D_loss.item()
        num_batches += 1

        # Add batch time profiling
        torch.cuda.synchronize()  # Make sure GPU finishes before timing
        end = time.time()
        print(f"Batch {i+1} time: {end - start:.2f} sec")

    # Step the learning rate schedulers
    scheduler_G.step()
    scheduler_D.step()

    avg_g_loss = epoch_g_loss / num_batches
    avg_d_loss = epoch_d_loss / num_batches

    # Print epoch losses
    print(f"Epoch [{epoch+1}/{n_epochs}]", flush=True)
    print(f"  G Loss: {avg_g_loss:.4f} (Adv: {adv_loss.item():.4f}, NCE: {nce_loss.item():.4f}, "
          f"FM: {fm_reg_loss.item():.4f}, HU: {hu_reg_loss.item():.4f})")
    print(f"  D Loss: {avg_d_loss:.4f} (GAN: {gan_loss.item():.4f}, GP: {(lambda_gp * gp).item():.4f})")

    # Save images periodically
    if (epoch + 1) % 10 == 0 or epoch == 0:
        G.eval()
        with torch.no_grad():
            test_samples = min(4, batch_size)
            fake_img = G(real_img[:test_samples])

            fake_img = (fake_img + 1) / 2.0
            real_comp = (real_img[:test_samples] + 1) / 2.0

            fake_img = resize_to_match(real_comp, fake_img)

            # Save generated images
            vutils.save_image(fake_img, os.path.join(save_dir, f"gen_epoch_{epoch+1}.png"), nrow=2)

            # Save image comparison
            comparison = torch.cat([real_comp, fake_img], dim=0)
            vutils.save_image(comparison, os.path.join(save_dir, f"compare_epoch_{epoch+1}.png"),
                              nrow=test_samples, normalize=False)

        # Switch back to train mode
        G.train()