In [2]:
from PIL import Image
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from os.path import join, exists
import pickle
import time
from copy import deepcopy
import random

import torch
from src.open_clip.factory import create_model_and_transforms, get_tokenizer
from torch.utils.data import Dataset, DataLoader
from src.training.precision import get_autocast

In [None]:
# Please specify the root directory to benchmark datasets 
ROOT_DATA_DIR = '/PATH/TO/THE/ROOT/DIRECTORY/OF/BENCHMARK/DATASETS'

In [None]:
# Please specify the paths to data frames 
data_csv_path_dict = {
    'SkyScript': '/THE/PATH/TO/SkyScript_test_30K_filtered_by_CLIP_openai.csv',
    'RSICD': '/THE/PATH/TO/RSICD/RSICD_img_txt_pairs_test.csv',
    'RSITMD': '/THE/PATH/TO/RSITMD/RSITMD_img_txt_pairs_test.csv',
    'ucmcaptions': '/THE/PATH/TO/ucmcaptions/ucmcaptions_img_txt_pairs_test.csv',
}

In [3]:
batch_size = 128
precision = 'amp'
autocast = get_autocast(precision)

In [4]:
class CsvDataset_customized(Dataset):
    def __init__(self, df, transforms, img_key, caption_key, tokenizer=None, return_img_path=False, 
                 root_data_dir=None):
        if root_data_dir is not None:
            df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))
        self.images = df[img_key].tolist()
        self.captions = df[caption_key].tolist()
        self.transforms = transforms
        self.tokenize = tokenizer
        self.return_img_path = return_img_path

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        images = self.transforms(Image.open(str(self.images[idx])))
        texts = self.tokenize([str(self.captions[idx])])[0]
        if self.return_img_path:
            return images, texts, str(self.images[idx])
        return images, texts
    
class CsvDataset_image(Dataset):
    def __init__(self, df, transforms, img_key, return_img_path=False, root_data_dir=None):
        if root_data_dir is not None:
            df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))
        self.images = df[img_key].tolist()
        self.transforms = transforms
        self.return_img_path = return_img_path

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        images = self.transforms(Image.open(str(self.images[idx])))
        if self.return_img_path:
            return images, str(self.images[idx])
        return images
    
class CsvDataset_text(Dataset):
    def __init__(self, df, caption_key, tokenizer=None, return_original_text=False, root_data_dir=None):
        if root_data_dir is not None:
            df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))
        self.captions = df[caption_key].tolist()
        self.tokenize = tokenizer
        self.return_original_text = return_original_text

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        original_text = str(self.captions[idx])
        texts = self.tokenize([original_text])[0]
        if self.return_original_text:
            return texts, original_text
        return texts

def random_seed(seed=42, rank=0):
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)

def get_sample_identifier(filepath):
    return '/'.join(filepath.split('/')[-2:])


In [197]:
model_arch_name = 'ViT-L-14' # or 'ViT-B-32'
ckpt_name = 'THE/PATH/TO/MODEL/CHECKPOINT' # replace this with the path to the .pt file
dataset_name = 'SkyScript'
data_csv_path = data_csv_path_dict[dataset_name]

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [199]:
if dataset_name == 'SkyScript':
    caption_key = 'title_multi_objects'
else:
    caption_key = 'title'

force_quick_gelu = True

random_seed(42, 0)
if 'ViT-B-32' in model_arch_name:
    model, _, preprocess_val = create_model_and_transforms(
            model_arch_name,
            ckpt_name,
            precision=precision,
            device=device,
            output_dict=True,
        )
else:
    model, _, preprocess_val = create_model_and_transforms(
            model_arch_name,
            ckpt_name,
            precision=precision,
            device=device,
            output_dict=True,
            force_quick_gelu=force_quick_gelu,
        )

tokenizer = get_tokenizer(model_arch_name)

In [200]:
df = pd.read_csv(data_csv_path)
df['filepath'] = df['filepath'].apply(lambda x: join(ROOT_DATA_DIR, x))
if dataset_name in ['RSICD', 'RSITMD', 'ucmcaptions']:
    df[caption_key] = df[caption_key].apply(lambda x: 'a satellite image. ' + x)

In [201]:
df_image = df.groupby('filepath').count().reset_index()
df_text = df.groupby(caption_key).count().reset_index()
df_image.shape, df_text.shape

((30000, 4), (6255, 4))

In [202]:
# unique images
dataset_image = CsvDataset_image(
    df=df_image, 
    transforms=preprocess_val,
    img_key='filepath',
    return_img_path=True,
    
)

dataloader = DataLoader(dataset_image, batch_size=batch_size, shuffle=False, num_workers=4)

model.eval()
all_image_features = []
all_image_paths = []
with torch.no_grad():
    for batch in tqdm(dataloader, unit_scale=batch_size):
        images, img_paths = batch
        images = images.to(device=device)
        with autocast():
            image_features = model.encode_image(images, normalize=True)
            all_image_features.append(image_features.cpu())
            all_image_paths.extend(img_paths)
all_image_features = torch.cat(all_image_features)

100%|█████████████████████████████████████████████████████████████████████████| 30080/30080 [02:16<00:00, 220.89it/s]


In [203]:
# unique texts
dataset_text = CsvDataset_text(
    df=df_text, 
    caption_key=caption_key, 
    tokenizer=tokenizer, 
    return_original_text=True,
)

dataloader = DataLoader(dataset_text, batch_size=batch_size, shuffle=False, num_workers=4)

model.eval()
all_text_features = []
all_texts = []
with torch.no_grad():
    for batch in tqdm(dataloader, unit_scale=batch_size):
        texts, original_texts = batch
        texts = texts.to(device=device)
        with autocast():
            text_features = model.encode_text(texts, normalize=True)
            all_text_features.append(text_features.cpu())
            all_texts.extend(original_texts)
all_text_features = torch.cat(all_text_features)

100%|██████████████████████████████████████████████████████████████████████████| 6272/6272 [00:03<00:00, 1826.58it/s]


In [204]:
text_indices = {x: i for i, x in enumerate(all_texts)}
img_indices = {x: i for i, x in enumerate(all_image_paths)}

In [205]:
# ground truth
img_path2text = {}
text2img_path = {}
for i in tqdm(df.index):
    text = df.loc[i, caption_key]
    img_path = df.loc[i, 'filepath']
    text_id = text_indices[text]
    img_id = img_indices[img_path]
    if img_path not in img_path2text:
        img_path2text[img_path] = set()
    img_path2text[img_path].add(text_id)
    if text not in text2img_path:
        text2img_path[text] = set()
    text2img_path[text].add(img_id)

100%|███████████████████████████████████████████████████████████████████████| 30000/30000 [00:00<00:00, 38766.03it/s]


In [206]:
res = {'text2img_R@' + str(k): 0 for k in [1, 5, 10, 100]}
res.update({'img2text_R@' + str(k): 0 for k in [1, 5, 10, 100]})

In [207]:
# text to image
logit_scale = 100
for i in tqdm(range(len(all_texts))):
    text_feature = all_text_features[i]
    logits = logit_scale * text_feature @ all_image_features.t()
    ranking = torch.argsort(logits, descending=True).cpu().numpy()
    for k in [1, 5, 10, 100]:
        intersec = set(ranking[:k]) & set(text2img_path[all_texts[i]])
        if intersec:
            res['text2img_R@' + str(k)] += 1
for k in [1, 5, 10, 100]:
    res['text2img_R@' + str(k)] /= len(all_texts)

100%|███████████████████████████████████████████████████████████████████████████| 6255/6255 [00:37<00:00, 168.32it/s]


In [208]:
# image to text
logit_scale = 100
for i in tqdm(range(len(all_image_paths))):
    image_feature = all_image_features[i]
    logits = logit_scale * image_feature @ all_text_features.t()
    ranking = torch.argsort(logits, descending=True).cpu().numpy()
    for k in [1, 5, 10, 100]:
        intersec = set(ranking[:k]) & img_path2text[all_image_paths[i]]
        if intersec:
            res['img2text_R@' + str(k)] += 1
for k in [1, 5, 10, 100]:
    res['img2text_R@' + str(k)] /= len(all_image_paths)

100%|████████████████████████████████████████████████████████████████████████| 30000/30000 [00:23<00:00, 1274.62it/s]


In [None]:
res