In [1]:
import sys
sys.path.append('..')
sys.path.append('../../wildlife-tools')

import os
import numpy as np
import torchvision.transforms as T
import matplotlib.pyplot as plt

from wildlife_datasets import datasets
from analysis import *
from utils import *

In [2]:
model_name = 'MegaDescriptor-L-384'
k = 5
root_datasets = '../data'

In [6]:
dataset_classes = [
    datasets.HyenaID2022,
    datasets.LeopardID2022,
    datasets.NyalaData,    
    datasets.SarahZelvy,
    datasets.SeaTurtleIDHeads,
    datasets.StripeSpotter,
    datasets.WhaleSharkID,
    datasets.ZindiTurtleRecall,
]

img_size = int(model_name.split('-')[-1])
for dataset_class in dataset_classes:
    for flip in [True, False]:
        if flip:
            transform = T.Compose([T.RandomHorizontalFlip(1), T.Resize([img_size, img_size]), T.ToTensor(), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
        else:
            transform = T.Compose([T.Resize([img_size, img_size]), T.ToTensor(), T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

        dataset_name = dataset_class.__name__
        root = os.path.join(root_datasets, dataset_name)
        file_name = os.path.join('features', f'features_{dataset_name}_flip={flip}_{model_name}.npy')

        d = dataset_class(root)
        if 'bbox' in d.df:
            dataset = WD(d.df, d.root, transform=transform, img_load='bbox')
        else:
            dataset = WD(d.df, d.root, transform=transform)
        extractor = get_extractor(model_name='hf-hub:BVRA/'+model_name, batch_size=32, device='cuda')
        features = get_normalized_features(file_name, dataset, extractor)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|█████████████████████████████████████████████████████████████| 214/214 [06:50<00:00,  1.92s/it]
100%|█████████████████████████████████████████████████████████████| 214/214 [06:39<00:00,  1.87s/it]
100%|███████████████████████████████████████████████████████████████| 61/61 [01:39<00:00,  1.62s/it]
100%|███████████████████████████████████████████████████████████████| 61/61 [01:37<00:00,  1.60s/it]
100%|███████████████████████████████████████████████████████████████| 26/26 [00:48<00:00,  1.85s/it]
100%|███████████████████████████████████████████████████████████████| 26/26 [00:42<00:00,  1.63s/it]
100%|█████████████████████████████████████████████████████████████| 241/241 [11:24<00:00,  2.84s/it]
100%|█████████████████████████████████████████████████████████████| 241/241 [10:11<00:00,  2.54s/it]


In [7]:
from wildlife_datasets import splits

for dataset_class in dataset_classes:
    print(dataset_class)
    for flip in [True, False]:
        dataset_name = dataset_class.__name__
        root = os.path.join(root_datasets, dataset_name)
        file_name = os.path.join('features', f'features_{dataset_name}_flip={flip}_{model_name}.npy')

        d = dataset_class(root)
        features = get_normalized_features(file_name)

        splitter = splits.ClosedSetSplit(0.8)
        idx_database, idx_query = splitter.split(d.df)[0]

        idx_true, idx_pred = compute_predictions_closed(features[idx_query], features[idx_database])
        idx_true = idx_query[idx_true]
        idx_pred = idx_database[idx_pred]

        labels = d.df['identity'].to_numpy()
        accuracy = np.mean(labels[idx_true] == labels[idx_pred[:,0]])

        print(accuracy)

<class 'wildlife_datasets.datasets.datasets.HyenaID2022'>
0.7888888888888889
0.7841269841269841
<class 'wildlife_datasets.datasets.datasets.LeopardID2022'>
0.7501744591765527
0.7578506629448709
<class 'wildlife_datasets.datasets.datasets.NyalaData'>
0.38317757009345793
0.39485981308411217
<class 'wildlife_datasets.datasets.datasets.SarahZelvy'>
0.7575757575757576
0.7676767676767676
<class 'wildlife_datasets.datasets.datasets.SeaTurtleIDHeads'>
0.9117840684660962
0.9091507570770243
<class 'wildlife_datasets.datasets.datasets.StripeSpotter'>
0.9878048780487805
0.9939024390243902
<class 'wildlife_datasets.datasets.datasets.WhaleSharkID'>
0.61198738170347
0.6151419558359621
<class 'wildlife_datasets.datasets.datasets.ZindiTurtleRecall'>
0.7556781310837118
0.754380272550292
