# Comparison of prototypes across images
This notebook plots a random selection of images where a given prototypes is present. The aim is to show the consistency of the concepts used by the model to make a prediction across samples.

In [None]:
import os
import pickle
from os.path import join as pj

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pyrootutils
import seaborn as sns
import torch
from matplotlib.patches import Rectangle
from torchvision import transforms
from tqdm import tqdm
import pandas as pd
pyrootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)
from pathlib import Path

from src.shared_utils.utils_visualisation import plot_prototypes, show_cam_on_image

In [None]:
path_sim =# add the path to the folder where the trained model is stored

pkl_path = pj(path_sim, "results_test.pkl")
with open(pkl_path, "rb") as f:
    dict_results = pickle.load(f)

In [None]:
labels = dict_results["labels"]

np_image = dict_results["sample"]
np_similarity = dict_results["similarity_prototype"]
pruned_importance = dict_results["importance"].copy()
pruned_importance[pruned_importance<0.1] = 0
modified_preds = pruned_importance.sum(axis=1)

In [None]:
# mean_normalize=torch.tensor([0.48145466, 0.4578275, 0.40821073]) # for medclip
# std_normalize=torch.tensor([0.26862954, 0.26130258, 0.27577711]) # for medclip

mean_normalize = torch.tensor([0.485, 0.456, 0.406])
std_normalize = torch.tensor([0.229, 0.224, 0.225])
unnormalize = transforms.Normalize(
        (-mean_normalize / std_normalize), (1.0 / std_normalize),
    )

In [None]:
class_importance = pruned_importance[np.arange(pruned_importance.shape[0]), :, labels]

In [None]:
list_idx_used = np.argwhere((class_importance>0).sum(axis=0)>0).squeeze()
print(list_idx_used)

In [None]:
colorblind_palette = sns.color_palette("colorblind", len(list_idx_used))

In [None]:
proto_per_figure = 16
scale_factor =2
nb_rows =4
width = scale_factor*(proto_per_figure/nb_rows)
height = scale_factor*nb_rows*1.3

In [None]:
class_importance.shape

In [None]:
top_images = True

In [None]:
if top_images:
    path_save = Path(path_sim) / "prototypes_organised_top"
else:
    path_save = Path(path_sim) / "prototypes_organised_random"
path_save.mkdir(exist_ok=True, parents=True)
for idx_proto_analsyed in tqdm(list_idx_used):
    if top_images:
        list_images = class_importance[:, idx_proto_analsyed].argsort()[-proto_per_figure:][::-1]
    else:
        top_importance = np.argsort(class_importance, axis=1)[:, -6:]
        list_images = np.argwhere((top_importance == idx_proto_analsyed).sum(axis=1)>0).squeeze()
        np.random.shuffle(list_images)
        list_images = list_images[:proto_per_figure]

    fig, axs = plt.subplots(ncols=int(proto_per_figure/nb_rows), nrows=nb_rows, figsize=(width, height))
    for i,idx  in enumerate(list_images):
        idx_figure = idx
        image = np_image[idx_figure]
        image = unnormalize(torch.tensor(image)).numpy()
        label_tmp = labels[idx_figure]
        pred_tmp = np.argmax(modified_preds[idx_figure])
        importance_image = pruned_importance[idx_figure,:,label_tmp]
        similarity = np_similarity[idx_figure]
        img_size = image.shape[1:]
        size_square_similarity = int(similarity.shape[1]**0.5)

        similarity_proto = similarity[idx_proto_analsyed]
        similiarity_tmp = similarity_proto.reshape(
            size_square_similarity, size_square_similarity,
        )
        similarity_scaled = torch.nn.functional.interpolate(
            torch.tensor(similiarity_tmp[None, None, :, :]),
            size=img_size,
            scale_factor=None,
            mode="bilinear",
        )

        similarity_plot = show_cam_on_image(
            np.transpose(image,(1,2,0)),
            similarity_scaled[0, 0].detach().cpu().numpy(),
            alpha = 0.4,
        )
        #flatten axs and plot in the next available spot
        axs.flatten()[i].imshow(similarity_plot,interpolation="nearest")
        axs.flatten()[i].axis("off")
        # axs.flatten()[i].title.set_text(f"Pred: {df_disease.iloc[pred_tmp, 0]}, Label: {df_disease.iloc[label_tmp, 0]}")
        axs.flatten()[i].title.set_text(f"Pred: {pred_tmp}, Label: {label_tmp}")

    plt.tight_layout()
    plt.savefig(path_save / f"proto_{idx_proto_analsyed}.png")
    plt.close()