## 5. Evaluation
### Generated Image Quality Metrics

1. **FID (Fréchet Inception Distance)**
   - Measures how similar the distribution of generated images is to real images
   - Lower is better (good models: <50, excellent models: <20)
   - Uses Inception-v3 features to compare real and generated image statistics

2. **Inception Score (IS)**
   - Measures both quality and diversity of generated images
   - Higher is better (good models: >3, excellent models: >7)
   - Evaluates if images contain clear, recognizable objects and are diverse

3. **Kernel Inception Distance (KID)**
   - Similar to FID but more reliable for smaller sample sizes
   - Lower is better (good models: <0.05, excellent models: <0.02)
   - Less sensitive to sample size than FID, good for quick evaluations

4. **Diversity Score**
   - Custom metric measuring average L2 distance between pairs of generated images
   - Higher values indicate more diverse outputs
   - Helps detect mode collapse (when model generates very similar images)


In [45]:
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from tqdm import tqdm
import time
from torch import nn, einsum
from inspect import isfunction
from functools import partial
from model_unet import *
device = "cuda" if torch.cuda.is_available() else "cpu"

class Evaluator:
    def __init__(self, device='cuda'):
        print("Initializing evaluation metrics...")
        self.device =torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(self.device)
        print(self.device)
        self.fid = FrechetInceptionDistance(normalize=True).to(device)
        self.inception_score = InceptionScore(normalize=True).to(device)
        self.kid = KernelInceptionDistance(normalize=True, subset_size=50).to(device)
        print("✓ Metrics initialized successfully")

    @torch.no_grad()
    def generate_images(self, model, num_images, image_size=64, batch_size=32):
        # Ensure model is in evaluation mode and on the correct device
        model.eval()
        model = model.to(self.device)
        
        # Split into batches if needed
        batches = num_to_groups(num_images, batch_size)
        print(batches)
        
        # Sample images for each batch
        all_images_list = []
        for n in batches:
            # Generate images for this batch
            # Ensure the shape tensor is on the same device as the model
            batch_images = sample(model, image_size, diffusion_params=diffusion_params, batch_size=32, channels=3)
            
            # Take the last image in the sampling process (fully denoised)
            batch_images = batch_images[-1].to(self.device)
            all_images_list.append(batch_images)
        
        # Concatenate images
        all_images = torch.cat(all_images_list, dim=0)
        
        # Move to CPU for saving
        all_images = all_images.cpu()
        
        # Normalize to [0, 1] range
        all_images = (all_images + 1) * 0.5
        
        # Ensure images are on CPU and clipped to [0, 1]
        all_images = torch.clamp(all_images, 0, 1)
        
        # Save images
        #save_image(all_images, 'generated_images.png', nrow=8)
        
        return all_images


        
    @torch.no_grad()
    def evaluate_samples(self, real_dataloader, model, num_samples=100, batch_size=32):
        """
        Enhanced evaluation with KID score and progress logging
        """
        start_time = time.time()
        print(f"\n📊 Starting evaluation with {num_samples} samples...")
    
        print("\n1️⃣ Collecting real images...")
        real_images = []
        for batch, _ in real_dataloader:
            real_images.append(batch)
            if len(torch.cat(real_images)) >= num_samples:
                break
        real_images = torch.cat(real_images)[:num_samples].to(self.device)
        print(f"✓ Collected {len(real_images)} real images")

        print("\n2️⃣ Generating samples...")
        generated_images = self.generate_images(model, num_images=num_samples, image_size=64)
        generated_images = generated_images.to(self.device)
        print(f"✓ Generated {len(generated_images)} images")
        
        # Calculate metrics
        metrics = {}
        print("\n3️⃣ Computing metrics...")
        
        print("Computing FID score...")
        self.fid.reset()
        self.fid.update(real_images, real=True)
        self.fid.update(generated_images, real=False)
        metrics['fid'] = self.fid.compute().item()
        print(f"✓ FID Score: {metrics['fid']:.2f}")
        
        print("\nComputing Inception Score...")
        self.inception_score.reset()
        self.inception_score.update(generated_images)
        is_mean, is_std = self.inception_score.compute()
        metrics['inception_score_mean'] = is_mean.item()
        metrics['inception_score_std'] = is_std.item()
        print(f"✓ Inception Score: {metrics['inception_score_mean']:.2f} ± {metrics['inception_score_std']:.2f}")
        
        print("\nComputing KID score...")
        self.kid.reset()
        self.kid.update(real_images, real=True)
        self.kid.update(generated_images, real=False)
        kid_mean, kid_std = self.kid.compute()
        metrics['kid_mean'] = kid_mean.item()
        metrics['kid_std'] = kid_std.item()
        print(f"✓ KID Score: {metrics['kid_mean']:.4f} ± {metrics['kid_std']:.4f}")
        
        print("\nComputing diversity score...")
        if len(generated_images) >= 2:
            diversity_score = self.calculate_diversity(
                generated_images, 
                num_pairs=min(100, num_samples * 2)
            )
            metrics['diversity_score'] = diversity_score
            print(f"✓ Diversity Score: {metrics['diversity_score']:.2f}")
        
        elapsed_time = time.time() - start_time
        print(f"\n✨ Evaluation completed in {elapsed_time:.2f} seconds")
            
        return metrics
    
    def calculate_diversity(self, images, num_pairs=100):
        """
        Calculate diversity score with random pairs
        """
        num_images = len(images)
        idx1 = torch.randint(0, num_images, (num_pairs,))
        idx2 = torch.randint(0, num_images, (num_pairs,))
        
        images_flat = images.view(len(images), -1)
        distances = torch.norm(
            images_flat[idx1] - images_flat[idx2], 
            dim=1, 
            p=2
        )
        
        return distances.mean().item()

def evaluate_diffusion(model, test_loader, num_samples=100):
    """
    Comprehensive evaluation with all metrics
    """
    evaluator = Evaluator()
    metrics = evaluator.evaluate_samples(
        test_loader,
        model,
        num_samples=num_samples
    )
    
    print("\n📈 Final Evaluation Results:")
    print("=" * 40)
    print(f"FID Score: {metrics['fid']:.2f}")
    print(f"Inception Score: {metrics['inception_score_mean']:.2f} ± {metrics['inception_score_std']:.2f}")
    print(f"KID Score: {metrics['kid_mean']:.4f} ± {metrics['kid_std']:.4f}")
    print(f"Diversity Score: {metrics['diversity_score']:.2f}")
    print("=" * 40)
    
    # Provide some context about the scores
    print("\n📝 Quick interpretation:")
    print("FID: Lower is better (good: <50, excellent: <20)")
    print("IS: Higher is better (good: >3, excellent: >7)")
    print("KID: Lower is better (good: <0.05, excellent: <0.02)")
    print("Diversity: Higher indicates more varied outputs")
    
    return metrics

In [46]:
def load_model(image_size, channels,path):
    model = Unet(
        dim=image_size,
        channels=channels,
        dim_mults=(1, 2, 4,)
    )
    
    checkpoint = torch.load(path)
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict)
    
    model = model.to(device)
    return model

def prepare_data_loaders(train_data, train_labels, val_data, val_labels, test_data, test_labels, batch_size):
    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    test_dataset = TensorDataset(test_data, test_labels)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader

def load_flowers(batch_size):
    test_data = torch.load("data/prepared_datasets/train_flowers.pt")
    test_labels = torch.load("data/prepared_datasets/train_flowers_labels.pt")
    val_data = torch.load("data/prepared_datasets/val_flowers.pt")
    val_labels = torch.load("data/prepared_datasets/val_flowers_labels.pt")
    train_data = torch.load("data/prepared_datasets/test_flowers.pt")
    train_labels = torch.load("data/prepared_datasets/test_flowers_labels.pt")
    train_data = (train_data - train_data.min()) / (train_data.max() - train_data.min())
    val_data = (val_data - val_data.min()) / (val_data.max() - val_data.min())
    test_data = (test_data - test_data.min()) / (test_data.max() - test_data.min())
    
    train_data = train_data * 2 - 1
    val_data = val_data * 2 - 1
    test_data = test_data * 2 - 1
    
    train_loader, val_loader, test_loader = prepare_data_loaders(train_data, train_labels, val_data, val_labels, test_data, test_labels,batch_size)
    return train_loader, val_loader, test_loader

In [47]:
model = load_model(64,3,'models/model_3.pth')
train_loader, val_loader, test_loader = load_flowers(32)

In [48]:
metrics = evaluate_diffusion(
    model,
    test_loader,
    num_samples=128
)

Initializing evaluation metrics...
cuda
✓ Metrics initialized successfully

📊 Starting evaluation with 128 samples...

1️⃣ Collecting real images...
✓ Collected 128 real images

2️⃣ Generating samples...
[32, 32, 32, 32]


sampling loop time step: 100%|██████████| 200/200 [00:19<00:00, 10.06it/s]
sampling loop time step: 100%|██████████| 200/200 [00:19<00:00, 10.12it/s]
sampling loop time step: 100%|██████████| 200/200 [00:20<00:00, 10.00it/s]
sampling loop time step: 100%|██████████| 200/200 [00:20<00:00,  9.81it/s]


✓ Generated 128 images

3️⃣ Computing metrics...
Computing FID score...
✓ FID Score: 293.59

Computing Inception Score...
✓ Inception Score: 1.41 ± 0.09

Computing KID score...
✓ KID Score: 0.3057 ± 0.0142

Computing diversity score...
✓ Diversity Score: 18.28

✨ Evaluation completed in 83.56 seconds

📈 Final Evaluation Results:
FID Score: 293.59
Inception Score: 1.41 ± 0.09
KID Score: 0.3057 ± 0.0142
Diversity Score: 18.28

📝 Quick interpretation:
FID: Lower is better (good: <50, excellent: <20)
IS: Higher is better (good: >3, excellent: >7)
KID: Lower is better (good: <0.05, excellent: <0.02)
Diversity: Higher indicates more varied outputs
