In [62]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, ConcatDataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import pickle
import os
import kagglehub
# Set random seed for reproducibility
manualSeed = 999
# manualSeed = random.randint(1, 10000)  # Use this if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# Check for CUDA
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.use_deterministic_algorithms(False) # Needed for reproducible results

In [65]:
# Root directory for dataset
root_dir="images/wiki"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 40

# Learning rate for optimizers
lrG = 0.0002
lrD = 0.0001
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

celeba=False

# Descomment to load YOLO model

In [None]:
'''
import os
import cv2
from ultralytics import YOLO
from tqdm import tqdm
from PIL import Image

# Load YOLOv8 model
model = YOLO('yolov8n-face-lindevs.pt')  

# Define input and output folders
input_root = root_dir
output_folder = "cropped2/cropped_faces"


os.makedirs(output_folder, exist_ok=True)
root_for = output_folder
# Process all images recursively
for root, _, files in os.walk(input_root):
    for filename in tqdm(files, desc=f"Processing {root}"):
        img_path = os.path.join(root, filename)

        # Load image
        img = cv2.imread(img_path)
        if img is None:
            continue  # Skip non-image files

        # Run YOLO detection
        results = model(img)

        # If no head/face detected, skip
        if len(results[0].boxes) == 0:
            continue

        # Extract first detected head/face
        x1, y1, x2, y2 = map(int, results[0].boxes.xyxy[0])
        cropped_img = img[y1:y2, x1:x2]

        print(f"Detected face coordinates: {x1}, {y1}, {x2}, {y2}")

        if (x2 - x1) < 50 or (y2 - y1) < 50:
            continue

        #cv2.imshow("Image", img)
        #cv2.imshow("Cropped Face", cropped_img)
        #cv2.waitKey(0)
        #cv2.destroyAllWindows()

        # Convert to PIL format
        pil_img = Image.fromarray(cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB))

        # Generate a unique filename (to avoid overwriting)
        new_filename = f"{os.path.splitext(filename)[0]}_cropped.jpg"
        save_path = os.path.join(output_folder, new_filename)

        # Save cropped face
        pil_img.save(save_path)


print(f"Processing complete! Cropped images saved in '{output_folder}'.")
'''

In [None]:
# Download the CelebA dataset via KaggleHub
if celeba:
    path = kagglehub.dataset_download("jessicali9530/celeba-dataset")
    print("✅ Dataset downloaded to:", path)

    # Define root_dir to where the image folder is
    root_dir = os.path.join(path, "img_align_celeba")

# Load dataset using ImageFolder
dataset = dset.ImageFolder(root=root_dir,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Visualize a batch of real images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(
    vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),
    (1, 2, 0)
))
plt.show()

Weight Initialization

In [69]:
# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Generator

In [70]:
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)

# Print the model
print(netG)

# Discriminator

In [72]:
class WGANDiscriminator(nn.Module):
    def __init__(self, ngpu):
        super(WGANDiscriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # Input: (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            # Output: 1x1x1 (flattened below)
        )

    def forward(self, input):
        return self.main(input).view(-1)

In [None]:
# Create the Discriminator
netD = WGANDiscriminator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)

# Print the model
print(netD)

Loss Functions and Optimizers


In [74]:
# Create batch of latent vectors that we will use to visualize the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Setup Adam optimizers for both G and D
optimizerD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizerG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.0, 0.9))

## Computing Gradient Penalty

In [75]:
def compute_gradient_penalty(D, real_samples, fake_samples, device):
    # Ensure both inputs have the same batch size
    if real_samples.size(0) != fake_samples.size(0):
        raise ValueError(f"Batch size mismatch: real({real_samples.size(0)}), fake({fake_samples.size(0)})")

    batch_size = real_samples.size(0)

    # Sample epsilon from uniform distribution [0, 1]
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)

    # Interpolate between real and fake samples
    interpolates = (epsilon * real_samples + (1 - epsilon) * fake_samples).requires_grad_(True)

    # Discriminator output for interpolated images
    d_interpolates = D(interpolates)

    # For WGAN-GP: create gradient outputs filled with 1s
    grad_outputs = torch.ones_like(d_interpolates, device=device)

    # Compute gradients of D(interpolates) w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Flatten gradients per sample and compute L2 norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)

    # Compute gradient penalty: (||grad||_2 - 1)^2
    penalty = ((gradient_norm - 1) ** 2).mean()
    return penalty



# Computing Fréchet inception distance

In [76]:
import torch
import torchvision.transforms as transforms
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import inception_v3
from torchvision.utils import save_image

# Ensure FID calculation happens in evaluation mode
fid_metric = FrechetInceptionDistance(feature=2048).to("cuda" if torch.cuda.is_available() else "cpu")

In [77]:
def compute_fid(real_images, fake_images, fid_metric):
    """Computes FID between real and fake images."""
    
    real_images = (real_images * 255).byte()
    fake_images = (fake_images * 255).byte()
   
    # Update the metric with real and fake images
    fid_metric.update(real_images, real=True)
    fid_metric.update(fake_images, real=False)

    # Compute FID
    fid_score = fid_metric.compute()
    
    # Reset metric for next evaluation
    fid_metric.reset()
    
    return fid_score.item()


In [None]:
# Lists to keep track of progress
print(f"Using device: {device}")
print(f"netG is on: {next(netG.parameters()).device}")
print(f"netD is on: {next(netD.parameters()).device}")
img_list = []
G_losses = []
D_losses = []
FID_scores = []
iters = 0
def show_generated_images(img_tensor, title=None):
    npimg = img_tensor.numpy()
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    if title:
        plt.title(title)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
print("Starting Training Loop...")
n_critic= 5 # Update the Discriminator 5 times for every 1 Generator update
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(dataloader, 0):

        ############################
        # (1) Update D network
        ############################
        for _ in range(n_critic):
            netD.zero_grad()

            # Get real images and infer actual batch size
            real_images = images.to(device)
            curr_batch_size = real_images.size(0)

            # Generate fake images
            noise = torch.randn(curr_batch_size, nz, 1, 1, device=device)
            fake_images = netG(noise).detach()

            # Get discriminator outputs
            output_real = netD(real_images)
            output_fake = netD(fake_images)

            # Compute gradient penalty
            lambda_gp = 10
            gp = compute_gradient_penalty(netD, real_images.data, fake_images.data, device)

            # Final discriminator loss
            D_loss = -torch.mean(output_real) + torch.mean(output_fake) + lambda_gp * gp

            try:
                D_loss.backward()
            except RuntimeError as e:
                print("Backward Error:", e)
            optimizerD.step()

        ############################
        # (2) Update G network
        ############################
        netG.zero_grad()
        noise = torch.randn(curr_batch_size, nz, 1, 1, device=device)
        fake_images = netG(noise)
        output_fake = netD(fake_images)
        G_loss = -torch.mean(output_fake)
        G_loss.backward()
        optimizerG.step()

        # Print losses occasionally
        if i % 50 == 0:
            print(f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}] "
                  f"Loss_D: {D_loss.item():.4f} Loss_G: {G_loss.item():.4f}")

        # Save losses
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())

        iters += 1
        iter_counter += 1
        
    checkpoint_path = "checkpointswgangp/"
    os.makedirs(checkpoint_path, exist_ok=True)

    if (epoch + 1) % 5 == 0:
        save_data = {
            'epoch': epoch,
            'netG_state_dict': netG.state_dict(),
            'netD_state_dict': netD.state_dict(),
            'optimizerG_state_dict': optimizerG.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
            'G_losses': G_losses,
            'D_losses': D_losses,
        }
        with open(f"{checkpoint_path}/wgangp_checkpoint_epoch_{epoch+1}.pkl", "wb") as f:
            pickle.dump(save_data, f)
        print(f"Checkpoint saved at epoch {epoch+1}")




# Evolution of the Frechet Distance

In [None]:
print(FID_scores)

In [None]:
torch.use_deterministic_algorithms(False)
import re
from tqdm import tqdm

# ---- Settings ----
checkpoint_dir = "checkpointswgangp/"
checkpoint_pattern = r"wgangp_checkpoint_epoch_(\d+)_small_dataset\.pkl"
num_fid_batches = 10
batch_size = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Collect all checkpoint paths and sort by epoch ----
checkpoints = []
for filename in os.listdir(checkpoint_dir):
    match = re.match(checkpoint_pattern, filename)
    if match:
        epoch_num = int(match.group(1))
        checkpoints.append((epoch_num, os.path.join(checkpoint_dir, filename)))

checkpoints.sort()  # Sort by epoch number

# ---- FID results storage ----
fid_results = []

# ---- Loop through each checkpoint ----
for epoch, path in checkpoints:
    print(f"\nEvaluating checkpoint: {path}")
    
    # Load checkpoint
    with open(path, "rb") as f:
        checkpoint = pickle.load(f)

    netG.load_state_dict(checkpoint['netG_state_dict'])
    netG.to(device)
    netG.eval()

    # Reset FID metric
    fid_metric = fid_metric.to(device)
    fid_metric.reset()

    # Generate and evaluate FID over multiple batches
    fid_scores = []
    real_iter = iter(dataloader)

    for _ in tqdm(range(num_fid_batches), desc=f"Epoch {epoch}"):
        with torch.no_grad():
            try:
                real_batch = next(real_iter)[0]
            except StopIteration:
                real_iter = iter(dataloader)
                real_batch = next(real_iter)[0]

            real_images = real_batch[:batch_size].to(device)
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_images = netG(noise)

            score = compute_fid(real_images, fake_images, fid_metric)
            fid_scores.append(score)

    mean_fid = sum(fid_scores) / len(fid_scores)
    fid_results.append((epoch, mean_fid))
    print(f"✅ Epoch {epoch}: Mean FID = {mean_fid:.4f}")

# ---- Display results ----
print("\n=== Final FID Scores ===")
for epoch, score in fid_results:
    print(f"Epoch {epoch}: FID = {score:.4f}")

In [None]:
fid_epochs = [epoch for epoch, score in fid_results]
fid_scores = [score for epoch, score in fid_results]

# Plotting
plt.figure(figsize=(8, 5))
plt.plot(fid_epochs, fid_scores, marker='o', linestyle='-')
plt.title("FID Score Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("FID Score")
plt.grid(True)
plt.tight_layout()
plt.show()

# Evolution of Generator and Discriminator Losses over Training

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Image Generator Evolution / Progression over the Epochs 

In [None]:
torch.use_deterministic_algorithms(False)
import re
from tqdm import tqdm

checkpoint_dir = "checkpointswgangp/"
checkpoint_pattern = r"wgangp_checkpoint_epoch_(\d+)\.pkl"
num_fid_batches = 10
batch_size = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoints = []
for filename in os.listdir(checkpoint_dir):
    match = re.match(checkpoint_pattern, filename)
    if match:
        epoch_num = int(match.group(1))
        checkpoints.append((epoch_num, os.path.join(checkpoint_dir, filename)))

checkpoints.sort()  # Sort by epoch number

fid_results = []

for epoch, path in checkpoints:
    print(f"\nEvaluating checkpoint: {path}")
    
    # Load checkpoint
    with open(path, "rb") as f:
        checkpoint = pickle.load(f)

    netG.load_state_dict(checkpoint['netG_state_dict'])
    netG.to(device)
    netG.eval()

    # Reset FID metric
    fid_metric = fid_metric.to(device)
    fid_metric.reset()

    # Generate and evaluate FID over multiple batches
    fid_scores = []
    real_iter = iter(dataloader)

    for _ in tqdm(range(num_fid_batches), desc=f"Epoch {epoch}"):
        with torch.no_grad():
            try:
                real_batch = next(real_iter)[0]
            except StopIteration:
                real_iter = iter(dataloader)
                real_batch = next(real_iter)[0]

            real_images = real_batch[:batch_size].to(device)
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake_images = netG(noise)

            score = compute_fid(real_images, fake_images, fid_metric)
            fid_scores.append(score)

    mean_fid = sum(fid_scores) / len(fid_scores)
    fid_results.append((epoch, mean_fid))
    print(f"✅ Epoch {epoch}: Mean FID = {mean_fid:.4f}")

print("\n=== Final FID Scores ===")
for epoch, score in fid_results:
    print(f"Epoch {epoch}: FID = {score:.4f}")

# Real Images vs Fake Images

In [None]:
# Load last checkpoint 
last_epoch, last_ckpt_path = checkpoints[-1]
print(f"\n🔄 Loading last checkpoint (Epoch {last_epoch}) from {last_ckpt_path}")

with open(last_ckpt_path, "rb") as f:
    checkpoint = pickle.load(f)

netG.load_state_dict(checkpoint['netG_state_dict'])
netG.to(device)
netG.eval()

# Generate image from noise 
with torch.no_grad():
    noise = torch.randn(64, nz, 1, 1, device=device)  # 64 images for grid
    fake_images = netG(noise).detach().cpu()



# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(vutils.make_grid(fake_images, padding=5, normalize=True), (1, 2, 0)))
plt.show()