# Requirements

In [5]:
import numpy as np
import pandas as pd 

import matplotlib.pyplot as plt
import seaborn as sns

import wandb

dataset_match = {
    "cobre": "COBRE",
    "abide": "ABIDE",
    "synth1": "Synthetic1",
    "synth2": "Synthetic2",
}

model_match = {
    "lstm": "LSTM",
    "mean_lstm": "Mean LSTM",
    "transformer": "Transformer",
    "mean_transformer": "Mean Transformer",
    "dice": "DICE",
    "glob_dice": "Pretuned DICE",
}

def load_normal(proj_name):
    api = wandb.Api(timeout=19)
    # Project is specified by <entity/project-name>
    runs = api.runs(f"{proj_name}")

    summary_list = []
    for run in runs: 
        # .summary contains the output keys/values for metrics like accuracy.
        #  We call ._json_dict to omit large files 
        summary_list.append(run.summary._json_dict)

    AUC_score = []
    accuracy = []
    for run in summary_list:
        AUC_score.append(run["test_score"])
        accuracy.append(run["test_accuracy"])
    
    return AUC_score, accuracy

def load_metrics(paths_dict, ds_dict, model_dict):
    data_list = []
    mean_list = []

    for model_name in paths_dict.keys():
        print(model_name)
        for dataset_name in paths_dict[model_name].keys():
            print("\t ", dataset_name)

            path = paths_dict[model_name][dataset_name]

            auc, acc = load_normal(path)
            
            data_list.append(
                pd.DataFrame(
                    {
                        "AUC": auc,
                        "Accuracy": acc,
                        "Model": [model_dict[model_name]]*len(auc),
                        "Dataset": [ds_dict[dataset_name]]*len(auc),
                    }
                )
            )
            mean_list.append(
                pd.DataFrame(
                    {
                        "Mean": np.mean(auc),
                        "Var": np.var(auc),
                        "Model": model_dict[model_name],
                        "Dataset": ds_dict[dataset_name],
                    },
                    index=[0]
                )
            )
    
    return pd.concat(data_list), pd.concat(mean_list)

# 1. Plot boxplots

In [13]:
projects = {
    "lstm": {
        "abide": "introdl-exp-lstm-abide",
        "cobre": "introdl-exp-lstm-cobre",
        "synth1": "introdl-exp-lstm-synth1",
        "synth2": "introdl-exp-lstm-synth2",
    },
    "mean_lstm": {
        "abide": "introdl-exp-mean_lstm-abide",
        "cobre": "introdl-exp-mean_lstm-cobre",
        "synth1": "introdl-exp-mean_lstm-synth1",
        "synth2": "introdl-exp-mean_lstm-synth2",
    },
    "transformer": {
        "abide": "introdl-exp-transformer-abide",
        "cobre": "introdl-exp-transformer-cobre",
        "synth1": "introdl-exp-transformer-synth1",
        "synth2": "introdl-exp-transformer-synth2",
    },
    "mean_transformer": {
        "abide": "introdl-exp-mean_transformer-abide",
        "cobre": "introdl-exp-mean_transformer-cobre",
        "synth1": "introdl-exp-mean_transformer-synth1",
        "synth2": "introdl-exp-mean_transformer-synth2",
    },
    "dice": {
        "abide": "introdl-exp-dice-abide",
        "cobre": "introdl-exp-dice-cobre",
        "synth1": "introdl-exp-dice-synth1",
        "synth2": "introdl-exp-dice-synth2",
    },
    # "glob_dice": {
    #     "abide": "introdl-exp-dice-global_abide",
    #     "cobre": "introdl-exp-dice-global_cobre",
    #     "synth1": "introdl-exp-dice-global_synth1",
    #     "synth2": "introdl-exp-dice-global_synth2",
    # },
}

data, stat_data = load_metrics(projects, dataset_match, model_match)

lstm
	  abide
	  cobre
	  synth1
	  synth2
mean_lstm
	  abide
	  cobre
	  synth1
	  synth2
transformer
	  abide
	  cobre
	  synth1
	  synth2
mean_transformer
	  abide
	  cobre
	  synth1
	  synth2
dice
	  abide
	  cobre
	  synth1
	  synth2


In [98]:
sns.set_theme(
    style="whitegrid", 
    font_scale = 1.5, 
    rc={'figure.figsize':(10,3)}
)


palette = {
    "LSTM": "C0",
    "Mean LSTM": "C1",
    "Transformer": "C2",
    "Mean Transformer": "C3",
    "DICE": "C4",
    "Pretuned DICE": "C5",
}


ax = sns.boxplot(
    x="Dataset", 
    y="AUC",
    hue="Model",
    data=data,
    palette=palette,
    showfliers = False
)

ax.set(ylabel="AUROC")

ax.axhline(0.5)
plt.yticks([0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
# plt.xticks(rotation=20)
# sns.despine(offset=10, trim=True)
plt.legend(bbox_to_anchor=(1.01, 1.0), loc='upper left', borderaxespad=0)


sns.set_theme(
    font_scale = 1,
)

plt.ylim(0.4, 1.05)

# plt.show()
plt.savefig(
    "aucs.eps",
    format="eps",
    # dpi=300,
    bbox_inches='tight',
)

plt.close()

The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.


# 2. Create panel of saliency maps

In [18]:
from PIL import Image

sns.reset_defaults()

models = [
    "lstm",
    "mean_lstm",
    "transformer",
    "mean_transformer",
    "dice",
    # "pretuned_dice",
]
datasets = [
    "abide",
    "cobre",
    "synth1",
    "synth2",
]

# fig, axs = plt.subplots(2*len(models), len(datasets), figsize=(10, 10), constrained_layout = True)
fig, axs = plt.subplots(2*len(models), len(datasets), figsize=(3.6*len(datasets), 3*len(models)))
# fig, axs = plt.subplots(2*len(models), len(datasets))
# fig.tight_layout()

for i, model in enumerate(models):
    for j, dataset in enumerate(datasets):
        for target in range(2):
            if model == "pretuned_dice":
                image = Image.open(f'../assets/introspection/introdl-introspection-dice-global_{dataset}/k_00/saliency/colormap/general_{target}.png')
            else:
                image = Image.open(f'../assets/introspection/introdl-introspection-{model}-{dataset}/k_00/saliency/colormap/general_{target}.png')
            # image0.show()
            axs[2*i + target, j].imshow(image)
            axs[2*i + target, j].set_xticks([])
            axs[2*i + target, j].set_yticks([])
            
            if j == 0:
                if target == 0:
                    axs[2*i + target, j].set_ylabel(f"Class {target},\n {model}", fontsize = 16)
                else:
                    axs[2*i + target, j].set_ylabel(f"Class {target}", fontsize = 16)
            if i == 0 and target == 0:
                axs[2*i + target, j].set_xlabel(dataset_match[dataset], fontsize = 16)
                axs[2*i + target, j].xaxis.set_label_position('top') 
            

fig.tight_layout(pad = 0.05)
# plt.subplots_adjust(left=0.1,
#                     bottom=0.1,
#                     right=0.9,
#                     top=0.9,
#                     wspace=0.4,
#                     hspace=0.4)

# plt.show()

plt.savefig(
    "saliency.png",
    format="png",
    dpi=150,
    bbox_inches='tight',
)

plt.close()

In [97]:
import matplotlib.pyplot as plt

models = [
    "lstm",
    "mean_lstm",
    "transformer",
    "mean_transformer",
    "dice",
    # "pretuned_dice",
]
datasets = [
    "abide",
    "cobre",
    "synth1",
    "synth2",
]

model_match = {
    "lstm": "LSTM",
    "mean_lstm": "Mean\nLSTM",
    "transformer": "Transformer",
    "mean_transformer": "Mean\nTransformer",
    "dice": "DICE",
    "glob_dice": "Pretuned\nDICE",
}

fig = plt.figure(constrained_layout=True, figsize=(3.7*len(datasets), 3*len(models)))
subfigs = fig.subfigures(len(models), len(datasets)+1, width_ratios=[0.1]+ [1]*(len(datasets)))

for i, model in enumerate(models):
    for j, dataset in enumerate(datasets):
        if j == 0:
            subfigs[i, j].suptitle(f"{model_match[model]}", fontsize = 16, x = -0.2, y = 0.5, rotation = "vertical", ha="center", va="center")
            axs = subfigs[i, j].subplots(2, 1)
            for target in range(2):
                axs[target].text(1.5, 0.5, f'Class {target}', ha='center', va='center', fontsize = 16, rotation = "vertical")
                axs[target].axis('off')

        axs = subfigs[i, j+1].subplots(2, 1)
        if i == 0:
            subfigs[i, j+1].suptitle(f"{dataset_match[dataset]}", fontsize = 16, x = 0.5, y = 1.02, ha="center", va="center")
        for target in range(2):
            image = Image.open(f'../assets/introspection/introdl-introspection-{model}-{dataset}/k_00/saliency/colormap/general_{target}.png')
            # image0.show()
            axs[target].imshow(image)
            axs[target].set_xticks([])
            axs[target].set_yticks([])
            

# plt.show()
plt.savefig(
    "saliency.eps",
    format="eps",
    # dpi=150,
    bbox_inches='tight',
)

plt.close()