# Score sheet

In [None]:
import os
import pickle
from os.path import join as pj

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pyrootutils
import torch
from torchvision import transforms
from matplotlib.patches import Rectangle
import seaborn as sns
pyrootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)

from src.shared_utils.utils_visualisation import plot_prototypes, show_cam_on_image

In [None]:
path_sim = ""
proto_per_figure = 4
pkl_path = pj(path_sim, "results_test.pkl")
with open(pkl_path, 'rb') as f:
    dict_results = pickle.load(f)

In [None]:
mean_normalize = torch.tensor([0.485, 0.456, 0.406])
std_normalize = torch.tensor([0.229, 0.224, 0.225])

unnormalize = transforms.Normalize(
        (-mean_normalize / std_normalize), (1.0 / std_normalize)
    )

In [None]:
labels = dict_results["labels"]

np_image = dict_results["sample"]
np_similarity = dict_results["similarity_prototype"]
pruned_importance = dict_results["importance"].copy()
pruned_importance[pruned_importance<0.1] = 0
modified_preds = pruned_importance.sum(axis=1)


In [None]:
class_importance = pruned_importance[np.arange(pruned_importance.shape[0]), :, labels]

In [None]:
idx_figures = np.random.choice(len(labels), 2, replace=False)


In [None]:
importance_all = class_importance[idx_figures]
top_proto_idx = np.argsort(importance_all, axis=1)[:, -proto_per_figure:]
all_proto = np.unique(top_proto_idx)

In [None]:
colorblind_palette = sns.color_palette("colorblind", len(all_proto))

dict_proto_color = {all_proto[i]: colorblind_palette[i] for i in range(len(all_proto))}

In [None]:
print("idx_figures", idx_figures)

In [None]:
scale_factor =4
width = scale_factor*(proto_per_figure+1)
height = scale_factor*len(idx_figures)*1.1
fig, axs = plt.subplots(ncols=proto_per_figure + 1, nrows=len(idx_figures), figsize=(width, height))

for idx_plot,idx_figure in enumerate(idx_figures):
    importance_plotted = 0
    image = np_image[idx_figure]
    image = unnormalize(torch.tensor(image)).numpy()
    label_tmp = labels[idx_figure]
    pred_tmp = np.argmax(modified_preds[idx_figure])
    importance_image = pruned_importance[idx_figure,:,label_tmp]
    top_proto_idx_image = np.argsort(importance_image)[-proto_per_figure:]
    top_proto_idx_image = top_proto_idx_image[::-1]
    similarity = np_similarity[idx_figure]
    similarity = similarity[top_proto_idx_image]
    nb_proto = min((importance_image>0).sum(),similarity.shape[0])
    similarity = similarity[:nb_proto]
    img_size = image.shape[1:]
    size_square_similarity = int(similarity.shape[1]**0.5)
    color_annotations = [dict_proto_color[top_proto_idx_image[i]] for i in range(proto_per_figure)]
    plot_prototypes(image, similarity,axs=axs[idx_plot,0], alpha=0.2, label=label_tmp,pred = pred_tmp,color_annotations=color_annotations)

    axs[idx_plot,0].title.set_fontsize(18)
    for idx in range(proto_per_figure):
        if idx>=nb_proto:
            axs[idx_plot,idx + 1].axis("off")
            continue
        idx_proto = top_proto_idx_image[idx]

        similarity_proto = similarity[idx]
        similiarity_tmp = similarity_proto.reshape(
            size_square_similarity, size_square_similarity
        )
        similarity_scaled = torch.nn.functional.interpolate(
            torch.tensor(similiarity_tmp[None, None, :, :]),
            size=img_size,
            scale_factor=None,
            mode="bilinear",
        )

        similarity_plot = show_cam_on_image(
            np.transpose(image,(1,2,0)),
            similarity_scaled[0, 0].detach().cpu().numpy(),
        )
        axs[idx_plot,idx + 1].imshow(similarity_plot,interpolation='nearest')
        axs[idx_plot,idx + 1].title.set_text(f"Importance: {importance_image[idx_proto]:.2f},Proto: {idx_proto}",
        )
        importance_plotted += importance_image[idx_proto]
        axs[idx_plot,idx + 1].title.set_fontsize(18)
        # axs[idx_plot,idx + 1].title.set_fontsize(30)
        # make tight layout
        axs[idx_plot,idx + 1].axis("off")
        border = Rectangle(
            (0, 0),
            image.shape[1],
            image.shape[0],
            linewidth=4,
            edgecolor=color_annotations[idx],
            facecolor="none",
        )
        axs[idx_plot,idx + 1].add_patch(border)
    axs[idx_plot,0].title.set_text(f"Total score: {importance_image.sum():.2f} ({(importance_plotted/importance_image.sum())*100:.2f}%)")
                                    # predicted: {pred_tmp}")
            # Add the border to the image

plt.tight_layout()

In [None]:
np_similarity[idx_figures[5],283].max()