In [30]:
import pandas as pd
import numpy as np
import torch.nn.functional as F
from tqdm.auto import tqdm
from playground.preprocessing import clip_model, clip_processor, model_device

# for experiments
from sklearn import svm, linear_model
from sklearn.metrics import average_precision_score
from playground.linear_model import LinearModel
import importlib
import playground.linear_model
importlib.reload(playground.linear_model)
from playground.linear_model import LinearModel

In [2]:
## compute embedding first time: this will take a few minutes
## need to download objectnet first (and crop 224x224)
#from playground.load import objectnet_dataset, extract_image_vectors
#df = extract_image_vectors(objectnet_dataset)
#df.to_parquet('data/objectnet/embeddings.parquet')

In [16]:
# load precomputed embeddings
objectnet_df = pd.read_parquet('data/objectnet/embeddings.parquet')

In [130]:
# work with unit length vectors for dot product (works better for CLIP embeddings)
objectnet_df = objectnet_df.assign(normalized_vectors=[vec for vec in np.stack(objectnet_df['vectors']) / np.linalg.norm(objectnet_df['vectors'])])

# pick a random image from each class to be the query, the rest will be the test database
np.random.seed(13)
objectnet_df = objectnet_df.assign(random_id=np.random.permutation(objectnet_df.shape[0]))
objectnet_df = objectnet_df.assign(group_rank=objectnet_df.groupby('label')['random_id'].rank(method='first').astype('int'))
objectnet_df = objectnet_df.assign(split=objectnet_df.group_rank.apply(lambda x: 'query' if x <= 11 else 'test'))

search_query_df = objectnet_df[objectnet_df.split == 'query']
test_df = objectnet_df[objectnet_df.split == 'test']

# from the test set, take a random sample of the DB which we will use as pseudo-negative examples
# while training some of the linear models
number_svm_train_examples = 1000
random_sample = np.random.permutation(test_df.shape[0])[:number_svm_train_examples]
Xneg = np.stack(test_df.iloc[random_sample].normalized_vectors.values)
yneg = np.zeros(Xneg.shape[0])

# the full test set used for evaluation
Xtest = np.stack(test_df.normalized_vectors.values)

In [140]:
def get_ys(df, label : str):
    ''' get binary labels for a given class '''
    assert label in df.label.unique()
    return np.where(df.label == label, 1, 0)

def get_text_embedding(text : str, prompt_template='A picture of a {}'):
    ''' get CLIP vector representation of text query '''
    text = text.replace('_', ' ')
    text = prompt_template.format(text)
    query_tokens = clip_processor(text=[text], return_tensors='pt')
    query_vector = clip_model.get_text_features(query_tokens['input_ids'].to(model_device))
    query_vector = F.normalize(query_vector)
    query_vector = query_vector.cpu().detach().numpy().reshape(-1)
    return query_vector

def eval_method(query_df, vector_function):
    ''' run a given functrion over different categories on the dataset and compute AP '''
    aps = []
    for (idx, row) in tqdm(query_df.iterrows(), total=query_df.shape[0]):
        query_vector = vector_function(row)
        scores = Xtest @ query_vector
        y_true = get_ys(test_df, row.label)
        ap = average_precision_score(y_true, scores)
        aps.append(ap)
    return np.array(aps)


## Different methods to get a vector which we can use as a query for the image search

def get_vector_from_text(row):
    ''' get CLIP vector representation of text query, aka zero-shot search '''
    # (simply return the vector rep. of the text query)
    return get_text_embedding(row.label)

def get_vector_from_knn(row):
    ''' get the vector representation of the row, aka nearest neighbor search '''
    # (simply return the vector rep. of the image)
    return row.normalized_vectors

def get_vector_from_exemplar_svm(row):
    ''' ExemplarSVM: get the vector representation from using one positive example, and a random sample
    labeled as negative, train using SVM and use this for the vector lookup '''
    clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=1.)
    Xpos = row.normalized_vectors.reshape(1, -1)
    X = np.concatenate([Xpos, Xneg], axis=0)
    y = np.concatenate([np.ones(1), yneg])
    clf.fit(X, y)
    return clf.coef_.reshape(-1)

def get_vector_from_exemplar_logistic_reg(row, C):
    ''' Similar to ExemplarSVM, but using logistic regression instead.'''
    # fit_intercept=False is important for this LR to work nearly as well as SVM
    clf = linear_model.LogisticRegression(class_weight='balanced', fit_intercept=False, verbose=False, max_iter=10000, tol=1e-6, C=C)
    Xpos = row.normalized_vectors.reshape(1, -1)
    random_sample = np.random.permutation(test_df.shape[0])[:number_svm_train_examples]
    Xneg = np.stack(test_df.iloc[random_sample].normalized_vectors.values)
    yneg = np.zeros(Xneg.shape[0])
    X = np.concatenate([Xpos, Xneg], axis=0)
    y = np.concatenate([np.ones(1), yneg])
    clf.fit(X, y)
    return clf.coef_.reshape(-1)

def get_vector_from_exemplar_svm_plus_text_reg(row):
    ''' Similar to ExemplarSVM, but using a linear model with a regularizer term based on the text query '''
    regularizer_vector  = get_vector_from_text(row)
    clf = LinearModel(class_weight='balanced', label_loss_type='hinge_squared_loss', reg_norm_lambda=1.,
                      verbose=False, max_iter=3,
                      regularizer_vector=regularizer_vector, reg_vector_lambda=1000.)
    Xpos = row.normalized_vectors.reshape(1, -1)
    random_sample = np.random.permutation(test_df.shape[0])[:number_svm_train_examples]
    Xneg = np.stack(test_df.iloc[random_sample].normalized_vectors.values)
    yneg = np.zeros(Xneg.shape[0])
    X = np.concatenate([Xpos, Xneg], axis=0)
    y = np.concatenate([np.ones(1), yneg])
    clf.fit(X, y) # train
    coeff = clf._module.weight.detach().cpu().numpy().reshape(-1)
    return coeff

In [137]:
text_ap = eval_method(search_query_df, get_vector_from_text)
knn_ap = eval_method(search_query_df, get_vector_from_knn)

100%|██████████| 3443/3443 [05:55<00:00,  9.69it/s]  
100%|██████████| 3443/3443 [00:51<00:00, 66.55it/s]


In [131]:
svm_ap = eval_method(search_query_df, get_vector_from_exemplar_svm)

100%|██████████| 3443/3443 [12:49<00:00,  4.47it/s]   


In [132]:
svm_reg_ap = eval_method(search_query_df, get_vector_from_exemplar_svm_plus_text_reg)

100%|██████████| 3443/3443 [24:35<00:00,  2.33it/s]   


In [139]:
search_query_df = search_query_df.assign(svm_ap=svm_ap,  svm_reg_ap=svm_reg_ap,
                                         knn_ap=knn_ap, text_ap=text_ap
                                         )
by_query = search_query_df
display(by_query[['svm_ap', 'svm_reg_ap','knn_ap','text_ap']].mean())

display((by_query.svm_ap > by_query.knn_ap).mean())
display((by_query.svm_reg_ap > by_query.text_ap).mean())

svm_ap        0.099397
svm_reg_ap    0.250966
knn_ap        0.094210
text_ap       0.237345
dtype: float64

0.5832123148417078

0.8071449317455707