In [None]:
import pickle
import numpy as np

img_pkl_path = "../data/pets/llava/image_features.pkl"
with open(img_pkl_path, "rb") as f:
    img_features = pickle.load(f)

txt_pkl_path = "../data/pets/llava/label_features.pkl"
with open(txt_pkl_path, "rb") as f:
    txt_embedding_dict = pickle.load(f)

labels_name = list(txt_embedding_dict.keys())

weights = []
for label in labels_name: 
    normalized_class_feature = txt_embedding_dict[label]['DF'] / np.linalg.norm(txt_embedding_dict[label]['DF'])
    weights.append(normalized_class_feature)
model = {"weights": np.vstack(weights).T, "class_names": labels_name}
print(model['weights'].shape)

img_names = list(img_features.keys())
def classify(image, classifier):
    # IF
    image_feature = img_features[image]['img_emb']
    image_feature /= np.linalg.norm(image_feature)
    
    # PF
    prediction_feature = img_features[image]['init_pred']
    prediction_feature /= np.linalg.norm(prediction_feature)

    # DF
    description_feature = img_features[image]['img_desc']
    description_feature /= np.linalg.norm(description_feature)

    # DF_PF_IF
    query_feature = image_feature + prediction_feature + description_feature
    query_feature /= np.linalg.norm(query_feature)

    # Predicting
    index = np.argmax(np.matmul(query_feature, classifier["weights"]))
    return classifier["class_names"][index.squeeze()]

def evaluation(img):
    import pandas as pd
    val_img_info_file = "../data/pets/test_stratified_subset_5.csv"
    val_img_info_file = pd.read_csv(val_img_info_file)
    val_img_info_file = val_img_info_file[val_img_info_file['image'] == img]
    val_img_info_file['name'] = val_img_info_file['name'].str.lower()
    return val_img_info_file['name'].values[0]

right_count = 0
for img in img_names:
    pred = classify(img, model)
    pred = pred.lower()
    if classify(img, model) == evaluation(img):
        right_count += 1
        print(f'{img}: pred-{classify(img, model)} gt-{evaluation(img)}')

print(f'Accuracy: {right_count/len(img_names)}')

(768, 37)
american_bulldog_13: pred-american bulldog gt-american bulldog
american_bulldog_94: pred-american bulldog gt-american bulldog
american_bulldog_49: pred-american bulldog gt-american bulldog
american_bulldog_203: pred-american bulldog gt-american bulldog
basset_hound_13: pred-basset hound gt-basset hound
basset_hound_94: pred-basset hound gt-basset hound
beagle_142: pred-beagle gt-beagle
beagle_50: pred-beagle gt-beagle
beagle_184: pred-beagle gt-beagle
chihuahua_13: pred-chihuahua gt-chihuahua
chihuahua_94: pred-chihuahua gt-chihuahua
chihuahua_50: pred-chihuahua gt-chihuahua
chihuahua_191: pred-chihuahua gt-chihuahua
english_cocker_spaniel_12: pred-english cocker spaniel gt-english cocker spaniel
great_pyrenees_13: pred-great pyrenees gt-great pyrenees
great_pyrenees_94: pred-great pyrenees gt-great pyrenees
great_pyrenees_180: pred-great pyrenees gt-great pyrenees
great_pyrenees_191: pred-great pyrenees gt-great pyrenees
newfoundland_12: pred-newfoundland gt-newfoundland
new