# Conditional GAN Evaluation

In this notebook, we will evaluate our trained conditional GAN. The evaluation will be performed on the 'test' split of the CelebA dataset.

We will calculate the LPIPS metrics:

**Learned Perceptual Image Patch Similarity (LPIPS)**: Measures the perceptual similarity between a generated image and its corresponding real image. Since our GAN is conditional, we generate an image using the embedding of a real image and compare the output to that same real image. Lower is better.

**On the test set of CelebA, our conditional GAN achieved competitive LPIPS scores, demonstrating good perceptual similarity between generated and real images.**

The complete details of the evaluation process and results can be seen below in this notebook.

Result Summary -:

**Number of test batches processed: 200
LPIPS Score: 0.3981 (Lower is better)**

In [18]:
!pip install torch torchvision torcheval lpips facenet-pytorch datasets -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import os
import lpips

# Hugging Face dataset library
from datasets import load_dataset

# Pre-trained encoder
from facenet_pytorch import InceptionResnetV1

All libraries imported successfully.


In [None]:
import torch
import torch.nn as nn
from facenet_pytorch import InceptionResnetV1

class FaceNetEncoder(nn.Module):
    """A pure PyTorch encoder using a pre-trained FaceNet model."""
    def __init__(self, device):
        super(FaceNetEncoder, self).__init__()
        # InceptionResnetV1 pretrained on 'vggface2' provides 512-dim embeddings
        self.model = InceptionResnetV1(pretrained='vggface2').to(device)
        self.model.eval()

    def forward(self, image_batch):
        """
        Args:
            image_batch (torch.Tensor): A batch of images of shape (N, C, H, W) normalized to [-1, 1].
        Returns:
            torch.Tensor: A tensor of face embeddings of shape (N, 512).
        """
        # The model expects images in the range [-1, 1], which matches our data pipeline
        with torch.no_grad():
            embeddings = self.model(image_batch)
        return embeddings

class Generator(nn.Module):
    def __init__(self, noise_dim=100, embedding_dim=512, channels=3):
        super(Generator, self).__init__()
        input_dim = noise_dim + embedding_dim

        self.main = nn.Sequential(
            nn.ConvTranspose2d(input_dim, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, noise, embedding):
        combined_input = torch.cat([noise, embedding], dim=1)
        reshaped_input = combined_input.view(-1, combined_input.size(1), 1, 1)
        return self.main(reshaped_input)

# Discriminator is not needed for evaluation, but included for completeness.
class Discriminator(nn.Module):
    def __init__(self, embedding_dim=512, channels=3):
        super(Discriminator, self).__init__()
        self.image_path = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
        )
        self.combined_path = nn.Sequential(
            nn.Conv2d(512 + embedding_dim, 1024, 4, 2, 1, bias=False), nn.BatchNorm2d(1024), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, 4, 1, 0, bias=False), nn.Sigmoid()
        )
    def forward(self, image, embedding):
        image_features = self.image_path(image)
        embedding_reshaped = embedding.view(-1, embedding.size(1), 1, 1)
        embedding_expanded = embedding_reshaped.expand(-1, -1, image_features.size(2), image_features.size(3))
        combined = torch.cat([image_features, embedding_expanded], dim=1)
        output = self.combined_path(combined)
        return output.view(-1, 1).squeeze(1)

In [5]:
import os
# --- Configuration ---
CHECKPOINT_PATH = "checkpoints/generator_epoch_2.pth"
NOISE_DIM = 100
EMBEDDING_DIM = 512
IMAGE_SIZE = 128
BATCH_SIZE = 32 
LIMIT_batches = 200

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

if not os.path.exists("checkpoints"):
    os.makedirs("checkpoints")
    print("Created 'checkpoints' directory. Make sure to place your model file inside.")

Using device: mps


In [6]:

encoder = FaceNetEncoder(device=device)
generator = Generator(noise_dim=NOISE_DIM, embedding_dim=EMBEDDING_DIM).to(device)

if not os.path.exists(CHECKPOINT_PATH):
    print(f"ERROR: Checkpoint file not found at '{CHECKPOINT_PATH}'")
    print("Please make sure the file exists.")
else:
    generator.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
    generator.eval() # Set to evaluation mode
    print(f"Generator loaded successfully from {CHECKPOINT_PATH}")

Generator loaded successfully from checkpoints/generator_epoch_2.pth


In [7]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset

def get_celeba_test_dataloader(batch_size, image_size):
    """Creates a DataLoader for the CelebA TEST split."""
    print("--- Preparing CelebA test dataloader (streaming) ---")
    preprocess = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # To range [-1, 1]
    ])

    def transform_example(example):
        example['image'] = preprocess(example['image'].convert("RGB"))
        return example

    # Use the 'test' split of the dataset
    dataset = load_dataset("flwrlabs/celeba", split="test", streaming=True)
    transformed_dataset = dataset.map(transform_example)
    final_dataset = transformed_dataset.with_format("torch")

    dataloader = DataLoader(final_dataset, batch_size=batch_size)
    print("Test DataLoader created successfully!")
    return dataloader

test_dataloader = get_celeba_test_dataloader(batch_size=BATCH_SIZE, image_size=IMAGE_SIZE)

--- Preparing CelebA test dataloader (streaming) ---


Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Test DataLoader created successfully!


In [None]:
print("\n--- Calculating LPIPS Score ---")

lpips_metric = lpips.LPIPS(net='alex').to(device)
all_lpips_scores = []
test_dataloader_lpips = get_celeba_test_dataloader(batch_size=BATCH_SIZE, image_size=IMAGE_SIZE)

pbar = tqdm(test_dataloader_lpips, total=LIMIT_batches, desc="Calculating LPIPS")

for i, batch in enumerate(pbar):
    if LIMIT_batches is not None and i >= LIMIT_batches:
        break

    real_images = batch['image'].to(device)
    current_batch_size = real_images.size(0)

    # Generate fake images using the same process as before
    real_embeddings = encoder(real_images)
    noise = torch.randn(current_batch_size, NOISE_DIM, device=device)
    fake_images = generator(noise, real_embeddings)
    distances = lpips_metric(real_images.detach(), fake_images.detach())
    all_lpips_scores.extend(distances.squeeze().detach().cpu().numpy())  # Added .detach() here

# Calculate the mean LPIPS score
mean_lpips_score = np.mean(all_lpips_scores)
print(f"\n>>> Final LPIPS Score (Average): {mean_lpips_score:.4f}")


--- Calculating LPIPS Score ---
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /Users/0xr4plh/Documents/Machine Learning/Generative Training/invideo/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth
--- Preparing CelebA test dataloader (streaming) ---


Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Test DataLoader created successfully!


Calculating LPIPS:   0%|          | 0/200 [00:00<?, ?it/s]


>>> Final LPIPS Score (Average): 0.3981


In [17]:
print("--- Evaluation Complete ---")
print(f"Model: {CHECKPOINT_PATH}")
print(f"Number of test batches processed: {LIMIT_batches if LIMIT_batches is not None else 'All'}")
print(f"LPIPS Score: {mean_lpips_score:.4f} (Lower is better)")
print("-" * 30)

--- Evaluation Complete ---
Model: checkpoints/generator_epoch_2.pth
Number of test batches processed: 200
LPIPS Score: 0.3981 (Lower is better)
------------------------------
