## Plot correlation across classes from the classification head
This notebook plots the correlation across the weights of the classification head for models trained on the CUB dataset. The aim is to show how the model reuses the same concepts across similar classes enabling a small global size of the explanations.

In [None]:
import pickle as pkl
from pathlib import Path

import numpy as np
import pyrootutils
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
pyrootutils.setup_root(Path.cwd() , indicator=".project-root", pythonpath=True)

from src.shared_utils.utils_experiments import load_model_dataset

In [None]:
path_sim = # add the path to the folder where the trained model is stored
model, dataset = load_model_dataset(path_sim, "test" )

In [None]:
# First save weight of the classification head separately from the model
np.save(Path(path_sim) / "classification_head_weight.npy", model.classification_head.weight.detach().cpu().numpy())

In [None]:
pkl_path = path_sim / "results_test.pkl"
threshold =0.1
with pkl_path.open("rb") as f:
    dict_results = pkl.load(f)
    labels = dict_results["labels"]
importance = dict_results["importance"].copy()
if threshold > 0:
    importance[importance < threshold] = 0
    preds = importance.sum(axis=1).argmax(axis=1)
else:
    preds = dict_results["preds"]
accuracy = (preds == labels).sum() / len(labels)
class_importance = importance[np.arange(importance.shape[0]), :, labels]
local_size = (class_importance > 0).sum(axis=1).mean()
global_size = ((class_importance > 0).sum(axis=0) > 0).sum()

In [None]:
idx_proto = np.argwhere((class_importance > 0).sum(axis=0)>0).squeeze()
classification_weight = np.load(path_sim / "classification_head_weight.npy")
selected_weight = classification_weight[:,idx_proto]

In [None]:
name_classes = [x.name for x in Path(path_data).iterdir()]
name_classes = sorted(name_classes, key=lambda x: int(x.split(".")[0]))

In [None]:
name_classes
# order the list based on the number at the start of the string

In [None]:
df = pd.DataFrame(selected_weight)
corr = np.corrcoef(df)
plt.figure(figsize=(20,20))
sns.heatmap(corr, annot=False)

In [None]:
# plot correlation across weight of each row of the matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# start_idx = 55
# end_idx = 70
start_idx = 112
end_idx = 135
chosen_idx = np.arange(start_idx, end_idx)
chosen_name = np.array(name_classes)[chosen_idx]
chosen_name = [x.split(".")[1] for x in chosen_name]
df = pd.DataFrame(selected_weight)
corr = np.corrcoef(df.iloc[start_idx: end_idx,:])
plt.figure(figsize=(20,20))
sns.heatmap(corr, annot=True);
plt.xticks(np.arange(len(chosen_idx)), chosen_name, rotation=45,ha="right");
plt.yticks(np.arange(len(chosen_idx)), chosen_name, rotation=45,va="top");

In [None]:
ordered_proto = np.argsort(np.abs(selected_weight[50]- selected_weight[52]))[::-1]
idx_proto[ordered_proto]