In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Install requirements
!pip install scikit-image
!pip install ftfy regex tqdm scikit-image --quiet
!pip install git+https://github.com/openai/CLIP.git --quiet
!pip install lightly --quiet
!pip install torchvision --quiet
!pip install pytorch-fid --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m116.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m87.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m56.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.9 MB/s[0m eta 

In [4]:
import os
import torch
import clip
import gc
import timm
import torch.nn.functional as F
import numpy as np
import pandas as pd
from PIL import Image
from glob import glob
from tqdm import tqdm
from torchvision import transforms
from torchvision.models import inception_v3
from skimage.metrics import structural_similarity as ssim
from scipy.stats import pearsonr

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Paths
base_path = "/content/drive/MyDrive/comparison"
models = ['clip', 'resnet50', 'vit']
subjects = ['1', '2', '3', '4', '5']
gt_images = sorted(glob(os.path.join(base_path, "Ground Truth", "*.png")))

# Load Models
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
inception = inception_v3(pretrained=True, transform_input=False).to(device).eval()

# Load SwAV from TIMM (pretrained by Facebook)
swav_model = timm.create_model("resnet50", pretrained=True)
swav_model.fc = torch.nn.Identity()
swav_model = swav_model.to(device).eval()

# Transforms
swav_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
inception_transform = transforms.Compose([transforms.Resize((299, 299)), transforms.ToTensor()])

# Metrics

def compute_clip_distance(img1, img2):
    t1 = clip_preprocess(img1).unsqueeze(0).to(device)
    t2 = clip_preprocess(img2).unsqueeze(0).to(device)
    with torch.no_grad():
        f1 = clip_model.encode_image(t1)
        f2 = clip_model.encode_image(t2)
        f1 /= f1.norm(dim=-1, keepdim=True)
        f2 /= f2.norm(dim=-1, keepdim=True)
    return 1 - (f1 @ f2.T).item()

def compute_swav_distance(img1, img2):
    t1 = swav_transform(img1).unsqueeze(0).to(device)
    t2 = swav_transform(img2).unsqueeze(0).to(device)
    with torch.no_grad():
        f1 = swav_model(t1)
        f2 = swav_model(t2)
        f1 /= f1.norm(dim=-1, keepdim=True)
        f2 /= f2.norm(dim=-1, keepdim=True)
    return 1 - (f1 @ f2.T).item()

def compute_inception_distance(img1, img2):
    t1 = inception_transform(img1).unsqueeze(0).to(device)
    t2 = inception_transform(img2).unsqueeze(0).to(device)
    with torch.no_grad():
        f1 = inception(t1)
        f2 = inception(t2)
        f1 = F.normalize(f1, dim=-1)
        f2 = F.normalize(f2, dim=-1)
    return 1 - (f1 @ f2.T).item()

# Evaluation Loop
results = []

for model in models:
    for subj in subjects:
        ssim_scores, pixcorr_scores = [], []
        clip_distances, swav_distances, inception_distances = [], [], []

        gen_path = os.path.join(base_path, model, subj)

        for gt_path in tqdm(gt_images, desc=f"{model} - Subj{subj}"):
            name = os.path.basename(gt_path).replace("_ground_truth.png", "")
            pred_path = os.path.join(gen_path, f"{name}_generated_{model}.png")
            if not os.path.exists(pred_path):
                continue

            gt_img = Image.open(gt_path).convert("RGB").resize((224, 224))
            pred_img = Image.open(pred_path).convert("RGB").resize((224, 224))

            clip_distances.append(compute_clip_distance(gt_img, pred_img))
            swav_distances.append(compute_swav_distance(gt_img, pred_img))
            inception_distances.append(compute_inception_distance(gt_img, pred_img))

            torch.cuda.empty_cache()
            gc.collect()

        results.append({
            "Model": model,
            "Subject": subj,
            "CLIPDist_Mean": np.mean(clip_distances),
            "CLIPDist_Std": np.std(clip_distances),
            "SwAVDist_Mean": np.mean(swav_distances),
            "SwAVDist_Std": np.std(swav_distances),
            "InceptionDist_Mean": np.mean(inception_distances),
            "InceptionDist_Std": np.std(inception_distances)
        })

# Save results
df = pd.DataFrame(results)
df.to_csv("/content/comparison_metrics_all.csv", index=False)
df.head()


clip - Subj1: 100%|██████████| 6/6 [00:03<00:00,  1.98it/s]
clip - Subj2: 100%|██████████| 6/6 [00:01<00:00,  3.05it/s]
clip - Subj3: 100%|██████████| 6/6 [00:01<00:00,  3.06it/s]
clip - Subj4: 100%|██████████| 6/6 [00:02<00:00,  2.62it/s]
clip - Subj5: 100%|██████████| 6/6 [00:02<00:00,  2.57it/s]
resnet50 - Subj1: 100%|██████████| 6/6 [00:01<00:00,  3.04it/s]
resnet50 - Subj2: 100%|██████████| 6/6 [00:01<00:00,  3.10it/s]
resnet50 - Subj3: 100%|██████████| 6/6 [00:01<00:00,  3.06it/s]
resnet50 - Subj4: 100%|██████████| 6/6 [00:01<00:00,  3.13it/s]
resnet50 - Subj5: 100%|██████████| 6/6 [00:02<00:00,  2.92it/s]
vit - Subj1: 100%|██████████| 6/6 [00:02<00:00,  2.38it/s]
vit - Subj2: 100%|██████████| 6/6 [00:01<00:00,  3.16it/s]
vit - Subj3: 100%|██████████| 6/6 [00:01<00:00,  3.09it/s]
vit - Subj4: 100%|██████████| 6/6 [00:01<00:00,  3.08it/s]
vit - Subj5: 100%|██████████| 6/6 [00:01<00:00,  3.05it/s]


Unnamed: 0,Model,Subject,CLIPDist_Mean,CLIPDist_Std,SwAVDist_Mean,SwAVDist_Std,InceptionDist_Mean,InceptionDist_Std
0,clip,1,0.328776,0.121041,0.641161,0.110793,0.829532,0.134451
1,clip,2,0.298258,0.12029,0.647775,0.162695,0.8669,0.158939
2,clip,3,0.322184,0.128756,0.577566,0.218331,0.753687,0.316654
3,clip,4,0.354085,0.132548,0.685017,0.107717,0.85821,0.171767
4,clip,5,0.320312,0.097462,0.699619,0.104233,0.86199,0.214965
