# PyTorch Metric Learning
See the documentation [here](https://kevinmusgrave.github.io/pytorch-metric-learning/)

## Install the packages

In [None]:
!pip install pytorch-metric-learning
!pip install -q faiss-gpu

## Import the packages

In [10]:
%matplotlib inline
import sys
sys.path.append("/build/pytorch-metric-learning")

import collections
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import lj_inference as ljinf
import lj_common_model as lj_com
from torchvision import transforms
import os
import importlib

In [None]:
# use this cell if update lj_inference.py, otherwise skip
importlib.reload(ljinf)

## Create helper functions

In [11]:
def print_decision(is_match):
    if is_match:
        print("Same class")
    else:
        print("Different class")


mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

inv_normalize = transforms.Normalize(
    mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]
)


def imshow(img, figsize=(8, 4)):
    img = inv_normalize(img)
    npimg = img.numpy()
    plt.figure(figsize=figsize)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

## Constants to Build Arguments String ##

In [24]:
input_size = 320
input_crop = 300
best_iteration = 22
dim = 512
nclasses = 0 # if 0, then use train dir to figure out number classes
data_dir = "/layerjot/tb2/ljcv-datalake-prod/triplet_dataset/20240621/"
batch_size = 24

base_results_dir = "/data/ljcv-model-artifacts/pytorch-metric-learning/results/20240621"
model_results_dir = os.path.join(base_results_dir,
                                 "pymetric.effv2_m.b24.dim512.im320.m5s01",
                                 "lj_saved_models")

backbone = "tf_efficientnetv2_m"
embedder = "embedder_best{}.pth".format(best_iteration)
trunk = "trunk_best{}.pth".format(best_iteration)
eval_device = "cuda" # cpu or gpu ??
base_argstring = "--input-size {} --input-crop {}  --base-output-dir {}".format(input_size, input_crop, model_results_dir)
base_argstring += " --backbone {}".format(backbone)
base_argstring += " --embedder {} --trunk {} --eval-device {}".format(embedder, trunk, eval_device)
base_argstring += " --batch-size {} --dim {} --output-dim {}".format(batch_size, dim, nclasses)
base_argstring += " {}".format(data_dir)

argstring = base_argstring

## Create the InferenceModel wrapper

In [None]:
parser = ljinf.create_inference_parser()
print("parsing {}...".format(argstring))
args = parser.parse_args(argstring.split())
inference_model, dataset_val, labels_to_indices, indices_to_labels = ljinf.build_inference_model_from_cli(args)
print("len labels_to_indices {}".format(len(labels_to_indices)))

# two examples
classA_index = 1
classB_index = 6
classA, classB = labels_to_indices[classA_index], labels_to_indices[classB_index]

## Get nearest neighbors of a query

In [None]:

# get 10 nearest neighbors for a car image
for img_type in [classA, classB]:
    img = dataset_val[img_type[0]][0].unsqueeze(0)
    # query image index
    img_idx = img_type[0]
    print("query index {} query image class {}".format(img_idx, dataset_val[img_type[0]][1]))
    imshow(torchvision.utils.make_grid(img))
    distances, indices = inference_model.get_nearest_neighbors(img, k=11)
    nearest_imgs = []
    for i in indices.cpu()[0]:
        if i == img_idx:
            continue
        nearest_imgs.append(dataset_val[i][0] )
    imshow(torchvision.utils.make_grid(nearest_imgs))

## Analyze all nearest neighbors

In [None]:
query_option = ljinf.NN_query_image.ALL
precision_at_k, recall_at_k, d_for_match, d_for_mismatch = ljinf.nearest_neighbors(inference_model, dataset_val, labels_to_indices, indices_to_labels, query_image=query_option)
print("precisions@k {} recall@k {}".format(precision_at_k, recall_at_k))
# For sheet
print("recall@1 recall@3 recall@5 recall@10 precision@10 {}".format(recall_at_k[0], recall_at_k[2], recall_at_k[4], recall_at_k[9], precision_at_k[9]))
print("{} {} {} {} {}".format(recall_at_k[0], recall_at_k[2], recall_at_k[4], recall_at_k[9], precision_at_k[9]))

In [None]:
# first analysis will just be to determine average and sigma of distances
def basic_metric(d_l, tag):
    print("{} average distance {} sigma {} N {}".format(tag, np.average(d_l), np.std(d_l), len(d_l)))

def basic_metric_by_n(partition, tag):
    list_of_k = sorted(partition.keys())
    for k in list_of_k:
        distances_at_k = [desc[0] for desc in partition[k]]
        basic_metric(distances_at_k, "{}@{}".format(tag, k))

def partition_by_n(dist):
    partition = collections.defaultdict(list)
    for desc in dist:
        d = desc[0]
        n = desc[1]
        ci = desc[2]
        ri = desc[3]
        cl = desc[4]
        rl = desc[5]
        partition[n].append((d, ci, ri, cl, rl))
    return partition
        
d_match_by_n = partition_by_n(d_for_match)
basic_metric_by_n(d_match_by_n, "matched classes")
d_miss_by_n = partition_by_n(d_for_mismatch)
basic_metric_by_n(d_miss_by_n, "missmatched classes")

# missmatched @n examples
for top_i in range(1,4):
    mismatched_indices = []
    mismatched_classes = []
    for desc in d_miss_by_n[top_i]:
        mismatched_indices.append((desc[1], desc[2]))
        mismatched_classes.append((desc[3], desc[4]))
    print("top-n {} {} mismatched images: mismatched labels {}".format(top_i, len(d_miss_by_n[top_i]), mismatched_classes))

    for qi, ri in mismatched_indices:
        query_image = dataset_val[qi][0]
        match_image = dataset_val[ri][0]
        imshow(torchvision.utils.make_grid([query_image, match_image]))   

In [None]:
print(len(dataset_val))

## Compare two images of the same class

In [None]:
# compare two images of the same class
(x, _), (y, _) = dataset_val[classA[0]], dataset_val[classA[1]]
imshow(torchvision.utils.make_grid(torch.stack([x, y], dim=0)))
decision = inference_model.is_match(x.unsqueeze(0), y.unsqueeze(0))
print_decision(decision)

## Compare two images of different classes

In [None]:
# compare two images of a different class
(x, _), (y, _) = dataset_val[classA[0]], dataset_val[classB[0]]
imshow(torchvision.utils.make_grid(torch.stack([x, y], dim=0)))
decision = inference_model.is_match(x.unsqueeze(0), y.unsqueeze(0))
print_decision(decision)

## Compare multiple pairs of images

In [None]:
# compare multiple pairs of images
x = torch.zeros(20, 3, 300, 300)
y = torch.zeros(20, 3, 300, 300)

print("lA {} lB {}".format(len(classA), len(classB)))
for i in range(0, 10, 2):
    x[i] = dataset_val[classA[i]][0]
    x[i + 1] = dataset_val[classB[i]][0]
    y[i] = dataset_val[classA[i + 10]][0]
    y[i + 1] = dataset_val[classB[i + 10]][0]
imshow(torchvision.utils.make_grid(torch.cat((x, y), dim=0), nrow=20), figsize=(30, 3))
decision = inference_model.is_match(x, y)
for d in decision:
    print_decision(d)
print("accuracy = {}".format(np.sum(decision) / len(x)))

## Compare all pairs within a batch

In [None]:
# compare all pairs within a batch
match_matrix = inference_model.get_matches(x)
assert match_matrix[0, 0]  # the 0th image should match with itself
imshow(torchvision.utils.make_grid(torch.stack((x[3], x[4]), dim=0)))
print_decision(match_matrix[3, 4])  # does the 3rd image match the 4th image?

## Compare all pairs between queries and references

In [None]:
# compare all pairs between queries and references
match_matrix = inference_model.get_matches(x, y)
imshow(torchvision.utils.make_grid(torch.stack((x[6], y[6]), dim=0)))
print_decision(match_matrix[6, 6])  # does the 6th query match the 6th reference?

# Get results in tuple form

In [None]:
sim_threshold = 0.95
argstring = "--similarity-threshold {} {}".format(sim_threshold, base_argstring)
args = parser.parse_args(argstring.split())
inference_model, dataset_val, labels_to_indices = ljinf.build_inference_model(args)

# get all matches in tuple form
match_tuples = inference_model.get_matches(x, y, return_tuples=True)
print("MATCHING IMAGE PAIRS")
for i, j in match_tuples:
    print(i, j)
    imshow(torchvision.utils.make_grid(torch.stack((x[i], y[j]), dim=0)))