In [19]:
import pickle

from pathlib import Path
from tqdm.auto import notebook_tqdm

from pytorch_metric_learning import losses
from pytorch_metric_learning import distances
from pytorch_metric_learning import losses
from pytorch_metric_learning import samplers

import numpy as np

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision import models
from torchvision.utils import make_grid

In [20]:
def create_headless_resnet18():
    model = models.resnet18(pretrained=True, progress=False)
    model = nn.Sequential(*list(model.children())[:-1])
    return model

In [21]:
model = create_headless_resnet18()

In [22]:
data_path = Path("/home/pau/Documents/datasets/MIT_split")
feature_path = Path("./results/retrieval")

In [23]:
transfs = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

catalogue = ImageFolder(str(data_path / "train"), transform=transfs)
queries = ImageFolder(str(data_path / "test"), transform=transfs)

In [24]:
catalogue_meta = [(x[0].split('/')[-1], x[1]) for x in catalogue.imgs]
query_meta = [(x[0].split('/')[-1], x[1]) for x in queries.imgs]

with (feature_path / "catalogue_meta.pkl").open('wb') as f_meta:
    pickle.dump(catalogue_meta, f_meta)

with (feature_path / "query_meta.pkl").open('wb') as f_meta:
    pickle.dump(query_meta, f_meta)

In [25]:
catalogue_data = np.empty((len(catalogue), 512))
with torch.no_grad():
    for ii, (img, _) in enumerate(catalogue):
        catalogue_data[ii, :] = model(img.unsqueeze(0)).squeeze().numpy()

with open(feature_path / "catalogue.npy", "wb") as f:
    np.save(f, catalogue_data)

In [26]:
query_data = np.empty((len(queries), 512))
with torch.no_grad():
    for ii, (img, _) in enumerate(queries):
        query_data[ii, :] = model(img.unsqueeze(0)).squeeze().numpy()

with open(feature_path / "queries.npy", "wb") as f:
    np.save(f, query_data)

In [None]:
with open(feature_path / "queries.npy", "rb") as f:
    query_data = np.load(f)
with open(feature_path / "catalogue.npy", "rb") as f:
    catalogue_data = np.load(f)

with (feature_path / "catalogue_meta.pkl").open('wb') as fc, \
        (feature_path / "query_meta.pkl").open('wb') as fq:
    catalogue_meta = pickle.load(fc)
    query_meta = pickle.load(fq)

In [40]:
from sklearn.neighbors import KNeighborsClassifier

catalogue_labels = np.asarray([x[1] for x in catalogue_meta])
query_labels = np.asarray([x[1] for x in query_meta])

knn = KNeighborsClassifier(n_neighbors=5)
knn = knn.fit(catalogue_data, catalogue_labels)
predictions = knn.predict(query_data)
pr_prob = knn.predict_proba(query_data)

In [43]:
from sklearn.metrics import f1_score, average_precision_score

one_hot = np.zeros((predictions.shape[0], max(predictions) + 1), dtype=int)
one_hot[predictions] = 1

f1 = f1_score(query_labels, predictions, average="macro")
ap = average_precision_score(one_hot, pr_prob)

In [45]:
pr_prob

array([[0. , 0. , 0.8, ..., 0. , 0. , 0. ],
       [0.8, 0.2, 0. , ..., 0. , 0. , 0. ],
       [0. , 0. , 0.4, ..., 0.2, 0. , 0.4],
       ...,
       [0. , 0. , 0. , ..., 0. , 0.2, 0.6],
       [0. , 0. , 0. , ..., 0. , 0.4, 0.4],
       [0. , 0. , 0. , ..., 0. , 0. , 1. ]])

In [44]:
print(ap, f1)

0.014844389654713325 0.3793890191797535
