In [None]:
import nets
import torch
import sys
import yaml
import utils
import torchvision.io
import matplotlib.pyplot as plt
from torchvision.transforms import v2, ToTensor
from PIL import Image
import numpy as np
import cv2 as cv

device = 'cpu'

def register_extraction_hook(hook, module: torch.nn.Module):
    return module.register_forward_hook(hook)
    
def plot_imgs(images, shape: tuple | None = None):
    #subplot(r,c) provide the no. of rows and columns
    if shape is None:
        n_images = len(images)
        max_img_row = 5
        ncols = min(n_images, max_img_row)
        nrows = math.ceil(len(images) / ncols)
    else:
        nrows, ncols = shape
    f, axarr = plt.subplots(nrows,ncols, figsize=(3.5*ncols, 3*nrows)) 
    
    # use the created array to output your multiple images
    for img, ax in zip(images, axarr.ravel()):
        ax.axis('off')
        ax.imshow(img.permute(1,2,0))

color_conv = {
    0: (0,    0,    0), # unlabeled
    1: (200,    0,  0), # industrial land
    2: (250,    0,150), # urban residential
    3: (200, 150, 150), # rural residential
    4: (250, 150, 150), # traffic land
    5: (0,     200, 0), # paddy field
    6: (150,  250,  0), # irrigated cropland
    7: (150, 200, 150), # dry cropland
    8: (200,   0, 200), # garden plot
    9: (150,   0, 250), # arbor woodland
    10:(150,  150,250),  # shrub land
    11:(250,  200,  0),  # natural grass land
    12:(200,  200,  0),  # artificial grass land
    13:(0,     0, 200),  # river
    14:(0,   150, 200),  # lake
    15:(0,   200, 250)  # pond
}

def mask_to_color(mask):
    global color_conv
    out = torch.zeros((3, mask.shape[0], mask.shape[1]), dtype=torch.uint8)
    for r in range(mask.shape[0]):
        for c in range(mask.shape[1]):
            out[:, r, c] = torch.asarray(color_conv[mask[r,c].item()])
    return out

def class_hist(mask, nclasses):
    out = torch.zeros(nclasses)
    for cl in range(nclasses):
        out[cl] = torch.count_nonzero(mask == cl) / mask.numel()
    return out


import torchmetrics.functional.segmentation
import torchmetrics.functional.classification
from prettytable import PrettyTable

def per_image_jaccard(target_mask_idx, query_masks_idx):
    result = torch.zeros(query_masks_idx.shape[0])
    for i in range(query_masks_idx.shape[0]):
        result[i] = torchmetrics.functional.classification.multiclass_jaccard_index(target_mask_idx, query_masks_idx[i].unsqueeze(0), 25, 'weighted', 0)
    return result

def pixel_precision(target_mask_idx, query_masks_idx):
    return torch.sum((query_masks_idx == target_mask_idx) * (target_mask_idx != 0), dim=(-1,-2)) / target_mask_idx.count_nonzero()


def per_image_emd(target_mask_idx, query_masks_idx, nbins):
    # based on the work in https://arxiv.org/abs/1611.05916
    def mask_to_norm_hist(mask):
        mask = mask.type(torch.float32)
        hist = torch.histc(mask, bins=nbins-1, min=1, max=nbins)
        hist = hist.div(torch.sum(hist))
        return hist.numpy()

    target_sig = mask_to_norm_hist(target_mask_idx)
    cs_ts = target_sig.cumsum()
    result = np.empty(query_masks_idx.shape[0])
    for i in range(query_masks_idx.shape[0]):
        query_sig = mask_to_norm_hist(query_masks_idx[i])
        # Ad-hoc formulation has been found to match opencv's results
        # bins = np.arange(start=1, stop=nbins, dtype=np.float32)
        # target_sig = target_sig.astype(np.float32)
        # query_sig = query_sig.astype(np.float32)
        # result[i] = cv.EMD(np.hstack([target_sig.reshape(-1,1), bins.reshape(-1,1)]), np.hstack([query_sig.reshape(-1,1), bins.reshape(-1,1)]), cv.DIST_L1)[0]
        result[i] = np.sum(np.abs(cs_ts - query_sig.cumsum()))
    return result
    
def compute_scores(targ_mask_idx, query_masks_idx, num_classes):
    scores = {}
    scores['miou'] = torchmetrics.functional.segmentation.mean_iou(targ_mask_idx.expand((query_masks_idx.shape[0], -1, -1)), query_masks_idx, num_classes, False)
    scores['wiou'] = per_image_jaccard(targ_mask_idx, query_masks_idx)
    scores['pprec'] = pixel_precision(targ_mask_idx, query_masks_idx)
    scores['emd'] = per_image_emd(targ_mask_idx, query_masks_idx, num_classes)
    return scores

def scores_table(sc, n_img):
    table = PrettyTable(0)
    table.field_names = ["score"] + [f"query{i}" for i in range(n_img)]
    for idx, key in enumerate(sc):
        table.add_row([key] + sc[key].tolist())
    print(table)

def get_best_match_idx(selected_score, scores):
    if selected_score == 'miou' or selected_score == 'wiou' or selected_score == 'pprec':
        selector = torch.argmax
    elif selected_score == 'emd':
        selector = torch.argmin
    else:
        raise ValueError(f"'{selected_score}' is not a valid metric")
    return selector(scores[selected_score])

def plot_best_match_by_score(targ, quer, targ_m_color, quer_m_color, selected_score, scores):
    best_el = get_best_match_idx(selected_score, scores)
    plot_imgs([targ, quer[best_el], targ_m_color, quer_m_color[best_el]], (2,2))

import math   
def visualize_activations(input_act, layer_name=""):
    input_act = input_act.numpy()
    nch, m_h, m_w = input_act.shape
    figscale_factor = 512
    if nch > figscale_factor:
        figscale = math.ceil(nch / figscale_factor)
    else:
        figscale = 1
    #Create figure and axes
    fig = plt.figure(figsize=(figscale*12, figscale*8))
        
    #Set up title handling negative layer indexings
    fig.suptitle("Activation maps for '" + layer_name +
                 "'\nLayer shape: " + str(input_act.shape), fontsize=12*figscale)
    # if layer_idx >= 0:
    #     fig.suptitle("Activations for '" + layer_name + "' " 
    #                  + layer_typename + " layer (" + str(layer_idx) + "/" 
    #                  + str(len(layers_activation_maps) - 1) 
    #                  + ")\nLayer shape: " + str(layer_dims[1:]),
    #                  fontsize = 12 * figscale
    #                 )
    # else:
    #     fig.suptitle("Activations for '" + layer_name + "' " 
    #                  + layer_typename + " layer (" 
    #                  + str(len(layers_activation_maps) 
    #                  + layer_idx) + "/" + str(len(layers_activation_maps) - 1) 
    #                  + ") (idx = " + str(layer_idx) 
    #                  + ")\nLayer shape: " + str(layer_dims[1:]),
    #                  fontsize = 12 * figscale
    #                 )
                     
    #Plot each 2D activation map channel in grid for selected layer          
    #Calculate the number of rows and columns needed to arrange
    #the activation maps into a nearly-square grid. The number of 
    #maps is the number of channels in the convolutional layer output
    col_size = math.ceil(nch ** 0.5)
    row_size = math.ceil(nch / col_size)

    #Get image size of each channel activation map
    act_map_shape = (m_h, m_w)

    #Create a blank grid image with borders
    border_thickness = 1
    grid_image_shape = (
        row_size * (act_map_shape[0] + border_thickness) + border_thickness,
        col_size * (act_map_shape[1] + border_thickness) + border_thickness,
    )

    #Initialise image background
    grid_image = np.empty(grid_image_shape, dtype=np.float32)
    grid_image[:] = np.nan

    #Place images in the grid
    for ch_idx in range(nch):
        #Get row and column coordinates
        row = ch_idx // col_size
        col = ch_idx % col_size
        #Set start coordinates of region of grid image to update
        x = col * (act_map_shape[1] + border_thickness) + border_thickness
        y = row * (act_map_shape[0] + border_thickness) + border_thickness
        #Update grid image values with activation map
        grid_image[y : y + act_map_shape[0],
                x : x + act_map_shape[1]] = input_act[ch_idx, :, :]
    #Map NaNs to black to form borders between activation maps
    cmap = plt.cm.viridis
    cmap.set_bad('black', 1.)

    #Turn off x-ticks
    plt.xticks([],[])

    #Add y-ticks labelling rows with their activation map channel ranges
    # Generate labels
    ytick_labels = [str(i*col_size) + " - " + str(i*col_size + row_size - 1) 
                    for i in range(col_size - 1)] \
                 + [str((row_size - 1) * col_size) + " - " + str(nch - 1)]
    #Generate locations
    ytick_locs = [i*(grid_image_shape[1] // len(ytick_labels)) 
                  + 0.5 * (act_map_shape[1] + border_thickness)
                  for i in range(len(ytick_labels))]

    #Set y-tick locations and labels
    plt.yticks(ytick_locs, ytick_labels, fontsize = 10 * figscale)
    plt.imshow(grid_image, cmap=cmap)

# Load images
To get some masks to work on, create a directory with one image called "target" (with any extension supported by PIL) and all the other images with prefix "query".

In [None]:
import glob
img_dir = "img_directory"
target_fname = glob.glob("target*", root_dir=img_dir)
tensor_converter = v2.Compose([v2.PILToTensor()])
#read base image
target = tensor_converter(Image.open(img_dir + target_fname[0]))[:3,:,:]
#read target images
# query_dim = 224
# cropper = v2.RandomCrop(query_dim)
# big_img = tensor_converter(Image.open(target_path))[1:,:,:]
# queries = [cropper(big_img) for _ in range(10)]

query_fnames = glob.glob("query*", root_dir=img_dir)
queries = [tensor_converter(Image.open(img_dir+q))[:3,:,:] for q in query_fnames]

In [None]:
plt.axis('off')
plt.imshow(target.permute(1,2,0))

In [None]:
plot_imgs(queries)

# U-Net

In [None]:
n_cls = 15
chkpt = "your/checkpoint"
net = utils.load_network({"net": "Unetv2", "num_classes": n_cls}, device)
checkpoint = torch.load(chkpt, map_location=torch.device(device))
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

def take_target_features(module, args, output):
    global target_features
    target_features = output.clone()

def take_query_features(module, args, output):
    global query_features
    query_features = output.clone()

layer = net.encode5

print(net)

# DeepLabV3-MobileNet

In [None]:
n_cls = 15
chkpt = "your/checkpoint"
net = utils.load_network({"net": "MobileNet", "num_classes": n_cls, 'backbone': 'mobilenet'}, device)
checkpoint = torch.load(chkpt, map_location=torch.device(device))
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

def take_target_features(moudle, args, output):
    global target_features
    target_features = output['out'].clone()

def take_query_features(module, args, output):
    global query_features
    query_features = output['out'].clone()

layer = net.model.backbone

print(net)

# DeepLabV3-ResNet101

In [None]:
n_cls = 15
chkpt = "your/checkpoint"
net = utils.load_network({"net": "Resnet101", "num_classes": n_cls}, device)
checkpoint = torch.load(chkpt, map_location=device)
net.custom_load(checkpoint)
net.eval()

def take_target_features(moudle, args, output):
    global target_features
    target_features = output['out'].clone()

def take_query_features(module, args, output):
    global query_features
    query_features = output['out'].clone()

layer = net.model.backbone

print(net)

# Segformer

In [None]:
n_cls = 15
chkpt = "your/checkpoint"
net = utils.load_network({"net": "SegformerMod", "num_classes": n_cls}, device)
checkpoint = torch.load(chkpt, map_location=device)
net.custom_load(checkpoint)
net.eval()

def take_query_features(module, args, output):
    global query_features
    query_features = output.last_hidden_state.clone()

def take_target_features(module, args, output):
   global target_features
   target_features = output.last_hidden_state.clone()

layer = net.segformer.segformer.encoder

print(net)

Here you can run a forward pass with feature extraction. If you don't want to extract any features, simply "pass" in the body of both hooks

In [None]:
target_features = None
query_features = None

with torch.no_grad():
    handler = register_extraction_hook(take_target_features, layer)
    #input = target.unsqueeze(0).type(torch.float32)
    target_mask = net(target.unsqueeze(0).type(torch.float32))
    handler.remove()
    handler = register_extraction_hook(take_query_features, layer)
    query_masks = net(torch.stack(queries).type(torch.float32))
    handler.remove()

## Visualize activations
You can visualize the produced activations by running the following code. If you want to customize the visualization itself, look in the first block of this notebook!

In [None]:
visualize_activations(target_features.squeeze()[569:578], "MobileNet")

Maybe you wanto to visualize a specific activation map...

In [None]:
def visualize_activation_channel(features, channel):
    plt.matshow(features[channel], cmap='viridis')

visualize_activation_channel(target_features.squeeze(), 820)

Get the final masks by argmax along output channels

In [None]:
query_masks_idx = torch.argmax(query_masks, dim=1)
target_mask_idx = torch.argmax(target_mask, dim=1)

Color conversion for plotting...

In [None]:
target_mask_color = mask_to_color(target_mask_idx.squeeze())
query_masks_color = [mask_to_color(mask) for mask in list(query_masks_idx[:])]

In [None]:
plt.axis('off')
plt.imshow(target_mask_color.permute(1,2,0))

In [None]:
plot_imgs(query_masks_color)

In [None]:
queries[0].dtype

# Estimating similarity
Our main goal is that of finding which query image can be most closely associated with the target image.
We start our analysis by computing various scores between the segmentation mask of our target image and the queries.
## Selected scores
- **mean IoU**
- **weighted IoU**
- **pixel overlap**: simply the percentual amount of overlap between two segmentation masks as $$po(t,q) = \frac{1}{N}\sum_{s=1}^{N}\mathbb{1}[t_s=q_s]$$
- **Earth Mover's Distance**: EMD between segmentation mask histograms.

## Some thoughts on scores and the concept of similarity in semantic segmentation
The main question that should guide this quest on esimating similarity is: what do we mean by "similar" images in the case of semantic segmentation of satellite colour images? I want to start with some consideration on the metrics used. 

Both IoU metrics as well as pixel overlap are heavily influenced by the spatial characteristics of images (and masks). As a thought experiment, think of a shoreline, obviously observed in bird's eye view: we can easily picture a very simmetric image, with a line created where the water meets the sand; now think of the same image but flipped by 180 degrees. Sand and water "swithced places", but the overall content is the very same. How would such an image score with these metrics, taking the original as our ground truth? It's trivial to conclude that there would be little overlap for both classes (observe that the overlap for at least one class grows the more the separation line is distant from the center of the image). If we instead take an image where only sand (or water) is visible, this image would counterintuitively score higher than the transformed image, even though we could consider it semantically different since one object class is completely missing.

EMD on the other hand completely discards spatial information, since it measures the minimal effort required to "morph" one distribution into another: in our case the distribution is the dense pixel classifications output by the segmentation model (which is represented using normalized histograms). Going back to the previous example, now the transformed image would be the highest scoring, since it contains the exact same pixels.
> In fact the EMD paper defines a ground distance that considers the spatial information of a pixel plus its color...

Moreover, we impose an ordering on classes, such that similar classes will be closer together. Look at [this](https://arxiv.org/abs/1611.05916) paper for more info. The bottom line is that EMD for an ordered-class, L1 ground distance matrix has a closed-form solution, so it's very easy to compute (code in the first cell of this notebook...)

In [None]:
scores = compute_scores(target_mask_idx, query_masks_idx, 16)
scores_table(scores, len(query_masks_idx))

Get the best match for a specific score

In [None]:
score = 'emd'
plot_best_match_by_score(target, queries, target_mask_color, query_masks_color, score, scores)

Sort images by score, then plot them and their masks

In [None]:
sorted_by_score = sorted(zip(scores[score], range(len(queries))))
plot_imgs([queries[i[1]] for i in sorted_by_score]+[query_masks_color[i[1]] for i in sorted_by_score], (2, len(queries)))

You can also inspect activation maps for queries...

In [None]:
visualize_activations(query_features[get_best_match_idx('emd', scores)])

This works on output planes too!

In [None]:
visualize_activations(target_mask.squeeze(), "output masks")

In [None]:
visualize_activations(query_masks[2])

# Exploring DINO ViT features for retrieval

In [None]:
import sys
if 'utils' in sys.modules:
    sys.modules.pop('utils')
    del utils
dino_vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8').eval()
dino_vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16').eval()
dino_vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16').eval()
dino_vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8').eval()
if 'utils' in sys.modules:
    sys.modules.pop('utils')
import utils

> WARNING! RUN ONLY IF YOU HAVE A FINETUNED FULL CHECKPOINT (student + teacher nets)

In [None]:
# finetuned dino here...

finetuned_dino_chkp= torch.load("finetuned.pth", map_location=torch.device(device))
checkpoint_state_dict_mod = {}
checkpoint_state_dict = finetuned_dino_chkp['student']
for item in checkpoint_state_dict:
    s = str(item)
    if 'module.backbone' in s:
        checkpoint_state_dict_mod[s.replace('module.backbone.', '')] = checkpoint_state_dict[item]
dino_vits16.load_state_dict(checkpoint_state_dict_mod)

In [None]:
with torch.no_grad():
    vit_feats_target = dino_vits8(target.unsqueeze(0))
    vit_feats_queries = dino_vits8(torch.stack(queries))

DINO ViT's features have shown to be reliable for k-NN classification and copy detection. Should be able to find the most similar image by computing the L2 distance between the produced features

In [None]:
l2 = torch.cdist(vit_feats_queries.unsqueeze(0), vit_feats_target.unsqueeze(0)).squeeze()
best_el_l2 = torch.argmin(l2)
plot_imgs([target, queries[best_el_l2], target_mask_color, query_masks_color[best_el_l2]], (2,2))

In [None]:
cosine = torch.cosine_similarity(vit_feats_target, vit_feats_queries)
best_el_cos = torch.argmax(cosine)
plot_imgs([target, queries[best_el_cos], target_mask_color, query_masks_color[best_el_cos]], (2,2))

# Retrieval

## Setting up databases
Here you can setup the two databases. Nothing fancy, just read files and put them into tensors and lists.

In [None]:
import glob

db_dir = 'db/directory'
query_dir = 'query/directory'
db_files = glob.glob(db_dir+'*')
query_files = glob.glob(query_dir+'*')

# read all files to retrieve
retr = []
for img in db_files:
    retr.append(torchvision.io.read_image(img)[:3,:,:])
to_retrieve = []
for img in query_files:
    to_retrieve.append(torchvision.io.read_image(img)[:3,:,:])

Visualize your images:

In [None]:
plot_imgs(to_retrieve)

## Model inference
Extract segmentation masks with one of the models you instantiated.
To change model, just run the correspoing cell at the beginning of this notebook.

In [None]:
# compute masks for all these images
with torch.no_grad():
    retr_masks = []
    for t in retr:
        retr_masks.append(torch.argmax(net(t.unsqueeze(0).type(torch.float32)), dim=1))

    to_retrieve_masks = []
    for t in to_retrieve:
        to_retrieve_masks.append(torch.argmax(net(t.unsqueeze(0).type(torch.float32)), dim=1))


## Computing EMD
EMD is computed taking as reference the images into the "query" directory. For each image, the score is computed by taking every image present in the "db" directory.
Then we sort the images based on scores. For each query image, a list of tuples containing score and corresponding "db" image index number is produced.

In [None]:
#compute emd for every image
emds = []
for m in to_retrieve_masks:
    emds.append(per_image_emd(m, torch.stack(retr_masks), 16))

sorted_emds = [sorted(zip(emds[img_idx].tolist(), range(len(retr)))) for img_idx in range(len(to_retrieve))]

## DINO embeddings
Get dino embeddings. Running inference on single images (NO batch mode!) is a lot more memory friendly...

In [None]:
# compute vit embeddings
vit = dino_vits16
emb_retr = []
emb_to_retrieve = []
with torch.no_grad():
    for i in range(len(retr)):
        emb_retr.append(vit(retr[i].type(torch.float32).unsqueeze(0)).squeeze())
    for i in range(len(to_retrieve)):
        emb_to_retrieve.append(vit(to_retrieve[i].type(torch.float32).unsqueeze(0)).squeeze())

We get the similarity with the L2 distance between the "query" embeddings and all "db" images.

In [None]:
# now l2 distance between embeddings
l2_dist = torch.cdist(torch.stack(emb_to_retrieve), torch.stack(emb_retr))
sorted_l2 = [sorted(zip(l2_dist[img_idx].tolist(), range(len(retr)))) for img_idx in range(len(to_retrieve))]

Select one image and visualize the results. You can select the top N retrieved result by similarity.

In [None]:
# select your image...
selected_img_idx = 7
selected_img = to_retrieve[selected_img_idx]
topN = 5

In [None]:
plt.axis('off')
plt.imshow(selected_img.permute(1,2,0))

In [None]:
plot_imgs([retr[best[1]] for best in sorted_emds[selected_img_idx][:topN]])

In [None]:
plot_imgs([retr[best[1]] for best in sorted_l2[selected_img_idx][:topN]])

In [None]:
def scores_evaluation(emd, other_score):
    top1 = 0.
    top5 = 0.
    top10 = 0.
    top5to5 = 0.
    top5to10 = 0.
    top5to20 = 0.
    top5to50 = 0.
    dim = len(emd)
    for i in range(dim):
        emd_best5 = [el[1] for el in emd[i][0:5]]
        l2_bests = [el[1] for el in other_score[i][0:50]]
        top1 += (emd_best5[0] == l2_bests[0]) / dim
        top5 += (emd_best5[0] in l2_bests[:5]) / dim
        top10 += (emd_best5[0] in l2_bests[:10]) / dim
        top5to5 += sum(d in l2_bests[:5] for d in emd_best5) / (5*dim)
        top5to10 += sum(d in l2_bests[:10] for d in emd_best5) / (5*dim)
        top5to20 += sum(d in l2_bests[:20] for d in emd_best5) / (5*dim)
        top5to50 += sum(d in l2_bests for d in emd_best5) / (5*dim)

    print(top1, top5, top10, top5to5, top5to10, top5to20, top5to50)

In [None]:
scores_evaluation(sorted_emds, sorted_l2)

### GID
| model | top1 | top5 | top10 | top5to5 | top5to10 | top5to20 | top5to50 |
|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| s8 | 0.1 | 0.1 | 0.2 | 0.1 | 0.16 | 0.26 | 0.48 |
| s16 | 0.1 | 0.1 | 0.1 | 0.08 | 0.17999999999999997 | 0.23999999999999996 | 0.52 |
| b8 | 0.1 | 0.1 | 0.2 | 0.08 | 0.14 | 0.24000000000000002 | 0.54 |
| b16 | 0.1 | 0.1 | 0.1 | 0.06 | 0.08 | 0.15999999999999998 | 0.45999999999999996 |

### ESA
| model | top1 | top5 | top10 | top5to5 | top5to10 | top5to20 | top5to50 |
|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| s8 | 0.0 | 0.0 | 0.0 | 0.04 | 0.04 | 0.04 | 0.26 |
| s16 | 0.0 | 0.0 | 0.0 | 0.02 | 0.04 | 0.08 | 0.26 |
| b8 | 0.0 | 0.0 | 0.0 | 0.02 | 0.04 | 0.08 | 0.26 |
| b16 | 0.0 | 0.0 | 0.0 | 0.02 | 0.06 | 0.08 | 0.26 |

## MobileNet backbone feature planes as retrieval embeddings
We tried converting the backbone's feature planes into embeddings by performing a max operation on each full plane, obtaining at the end a 960-dimensional vector to be used for retrieval. L2 distance for computing distances, again.

In [None]:
n_cls = 15
chkpt = "/home/pit/Desktop/mobilenetdl_nobg_aug_checkpoint64"
mobilenet = utils.load_network({"net": "MobileNet", "num_classes": n_cls, 'backbone': 'mobilenet'}, device)
checkpoint = torch.load(chkpt, map_location=torch.device(device))
mobilenet.load_state_dict(checkpoint['model_state_dict'])
mobilenet.eval()
print(net)

In [None]:
mobilenet_emb = None

def take_mobilenet_emb(moudle, args, output):
    global mobilenet_emb
    mobilenet_emb = output['out'].squeeze().clone()

layer = mobilenet.model.backbone

with torch.no_grad():
    mobilenet_emb_to_retrieve = []
    mobilenet_emb_db = []
    handler = register_extraction_hook(take_mobilenet_emb, layer)
    for img in to_retrieve:
        mobilenet(img.unsqueeze(0).type(torch.float32))
        mobilenet_emb_to_retrieve.append(mobilenet_emb)
    for img in retr:
        mobilenet(img.unsqueeze(0).type(torch.float32))
        mobilenet_emb_db.append(mobilenet_emb)
    handler.remove()

In [None]:
feats_to_retrieve = []
feats_db = []

#  max pooling on channels
for emb in mobilenet_emb_to_retrieve:
    feats_to_retrieve.append(torch.max(emb.flatten(start_dim=1), dim=1)[0].flatten())
for emb in mobilenet_emb_db:
    feats_db.append(torch.max(emb.flatten(start_dim=1), dim=1)[0].flatten())

In [None]:
mobilenet_l1 = torch.cdist(torch.stack(feats_to_retrieve), torch.stack(feats_db), p=1)
sorted_mobilenet_l1 = [sorted(zip(mobilenet_l1[img_idx].tolist(), range(len(retr)))) for img_idx in range(len(to_retrieve))]

In [None]:
plot_imgs([retr[best[1]] for best in sorted_mobilenet_l1[selected_img_idx][:topN]])

In [None]:
scores_evaluation(sorted_emds, sorted_mobilenet_l1)

### GID
| dist | top1 | top5 | top10 | top5to5 | top5to10 | top5to20 | top5to50 |
|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| L2 | 0.1 | 0.5 | 0.7 | 0.12000000000000001 | 0.34 | 0.44 | 0.6000000000000001 |
| L1 | 0.2 | 0.5 | 0.7 | 0.16 | 0.36 | 0.46 | 0.5800000000000001 |
### ESA
| dist | top1 | top5 | top10 | top5to5 | top5to10 | top5to20 | top5to50 |
|:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| L2 | 0.1 | 0.1 | 0.2 | 0.04 | 0.06 | 0.18 | 0.49999999999999994 |
| L1 | 0.1 | 0.1 | 0.2 | 0.06 | 0.1 | 0.22 | 0.44 |

## Other experiments
We tried also with segformer features, without particular success. You are free to use these cells for experiments!

In [None]:
n_cls = 15
chkpt = "your/checkpoint"
segformer_net = utils.load_network({"net": "SegformerMod", "num_classes": n_cls}, device)
checkpoint = torch.load(chkpt, map_location=device)
segformer_net.custom_load(checkpoint)
segformer_net.eval()

encoded_queries = None
encoded_retr = None

def take_encoded_t(module, args, output):
    global encoded_queries
    encoded_queries = output.last_hidden_state.clone()

def take_encoded_q(module, args, output):
   global encoded_retr
   encoded_retr = output.last_hidden_state.clone()

layer = segformer_net.segformer.segformer.encoder

with torch.no_grad():
    emb_queries = []
    handler = register_extraction_hook(take_encoded_t, layer)
    for img in to_retrieve:
        segformer_net(img.unsqueeze(0))
        emb_queries.append(encoded_queries.squeeze())
    handler.remove()
    emb_retr = []
    handler = register_extraction_hook(take_encoded_q, layer)
    for img in retr:
        segformer_net(img.unsqueeze(0))
        emb_retr.append(encoded_retr.squeeze())
    handler.remove()

In [None]:
segfrm_l2 = torch.cdist(torch.stack(emb_queries).flatten(start_dim=1), torch.stack(emb_retr).flatten(start_dim=1))

In [None]:
sorted_segfrm_l2 = sorted(zip(segfrm_l2[selected_img_idx].tolist(), range(len(retr))))
plot_imgs([retr[best[1]] for best in sorted_segfrm_l2[:5]])

In [None]:
segfrm_cos_sim = torch.nn.functional.cosine_similarity(torch.stack(emb_queries).flatten(start_dim=1).unsqueeze(1), torch.stack(emb_retr).flatten(start_dim=1).unsqueeze(0), dim=2)

In [None]:
sorted_segfrm_cos_sim = sorted(zip(segfrm_cos_sim[selected_img_idx].tolist(), range(len(retr))), reverse=True)
plot_imgs([retr[best[1]] for best in sorted_segfrm_cos_sim[:5]])

## Unet embeddings

In [None]:
n_cls = 15
chkpt = "your/checkpoint"
unet = utils.load_network({"net": "Unetv2", "num_classes": n_cls}, device)
checkpoint = torch.load(chkpt, map_location=torch.device(device))
unet.load_state_dict(checkpoint['model_state_dict'])
unet.eval()

def take_embedding(module, args, output):
    global unet_embedding
    unet_embedding = output.clone()

layer = unet.encode5

unet_embedding = None

with torch.no_grad():
    unet_emb_queries = []
    handler = register_extraction_hook(take_embedding, layer)
    for img in to_retrieve:
        unet(img.unsqueeze(0).type(torch.float32))
        unet_emb_queries.append(unet_embedding.squeeze())
    unet_emb_retr = []
    for img in retr:
        unet(img.unsqueeze(0).type(torch.float32))
        unet_emb_retr.append(unet_embedding.squeeze())
    handler.remove()

unet_feats_to_retrieve = []
unet_feats_db = []

#  max pooling on channels
for emb in unet_emb_queries:
    unet_feats_to_retrieve.append(torch.max(emb.flatten(start_dim=1), dim=1)[0].flatten())
for emb in unet_emb_retr:
    unet_feats_db.append(torch.max(emb.flatten(start_dim=1), dim=1)[0].flatten())

unet_l2 = torch.cdist(torch.stack(unet_feats_to_retrieve), torch.stack(unet_feats_db))
sorted_unet_l2 = [sorted(zip(unet_l2[img_idx].tolist(), range(len(retr)))) for img_idx in range(len(to_retrieve))]

In [None]:
scores_evaluation(sorted_emds, sorted_unet_l2)

GID 0.2 0.2 0.30000000000000004 0.2 0.28 0.4 0.6400000000000001