# <center>This `.ipynb` file contains the code for computing the LPIPS metric</center>

### 1. Import the required libraries

In [1]:
import os
import sys
import cv2

import torch
from torchvision import transforms

sys.path.insert(0, '..')
from pfiles.lpips_metric import LPIPS

### 2. Custom function for calculating the `LPIPS` metric

In [2]:
def compute_dataset_lpips(original_dir, generated_dir):
    
    loss_fn = LPIPS().to("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    def filter_images(directory):
        return sorted(
            [f for f in os.listdir(directory) 
             if f.lower().endswith(('.jpg')) and ".ipynb_checkpoints" not in f]
        )

    original_images = filter_images(original_dir)
    generated_images = filter_images(generated_dir)

    if len(original_images) != len(generated_images):
        raise ValueError("Number of images in original and generated datasets do not match.")

    total_lpips = 0
    count = 0

    for orig_img_name, gen_img_name in zip(original_images, generated_images):
        orig_img_path = os.path.join(original_dir, orig_img_name)
        gen_img_path = os.path.join(generated_dir, gen_img_name)

        orig_img = cv2.imread(orig_img_path, cv2.IMREAD_COLOR)
        gen_img = cv2.imread(gen_img_path, cv2.IMREAD_COLOR)

        orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
        gen_img = cv2.cvtColor(gen_img, cv2.COLOR_BGR2RGB)

        if orig_img.shape != gen_img.shape:
            gen_img = cv2.resize(gen_img, (orig_img.shape[1], orig_img.shape[0]))

        orig_tensor = transform(orig_img).unsqueeze(0)  # add batch dimension
        gen_tensor = transform(gen_img).unsqueeze(0)    # add batch dimension

        device = "cuda" if torch.cuda.is_available() else "cpu"
        orig_tensor = orig_tensor.to(device)
        gen_tensor = gen_tensor.to(device)

        lpips_score = loss_fn(orig_tensor, gen_tensor).item()
        total_lpips += lpips_score
        count += 1

    avg_lpips = total_lpips / count if count > 0 else 0
    return avg_lpips

### 3. Show the result

In [3]:
original_dir = '../org'
print('Total images in original_dir:', len(next(os.walk(original_dir))[2]))

generated_dir =  'gen_14'
print('Total images in generated_dir:', len(next(os.walk(generated_dir))[2]))

avg_lpips_score = compute_dataset_lpips(original_dir, generated_dir)
print(f"Average LPIPS: {avg_lpips_score:.4f}")

Total images in original_dir: 10000
Total images in generated_dir: 10000
Average LPIPS: 0.4923
