In [None]:
import os
os.chdir("/home2/jgcw74/l3_project")  # go to project root
import importlib

import helpers
import dataset_processing
import models
import xai

torch_device = helpers.utils.get_torch_device()

In [None]:
import typing as t
from pathlib import Path
import json

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
AVAILABLE_MODELS = ("ResNet50", "ConvNeXtSmall", "SwinTransformerSmall")

## Load in results

In [None]:
dfs = dict()
for explainer_name in t.get_args(xai.EXPLAINER_NAMES):
    results_for_exp = dict()
    h5_output_path = helpers.env_var.get_project_root() / "results" / explainer_name / "evaluation_output.h5"
    store = pd.HDFStore(str(h5_output_path), mode="r")
    for key_name in store.keys():
        table_name = key_name.strip("/")
        df: pd.DataFrame = store[table_name]
        if df.isna().sum().sum() != 0:
            raise RuntimeError("A results table contains NaN values!")
        else:
            # adjust value in line with methodology formula
            df["output_completeness : preservation_check_conf_drop"] = 1 - df["output_completeness : preservation_check_conf_drop"]

            results_for_exp[table_name] = df
    dfs[explainer_name] = results_for_exp
    store.close()
for key in dfs.keys():
    print(f"{key}: {len(dfs[key])} tables loaded")

### Export results to excel and make one mega tidied-up dataframe

In [None]:
ew = pd.ExcelWriter("all_results_export.xlsx")
big_df_dict = dict()
for sheet_name, df_dict in dfs.items():
    temp_df = pd.concat(df_dict).reset_index()
    temp_df["dataset"] = temp_df["level_0"].str.split("_").str.get(0)
    temp_df["model"] = temp_df["level_0"].str.split("_").str.get(1)
    temp_df = temp_df.rename(columns={"level_1": "class_label"}).set_index(["dataset", "model", "class_label"]).drop(["level_0"], axis="columns")

    temp_df.to_excel(ew, sheet_name=sheet_name, index=True, merge_cells=False)

    big_df_dict[sheet_name] = temp_df
ew.close()

#### Clean up dataframe into desired multiindex format

In [None]:
xai_ds_m_c_df = pd.concat(big_df_dict.values(), keys=big_df_dict.keys(), names=["xai_method"])
xai_ds_m_c_df = xai_ds_m_c_df.drop(
    columns=xai_ds_m_c_df.columns[xai_ds_m_c_df.columns.str.startswith("continuity") | xai_ds_m_c_df.columns.str.endswith("l2_distance")]
)
xai_ds_m_c_df.columns = xai_ds_m_c_df.columns.str.replace("randomised_model_similarity", "random_sim").str.replace("adversarial_attack_similarity", "adv_attk_sim").str.replace("correctness", "COR").str.replace("output_completeness", "O-C").str.replace("contrastivity", "CON").str.replace("compactness", "COM").str.replace("spearman_rank", "SR").str.replace("top_k_intersection", "top_m").str.replace("structural_similarity", "ssim")
xai_ds_m_c_df = xai_ds_m_c_df.replace(-np.inf, np.nan)
xai_ds_m_c_df.columns, xai_ds_m_c_df.index[-1]

In [None]:
pd.options.display.precision = 5
pd.options.display.max_colwidth = 20
print(xai_ds_m_c_df.loc[("PartitionSHAP", "EuroSATMS", "ConvNeXtSmall")])

### Try loading some data

In [None]:
def get_dataset_and_model(dataset_n, model_n):
    model_type = models.get_model_type(model_n)
    ds = dataset_processing.get_dataset_object(dataset_n, "test", model_type.expected_input_dim, 32, 4, torch_device)

    m = model_type(False, ds.N_BANDS, ds.N_CLASSES).to(torch_device)
    weights_path = json.load(Path("weights_paths.json").open("r"))[dataset_n][model_n]
    m.load_weights(Path("checkpoints") / dataset_n / model_n / weights_path)

    return ds, m

In [None]:
model_name = "SwinTransformerSmall"
dataset_name = "PatternNet"

dataset, model = get_dataset_and_model(dataset_name, model_name)

print(list(enumerate(dataset.classes)))

In [None]:
class_idx = 1
batch_num = 0
base_exp = xai.get_explainer_object(
    "PartitionSHAP", model,
    extra_path=Path(dataset_name) / f"c{class_idx:02}" / f"b{batch_num:03}",
)
base_exp.force_load()

In [None]:
helpers.plotting.visualise_importance(base_exp.input.numpy(force=True), base_exp.explanation, alpha=0.6, with_colorbar=True)

In [None]:
base_exp.explanation.reshape(base_exp.explanation.shape[0], -1).sum(1)

In [None]:
helpers.plotting.visualise_importance(base_exp.input.numpy(force=True), base_exp.ranked_explanation, alpha=0.4, with_colorbar=True)

In [None]:
img_dict = np.load(helpers.env_var.get_xai_output_root() /
                   Path(dataset_name) / f"c{class_idx:02}" / "combined" /
                   f"{model_name}_adversarial_examples.npz")
og_imgs = img_dict["original_imgs"]
adv_imgs = img_dict["clipped_adv_imgs"]
helpers.plotting.show_image(
    np.stack([np.hstack([im1, -np.ones((3, 10, im1.shape[-1])), im2]) for im1, im2 in zip(og_imgs, adv_imgs)]),
    padding=20,
)

Adversarial images really are indistinguishable...

In [None]:
plt.imshow(np.hstack(og_imgs[:8] - adv_imgs[:8]).transpose(1,2,0)*50 + 1/2)

In [None]:
base_exp.model(torch.from_numpy(og_imgs).to(torch_device)).argmax(1)

In [None]:
base_exp.model(torch.from_numpy(adv_imgs).to(torch_device)).argmax(1)

## Compare generated explanations visually

### Helpers functions

In [None]:
def compare_explanations(dataset_n, model_n, class_i, batch_n, use_ranked=True):
    ds, m = get_dataset_and_model(dataset_n, model_n)
    print(ds.classes[class_i])
    exp_list = []
    for en in t.get_args(xai.EXPLAINER_NAMES):
        exp = xai.get_explainer_object(
            en, m, extra_path=Path(dataset_n) / f"c{class_i:02}" / f"b{batch_n:03}",
        )
        exp.force_load()
        exp_list.append(exp)

    helpers.plotting.visualise_importance(
        np.concatenate([exp.input.numpy(force=True)[:8] for exp in exp_list]),
        np.concatenate(
            [exp.ranked_explanation[:8] if use_ranked else exp.explanation[:8] for exp in exp_list]
        ),
        alpha=0.5, with_colorbar=True
    )
    return exp_list

In [None]:
def compare_models(dataset_n, class_i, batch_n):
    for mn in AVAILABLE_MODELS:
        exp_list = compare_explanations(dataset_n, mn, class_i, batch_n)
        plt.title(mn)
        plt.show()
    return exp_list

In [None]:
def get_ds_classes(dataset_n):
    ds, _ = get_dataset_and_model(dataset_n, "ResNet50")
    return list(enumerate(ds.classes))

### UCMerced

In [None]:
dataset_name = "UCMerced"
batch_num = 0
print(get_ds_classes(dataset_name))

In [None]:
_ = compare_models(dataset_name, 1, batch_num)

Line artifacts indicate the rest of the image was deemed as not important/0 (see below)

In [None]:
compare_explanations(dataset_name, "ConvNeXtSmall", 1, batch_num, use_ranked=False)

In [None]:
compare_explanations(dataset_name, "SwinTransformerSmall", 1, batch_num, use_ranked=False)

KPCA does some weird things for certain Swin Transformer images - reverse localisation.
Other explanation methods still highlight expected regions

So we expect the output-completeness and incremental deletion metrics to be worse for KPCA on SwinT

#### Investigate poor SwinT performance

In [None]:
CoIs = ("COR : incremental", "O-C")

In [None]:
ucm_airplane = xai_ds_m_c_df.loc[:, "UCMerced", :, "airplane"]
ucm_airplane

In [None]:
column_mask = [c for c in ucm_airplane.columns if c.startswith(CoIs)]
ucm_airplane[column_mask]

The best score for a deletion/preservation check is 1/-1. This only appears to be successful for ConvNeXt for KPCACAM.

GradCAM does a similarity poor job for all except ConvNeXt.

In [None]:
xai_ds_m_c_df[column_mask[1:]].loc["PartitionSHAP", dataset_name].groupby("model").boxplot(rot=90, sharey=True, layout=(1, 3), subplots=True)

KPCACAM is much more inconsistent than GradCAM. PartitionSHAP appears most reliable but still rarely over 0.5.

In [None]:
xai_ds_m_c_df[column_mask].loc[:, dataset_name, :].groupby("xai_method").boxplot(rot=90, subplots=True, layout=(1, 3))

In [None]:
xai_ds_m_c_df[column_mask[1:]].groupby(level=["xai_method", "dataset"]).boxplot(rot=90, subplots=False)

### Inspect performance on more targetable land cover classes (not objects)

In [None]:
compare_models(dataset_name, 16, batch_num)

SHAP explanations get much messier for land based concepts - not as cohesive as GradCAM methods

But might be true to underlying model/unfair because of similar concepts

### EuroSATRGB

In [None]:
dataset_name = "EuroSATRGB"
get_ds_classes(dataset_name)

#### AnnualCrop

In [None]:
last_exp_list = compare_models(dataset_name, 0, batch_num)
last_exp_list

Explanations are much less localised with EuroSAT for a general area class such as AnnualCrop

In [None]:
helpers.plotting.show_image(last_exp_list[-1].input[:8], padding=20)

#### Highways and rivers

In [None]:
last_exp_list = compare_models(dataset_name, 8, batch_num)

Rivers are easier - all appear to learn the banks of the river

In [None]:
helpers.plotting.visualise_importance(last_exp_list[0].input, last_exp_list[0].explanation, alpha=0.5)

Raw SHAP explanations much more precise and localised

In [None]:
last_exp_list = compare_models(dataset_name, 3, batch_num)

## Compare RGB and MS efficacy

## Plot evaluation results

In [None]:
explainer_name = "PartitionSHAP"
dataset_name = "EuroSATRGB"
model_name = "ConvNeXtSmall"

In [None]:
dfs[explainer_name][f"{dataset_name}_{model_name}"]