In [1]:
import os
import numpy as np
import torchvision.transforms as T
import matplotlib.pyplot as plt

from wildlife_datasets import datasets, splits
from analysis import *
from utils import *

In [2]:
model_name = 'MegaDescriptor-L-384'
k = 5
root_datasets = '../wildlife-datasets/data'

In [3]:
dataset_classes = [
    datasets.HyenaID2022,
    datasets.LeopardID2022,
    datasets.NyalaData,    
    datasets.SarahZelvy,
    datasets.SeaTurtleID2022,
    datasets.SeaTurtleIDHeads,
    datasets.StripeSpotter,
    datasets.WhaleSharkID,
    datasets.ZindiTurtleRecall,
]

img_size = int(model_name.split('-')[-1])
for dataset_class in dataset_classes:
    for flip in [True, False]:
        if flip:
            transform = T.Compose([T.RandomHorizontalFlip(1), T.Resize([img_size, img_size]), T.ToTensor(), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
        else:
            transform = T.Compose([T.Resize([img_size, img_size]), T.ToTensor(), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

        dataset_name = dataset_class.__name__
        root = os.path.join(root_datasets, dataset_name)
        file_name = os.path.join('features', f'features_{dataset_name}_flip={flip}_{model_name}.npy')

        d = dataset_class(root)
        if 'bbox' in d.df:
            dataset = WD(d.df, d.root, transform=transform, img_load='bbox')
        else:
            dataset = WD(d.df, d.root, transform=transform)
        extractor = get_extractor(model_name='hf-hub:BVRA/'+model_name, batch_size=32, device='cuda')
        features = get_normalized_features(file_name, dataset, extractor)

In [4]:
def compute_accuracy(features_database, features_query, idx_database, idx_query, labels):
    idx_true, idx_pred = compute_predictions_closed(features_query[idx_query], features_database[idx_database])
    idx_true = idx_query[idx_true]
    idx_pred = idx_database[idx_pred]
    return np.mean(labels[idx_true] == labels[idx_pred[:,0]])

results = {}
for dataset_class in dataset_classes:
    dataset_name = dataset_class.__name__
    d = dataset_class(os.path.join(root_datasets, dataset_name))
    labels = d.df['identity'].to_numpy()
    splitter = splits.ClosedSetSplit(0.8)
    idx_database, idx_query = splitter.split(d.df)[0]

    flip = True
    file_name = os.path.join('features', f'features_{dataset_name}_flip={flip}_{model_name}.npy')
    features_flipped = get_normalized_features(file_name)
    flip = False
    file_name = os.path.join('features', f'features_{dataset_name}_flip={flip}_{model_name}.npy')
    features_normal = get_normalized_features(file_name)

    results[dataset_name] = {
        'd_normal_q_normal': compute_accuracy(features_normal, features_normal, idx_database, idx_query, labels),
        'd_normal_q_flipped': compute_accuracy(features_normal, features_flipped, idx_database, idx_query, labels),
        'd_flipped_q_normal': compute_accuracy(features_flipped, features_normal, idx_database, idx_query, labels),
        'd_flipped_q_flipped': compute_accuracy(features_flipped, features_flipped, idx_database, idx_query, labels),
    }

In [5]:
pd.DataFrame(results).T

Unnamed: 0,d_normal_q_normal,d_normal_q_flipped,d_flipped_q_normal,d_flipped_q_flipped
HyenaID2022,0.784127,0.784127,0.785714,0.788889
LeopardID2022,0.757851,0.752268,0.755059,0.750174
NyalaData,0.39486,0.376168,0.383178,0.383178
SarahZelvy,0.767677,0.767677,0.737374,0.757576
SeaTurtleID2022,0.901024,0.893629,0.902162,0.893629
SeaTurtleIDHeads,0.909151,0.909809,0.909151,0.911784
StripeSpotter,0.993902,0.987805,0.987805,0.987805
WhaleSharkID,0.615142,0.612618,0.617035,0.611987
ZindiTurtleRecall,0.75438,0.753731,0.749838,0.755678
