In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import open_clip
from torch import rand
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import PIL
import glob
from collections import defaultdict
import pandas as pd
import json

In [2]:
path = "./results/modesae_spatialTrue_subtractTrue_downTrue_upTrue_up0True_midTrue_T4_ktrans80_str2.0"
name = "sae_80_2.0"

In [3]:
class CLIPScorer:
    def __init__(self, model_name='ViT-L-14'):
        self.model_name = model_name
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(model_name, pretrained='openai')
        self.model.cuda()
        self.model.eval()
        self.tokenizer = open_clip.get_tokenizer(model_name)

    def embed_texts(self, texts):
        with torch.no_grad(), torch.cuda.amp.autocast():
            text = self.tokenizer(texts).cuda()
            text_features = self.model.encode_text(text)
        return text_features

    def embed_images(self, images):
        with torch.no_grad(), torch.cuda.amp.autocast():
            tensors = []
            for img in images:
                if isinstance(img, np.ndarray):
                    img = Image.fromarray(img)
                tensor = self.preprocess(img).unsqueeze(0)
                tensors += [tensor]
            tensors = torch.cat(tensors, dim=0)
            image_features = self.model.encode_image(tensors.cuda())
        return image_features

    def get_scores(self, texts, images, normalize=True):
        text_features = self.embed_texts(texts)
        image_features = self.embed_images(images)
        if normalize:
            text_features /= text_features.norm(dim=-1, keepdim=True)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        scores = (text_features @ image_features.T)
        return scores

    def get_scores_images(self, images1, images2, normalize=True):
        image_features1 = self.embed_images(images1)
        image_features2 = self.embed_images(images2)
        if normalize:
            image_features1 /= image_features1.norm(dim=-1, keepdim=True)
            image_features2 /= image_features2.norm(dim=-1, keepdim=True)
        scores = (image_features1 @ image_features2.T)
        return scores

In [4]:
def img2lpips(img): # resulve this RuntimeError: Could not infer dtype of PngImageFile
    if isinstance(img, PIL.Image.Image):
        # this is what i tried to do    
        img = np.array(img)
        return (torch.tensor(img).float()/255.).unsqueeze(0).permute(0, 3, 1, 2)
    if isinstance(img, np.ndarray) and img.dtype == np.uint8:
        return (torch.tensor(img).float()/255.).unsqueeze(0).permute(0, 3, 1, 2)
    else:
        return img

lpips = LearnedPerceptualImagePatchSimilarity(net_type='alex', normalize=True)


In [5]:
ref_path = "./results/reference"
# pre load all of the reference images
ref_images = {}

# Process subfolders 1-9
for i in range(1, 10):
    subfolder_path = os.path.join(ref_path, str(i))
    if os.path.exists(subfolder_path):
        for ref in glob.glob(os.path.join(subfolder_path, "*_img1.png")):
            base_name = os.path.basename(ref).replace("_img1.png", "")
            img1_path = ref
            img2_path = ref.replace("_img1.png", "_img2.png")
            #print(img1_path, img2_path)
            if os.path.exists(img2_path):
                ref_images[f"{i}/{base_name}_img1"] = Image.open(img1_path)
                ref_images[f"{i}/{base_name}_img2"] = Image.open(img2_path)

In [None]:
clip_scorer = CLIPScorer()

In [7]:
dataset = json.load(open("./dataset/riebench.json", "r"))
id2data = {d["id"]:d for d in dataset}

In [8]:
def score_path(name, path, ref_images, n_imgs = None, allowed_tasks=None):
    scores = defaultdict(list)
    # iterate over all images and calculate the score folder structure is agian 1/... 2/... etc
    cnt = 0
    for i in range(1, 10):
        if allowed_tasks is not None and i not in allowed_tasks:
            continue
        subfolder_path = os.path.join(path, str(i))
        if os.path.exists(subfolder_path):
            for img in glob.glob(os.path.join(subfolder_path, "*.png")):
                # select the image that does not contain _ in its name
                if "_" not in os.path.basename(img):
                    try:
                        intervention_image = Image.open(img)
                        # get the base name
                        base_name = os.path.basename(img).replace(".png", "")
                        print("processing", base_name,"...")
                        # get the reference images
                        ref_image1 = ref_images[f"{i}/{base_name}_img1"] # edited prompt
                        ref_image2 = ref_images[f"{i}/{base_name}_img2"] # original prompt
                        intervention_jpg = intervention_image.convert('RGB')
                        ref_image1_jpg = ref_image1.convert('RGB')
                        ref_image2_jpg = ref_image2.convert('RGB')
                        # calculate the score
                        lpips_original = lpips(img2lpips(intervention_image), img2lpips(ref_image2))
                        lpips_edited = lpips(img2lpips(intervention_image), img2lpips(ref_image1))
                        # Convert PIL PngImageFile to jpg format in memory
                        
                        clip_img_original = clip_scorer.get_scores_images([intervention_jpg], [ref_image2_jpg])
                        clip_img_edited = clip_scorer.get_scores_images([intervention_jpg], [ref_image1_jpg])

                        original_prompt = id2data[base_name]["original_prompt"]
                        edit_prompt = id2data[base_name]["editing_prompt"]

                        clip_txt_original = clip_scorer.get_scores([original_prompt], [intervention_jpg])
                        clip_txt_edited = clip_scorer.get_scores([edit_prompt], [intervention_jpg])
                        # calculate the score
                        scores["name"] += [name]
                        scores["img"] += ["base_name"]
                        scores["lpips_original"] += [lpips_original.item()]
                        scores["lpips_edited"] += [lpips_edited.item()]
                        scores["clip_img_original"] += [clip_img_original.item()]
                        scores["clip_img_edited"] += [clip_img_edited.item()]
                        scores["clip_txt_original"] += [clip_txt_original.item()]
                        scores["clip_txt_edited"] += [clip_txt_edited.item()]
                        scores["editing_type_id"] += [i]

                        #print("lpips", lpips_original, lpips_edited)
                        #print("clip", clip_img_original, clip_img_edited)
                        cnt += 1
                        if n_imgs is not None and cnt >= n_imgs:
                            return pd.DataFrame(scores)
                    except Exception as e:
                        print(f"Error processing {img}: {e}")
                        continue
    return pd.DataFrame(scores)

In [None]:
scores = score_path(name, path, ref_images, n_imgs = None)
scores.to_csv(f"{path}/{name}.csv", index=False)