In [None]:
import torch
import os
from tqdm import tqdm
import shutil

from torch_fidelity import calculate_metrics
import lpips

from models import Generator
from encoder import FaceNetEncoder
from dataLoaders import get_dataloaders

CHECKPOINT_EPOCH = 20
CHECKPOINT_DIR = "./checkpoints_widerface"
GENERATOR_PATH = os.path.join(CHECKPOINT_DIR, f"generator_epoch_{CHECKPOINT_EPOCH}.pth")
BATCH_SIZE = 32
IMAGE_SIZE = 128
NOISE_DIM = 100
EMBEDDING_DIM = 512

EVAL_DIR = "./eval_images"
REAL_IMG_DIR = os.path.join(EVAL_DIR, "real")
FAKE_IMG_DIR = os.path.join(EVAL_DIR, "fake")

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

if os.path.exists(EVAL_DIR):
    shutil.rmtree(EVAL_DIR)
os.makedirs(REAL_IMG_DIR, exist_ok=True)
os.makedirs(FAKE_IMG_DIR, exist_ok=True)
print(f"Created temporary evaluation directory: {EVAL_DIR}")

Using device: mps
Created temporary evaluation directory: ./eval_images


In [None]:
print("Loading models...")
encoder = FaceNetEncoder(device=device)
encoder.eval()

generator = Generator(noise_dim=NOISE_DIM, embedding_dim=EMBEDDING_DIM).to(device)
try:
    generator.load_state_dict(torch.load(GENERATOR_PATH, map_location=device))
    print(f"Generator weights from epoch {CHECKPOINT_EPOCH} loaded successfully.")
except FileNotFoundError:
    print(f"ERROR: Checkpoint file not found at '{GENERATOR_PATH}'")
    raise

generator.eval()
print("\nLoading test dataset...")
_, test_loader = get_dataloaders(data_root='.', batch_size=BATCH_SIZE, image_size=IMAGE_SIZE)

if not test_loader:
    raise RuntimeError("Could not create test dataloader.")

Loading models...
Generator weights from epoch 20 loaded successfully.

Loading test dataset...
Loading training data from: ./data/train
Loading test data from: ./data/test

DataLoaders created successfully!
Number of training images: 10000
Number of testing images: 500


In [None]:
from torchvision.utils import save_image

all_real_images = []
all_fake_images = []
img_idx = 0

print("\nGenerating fake images for the entire test set...")
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating Images"):
        real_images_batch = batch[0].to(device)
        
        # Get embeddings and create noise
        real_embeddings = encoder(real_images_batch)
        noise = torch.randn(real_images_batch.size(0), NOISE_DIM, device=device)
        
        # Generate fake images
        fake_images_batch = generator(noise, real_embeddings)
        
        # Store tensors for LPIPS calculation later
        all_real_images.append(real_images_batch.cpu())
        all_fake_images.append(fake_images_batch.cpu())

        # Save individual images for FID calculation
        for i in range(real_images_batch.size(0)):
            # Un-normalize before saving to disk
            real_img_unnorm = real_images_batch[i] * 0.5 + 0.5
            fake_img_unnorm = fake_images_batch[i] * 0.5 + 0.5
            
            save_image(real_img_unnorm, os.path.join(REAL_IMG_DIR, f"{img_idx}.png"))
            save_image(fake_img_unnorm, os.path.join(FAKE_IMG_DIR, f"{img_idx}.png"))
            img_idx += 1

print(f"\nSaved {img_idx} real and fake images to '{EVAL_DIR}' for metric calculation.")


Generating fake images for the entire test set...


Generating Images: 100%|██████████| 16/16 [00:09<00:00,  1.77it/s]


Saved 500 real and fake images to './eval_images' for metric calculation.





In [None]:
print("\n--- Calculating LPIPS Score ---")
# Use the 'alex' network, which is standard for this metric
lpips_model = lpips.LPIPS(net='alex').to(device)

total_lpips_distance = 0
num_batches = 0

with torch.no_grad():
    for real_batch, fake_batch in tqdm(zip(all_real_images, all_fake_images), total=len(all_real_images), desc="Calculating LPIPS"):
        real_batch = real_batch.to(device)
        fake_batch = fake_batch.to(device)
        
        distance = lpips_model(real_batch, fake_batch)
        total_lpips_distance += distance.sum().item()
        num_batches += 1

# Average the score over all images
average_lpips = total_lpips_distance / (num_batches * BATCH_SIZE)

print("\n--- LPIPS Result ---")
print(f"Average LPIPS Score: {average_lpips:.4f}")
print("(Lower is better)")


--- Calculating LPIPS Score ---
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /Users/0xr4plh/Documents/Machine Learning/Generative Training 3/invideo/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth


Calculating LPIPS: 100%|██████████| 16/16 [00:00<00:00, 37.06it/s]


--- LPIPS Result ---
Average LPIPS Score: 0.5103
(Lower is better)



