# Plot model's prediction across samples of the FunnyBirds dataset with interbention
This notebook aims to show how the model's prediction changes as key parts of the images are supressed in the FunnyBird dataset.

In [None]:

import os
import matplotlib.pyplot as plt
import numpy as np
import pyrootutils
import seaborn as sns
import torch.nn.functional as F
import torch
from torchvision import transforms
from tqdm import tqdm
from matplotlib.patches import Rectangle
pyrootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)
import random
from src.shared_utils.utils_visualisation import plot_prototypes
from src.shared_utils.utils_experiments import  load_model_dataset

from torch.utils.data import DataLoader
from src.learning.data.FunnybirdsDataset import FunnyBirdsDataset
from src.shared_utils.utils_visualisation import show_cam_on_image, return_colorblind_palette

In [None]:
data_path =
path_sim =
model, _ = load_model_dataset(path_sim=path_sim,set="test")

In [None]:
test_dataset = FunnyBirdsDataset(
        data_path,
        "test",
        get_part_map=True,
        transform=transforms,
        eval_funny_birds=True,
    )
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
list_image = []
output_all = []
target_all = []
sample_to_plot = random.randrange(0, len(test_loader))
# sample_to_plot = 64
# 2nd one is 25
for i,sample in enumerate(test_loader):
    if i==sample_to_plot:
        image = sample["image"]
        target = sample["class_idx"]
        part_map = sample["part_map"]
        params = sample["params"]
        class_idxs = sample["class_idx"]
        image_idxs = sample["image_idx"]
        params = test_dataset.get_params_for_single(params)
        image = image.cuda()
        part_map = part_map.cuda()
        target = target.cuda()

        score = {}
        list_image.append(image)
        output = model(image)
        output_all.append(output)
        target_all.append(target)
        # original_score = output[0, target].item()

        # get scores for removed parts
        # bird_parts_keys = ['beak_model', 'eye_model', 'wing_model', 'tail_model', 'foot_model']
        bird_parts_keys = list(test_dataset.parts.keys())

        for remove_part in bird_parts_keys:
            image2 = test_dataset.get_intervention(
                class_idxs.squeeze(0).item(),
                image_idxs.squeeze(0).item(),
                [remove_part],
            )["image"]


            image2 = image2.cuda(non_blocking=True)
            list_image.append(image2)
            output = model(image2)
            output_all.append(output)

            # score[remove_part.split("_")[0]] = output[
            #     0, target
            # ].item()  # only keep part name, i.e. eye, instead of eye_model

        break

In [None]:
importance = torch.cat([output_all[x]["importance"] for x in range(len(output_all))]).detach().cpu().numpy()
preds = torch.stack([torch.argmax(output_all[x]["pred"]) for x in range(len(output_all))]).detach().cpu().numpy()
np_image = torch.cat(list_image).detach().cpu().numpy()
np_similarity = torch.cat([output_all[x]["similarity_prototype"] for x in range(len(output_all))]).detach().cpu().numpy()
labels = [target.item()]*len(preds)

In [None]:
class_importance = importance[np.arange(importance.shape[0]), :, labels]

In [None]:
# idx_figures =  [3297 ,5168, 2078]
idx_figures =  torch.arange(len(labels))
# idx_figures = np.random.choice(len(labels), 3, replace=False)
proto_per_figure = 5

In [None]:
importance_all = class_importance[idx_figures]
top_proto_idx = np.argsort(importance_all, axis=1)[:, -proto_per_figure:]
all_proto = np.unique(top_proto_idx)

In [None]:
colorblind_palette = return_colorblind_palette()
dict_proto_color = {all_proto[i]: colorblind_palette[i] for i in range(len(all_proto))}

In [None]:
print("idx_figures", idx_figures)

In [None]:
fig, axs = plt.subplots(ncols=proto_per_figure + 1, nrows=len(idx_figures), figsize=(16, 18)) #width, height
bird_parts_all =["Initial"] +  [f"w/o {bird_parts}" for bird_parts in bird_parts_keys]

for idx_plot,idx_figure in enumerate(idx_figures):

    image = np_image[idx_figure]
    # image = unnormalize(torch.tensor(image)).numpy()
    label_tmp = labels[idx_figure]
    pred_tmp = preds[idx_figure]
    importance_image = importance[idx_figure,:,label_tmp]
    top_proto_idx_image = np.argsort(importance_image)[-proto_per_figure:]
    top_proto_idx_image = top_proto_idx_image[::-1]
    similarity = np_similarity[idx_figure]
    top_proto_idx_image = top_proto_idx_image[importance_image[top_proto_idx_image]>0]
    similarity = similarity[top_proto_idx_image]
    img_size = image.shape[1:]
    size_square_similarity = int(similarity.shape[1]**0.5)
    color_annotations = [dict_proto_color[top_proto_idx_image[i]] for i in range(len(top_proto_idx_image))]
    plot_prototypes(image, similarity,axs=axs[idx_plot,0], alpha=0.7, label=label_tmp,pred = pred_tmp,color_annotations=color_annotations)
    axs[idx_plot,0].title.set_text(f"{bird_parts_all[idx_plot]}: {importance_image.sum():.2f}")
    axs[idx_plot,0].title.set_fontsize(18)
    for idx in range(len(similarity)):

        idx_proto = top_proto_idx_image[idx]
        similarity_proto = similarity[idx]
        similiarity_tmp = similarity_proto.reshape(
            size_square_similarity, size_square_similarity
        )
        similarity_scaled = torch.nn.functional.interpolate(
            torch.tensor(similiarity_tmp[None, None, :, :]),
            size=img_size,
            scale_factor=None,
            mode="bilinear",
        )

        similarity_plot = show_cam_on_image(
            np.transpose(image,(1,2,0)),
            similarity_scaled[0, 0].detach().cpu().numpy(),
        )
        axs[idx_plot,idx + 1].imshow(similarity_plot,interpolation='nearest')
        axs[idx_plot,idx + 1].title.set_text(f"Importance: {importance_image[idx_proto]:.2f}",
        )
        axs[idx_plot,idx + 1].title.set_fontsize(18)
        # axs[idx_plot,idx + 1].title.set_fontsize(30)
        # make tight layout
        axs[idx_plot,idx + 1].axis("off")
        border = Rectangle(
            (0, 0),
            image.shape[1],
            image.shape[0],
            linewidth=4,
            edgecolor=color_annotations[idx],
            facecolor="none",
        )
        # Add the border to the image
        axs[idx_plot,idx + 1].add_patch(border)
        plt.tight_layout()