In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from playground.load import dataset, clip_model, clip_processor, model_device, extract_image_vectors

In [4]:
# will take a few minutes
# df, image_vectors = extract_image_vectors(dataset)
# df.to_parquet('data/image_meta.parquet')
# np.save('data/image_vectors.npy', image_vectors)
df = pd.read_parquet('data/objectnet/image_meta.parquet')
image_vectors = np.load('data/objectnet/image_vectors.npy')

In [46]:
def get_y_true(df, label):
    assert label in df.label.unique()
    return np.where(df.label == label, 1, 0)

def get_query_vector(text):
    query_tokens = clip_processor(text=[text], return_tensors='pt')
    query_vector = clip_model.get_text_features(query_tokens['input_ids'].to(model_device))
    query_vector = F.normalize(query_vector)
    query_vector = query_vector.cpu().detach().numpy().reshape(-1)
    return query_vector

In [6]:
from sklearn.metrics import average_precision_score

In [47]:
y_true = get_y_true(df, 'banana')
query_vector = get_query_vector('a photo of a banana')
image_scores = image_vectors @ query_vector
average_precision_score(y_true, image_scores)

0.8783107024897996

In [8]:
# for each label, we will pick a few random 'train' images we will use for querying, and the rest will be used for measuring performance
np.random.seed(0)

In [32]:
df = df.assign(random_id=np.random.permutation(df.shape[0]))
df = df.assign(group_rank=df.groupby('label')['random_id'].rank(method='first').astype('int'))
df = df.assign(split=df.group_rank.apply(lambda x: 'query' if x <= 5 else 'test'))

In [33]:
# now, we will compute the average precision for each train example
query_df = df[df.split == 'query']
test_df = df[df.split == 'test']


In [36]:
test_vec_db = image_vectors[test_df.index.values]

In [38]:
image_query_vecs = image_vectors[query_df.index.values]

In [52]:
aps = []
for (idx, row), query_vector in tqdm(zip(query_df.iterrows(), image_query_vecs), total=query_df.shape[0]):
    #    query_vector = get_query_vector(row.text)
    image_scores = test_vec_db @ query_vector
    y_true = get_y_true(test_df, row.label)
    ap = average_precision_score(y_true, image_scores)
    aps.append(ap)

100%|██████████| 1565/1565 [00:32<00:00, 47.53it/s]


In [53]:
query_df = query_df.assign(ap=aps)

In [55]:
query_df

Unnamed: 0,label,path,random_id,group_rank,split,ap
32,air_freshener,air_freshener/2883d4225ff948a.png,1156,3,query,0.085341
40,air_freshener,air_freshener/38af90ed09a74e1.png,40,1,query,0.076872
67,air_freshener,air_freshener/5dc8b11f33b6464.png,1643,5,query,0.034204
128,air_freshener,air_freshener/a7f7c424fe094d3.png,1277,4,query,0.007236
158,air_freshener,air_freshener/d7d1bee4daa64c0.png,41,2,query,0.009850
...,...,...,...,...,...,...
50178,ziploc_bag,ziploc_bag/872ab49642bf4db.png,2109,3,query,0.124665
50186,ziploc_bag,ziploc_bag/949a8ec3b6154fc.png,2913,4,query,0.017945
50188,ziploc_bag,ziploc_bag/96fc471e354b49f.png,1542,2,query,0.014206
50211,ziploc_bag,ziploc_bag/b4895cfd653e417.png,3498,5,query,0.039904
