In [None]:
import os
print(os.getcwd())
os.chdir('../')
print(os.getcwd())

# config

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from collections import OrderedDict
import copy
import pickle
import time
from scipy import stats
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader

from vit_shapley.datamodules.ImageNette_datamodule import ImageNetteDataModule
from vit_shapley.datamodules.MURA_datamodule import MURADataModule
from vit_shapley.datamodules.Pet_datamodule import PetDataModule

from vit_shapley.modules.classifier import Classifier
from vit_shapley.modules.classifier_masked import ClassifierMasked
from vit_shapley.modules.surrogate import Surrogate
from vit_shapley.modules.explainer import Explainer

from vit_shapley.config import ex
from vit_shapley.config import config, env_chanwkim, dataset_ImageNette, dataset_MURA, dataset_Pet
_config=config()

dataset_split="test"
parallel_mode = (0, 1)
backbone_to_use=["vit_base_patch16_224"]
_config.update(dataset_Pet())
evaluation_stage=["1_classifier_evaluate",
                  "2_surrogate_evaluate",
                  "3_explanation_generate",
                  "4_insert_delete",
                  "5_sensitivity",
                  "6_noretraining",
                  "7_classifiermasked",
                  "8_elapsedtime",
                  "9_estimationerror"][1]

_config.update(env_chanwkim()); _config.update({'gpus_classifier':[4,],
                                                'gpus_surrogate':[4,],
                                                'gpus_explainer':[4,]})

_config.update({'classifier_backbone_type': None,
                'classifier_download_weight': False,
                'classifier_load_path': None})
_config.update({'classifier_masked_mask_location': "pre-softmax",
                'classifier_enable_pos_embed': True,
                })
_config.update({'surrogate_mask_location': "pre-softmax"})
_config.update({'surrogate_backbone_type': None,
                'surrogate_download_weight': False,
                'surrogate_load_path': None})
_config.update({'explainer_num_mask_samples': 2,
                'explainer_paired_mask_samples': True})

In [None]:
backbone_to_use

In [None]:
!gpustat

In [None]:
if _config["datasets"]=="ImageNette":
    backbone_type_config_dict_=OrderedDict({
        "vit_small_patch16_224":{
            "classifier_path": "results/wandb_transformer_interpretability_project/1yndrggu/checkpoints/epoch=14-step=2204.ckpt",
            "classifier_masked_path": "results/wandb_transformer_interpretability_project/fdm70w72/checkpoints/epoch=19-step=2939.ckpt",
            "surrogate_path":{
                "pre-softmax": "results/wandb_transformer_interpretability_project/3lfv4nmn/checkpoints/epoch=39-step=5879.ckpt"
            },
            "explainer_path":"results/wandb_transformer_interpretability_project/3biv2s85/checkpoints/epoch=60-step=9027.ckpt"

        },
        "deit_small_patch16_224":{
        },
        "vit_base_patch16_224":{
            "classifier_path": "results/wandb_transformer_interpretability_project/2rq1issn/checkpoints/epoch=16-step=2498.ckpt",
            "classifier_masked_path": "results/wandb_transformer_interpretability_project/x59c992d/checkpoints/epoch=21-step=3233.ckpt",
            "surrogate_path":{
                #"original": "results/wandb_transformer_interpretability_project/2rq1issn/checkpoints/epoch=16-step=2498.ckpt",
                "pre-softmax": "results/wandb_transformer_interpretability_project/3i6zzjnp/checkpoints/epoch=38-step=5732.ckpt",
                #"zero-input": "results/wandb_transformer_interpretability_project/zyybgzcm/checkpoints/epoch=22-step=3380.ckpt",
                #"zero-embedding": "results/wandb_transformer_interpretability_project/1gi5gmrm/checkpoints/epoch=36-step=5438.ckpt"
                },
            "explainer_path": "results/wandb_transformer_interpretability_project/3ty85eft/checkpoints/epoch=83-step=12431.ckpt"
        },
        "deit_base_patch16_224":{

        },
        "vit_large_patch16_224":{
            "classifier_path": "results/wandb_transformer_interpretability_project/1at36lgp/checkpoints/epoch=2-step=440.ckpt",
            "surrogate_path": {
                "pre-softmax":"results/wandb_transformer_interpretability_project/284sm0on/checkpoints/epoch=37-step=5585.ckpt"
            },
            "explainer_path":"results/wandb_transformer_interpretability_project/34gbowsg/checkpoints/epoch=91-step=13615.ckpt"
        }
    })    
elif _config["datasets"]=="MURA":
    backbone_type_config_dict_=OrderedDict({
        "vit_small_patch16_224":{

        },
        "deit_small_patch16_224":{
        },
        "vit_base_patch16_224":{
            "classifier_path":"results/wandb_transformer_interpretability_project/1u2xgwks/checkpoints/epoch=15-step=8255.ckpt",
            "surrogate_path": {
                #"original": "results/wandb_transformer_interpretability_project/1u2xgwks/checkpoints/epoch=15-step=8255.ckpt",
                "pre-softmax": "results/wandb_transformer_interpretability_project/22ompjqu/checkpoints/epoch=47-step=24767.ckpt",
                #"zero-input": "results/wandb_transformer_interpretability_project/2z2qs6t0/checkpoints/epoch=44-step=23219.ckpt",
                #"zero-embedding": "results/wandb_transformer_interpretability_project/1pbmwnvb/checkpoints/epoch=45-step=23735.ckpt"
            },
            "explainer_path":"results/wandb_transformer_interpretability_project/1dmhcwej/checkpoints/epoch=93-step=48597.ckpt"
        },
        "deit_base_patch16_224":{

        }
    })
    
elif _config["datasets"]=="Pet":
    backbone_type_config_dict_=OrderedDict({
        "vit_small_patch16_224":{

        },
        "deit_small_patch16_224":{
        },
        "vit_base_patch16_224":{
            "classifier_path":"results/wandb_transformer_interpretability_project/3g01rci7/checkpoints/epoch=9-step=909.ckpt",
            "surrogate_path": {
                "original": "results/wandb_transformer_interpretability_project/3g01rci7/checkpoints/epoch=9-step=909.ckpt",
                "pre-softmax": "results/wandb_transformer_interpretability_project/146vf465/checkpoints/epoch=40-step=3730.ckpt",
                #"zero-input": "results/wandb_transformer_interpretability_project/2z2qs6t0/checkpoints/epoch=44-step=23219.ckpt",
                #"zero-embedding": "results/wandb_transformer_interpretability_project/1pbmwnvb/checkpoints/epoch=45-step=23735.ckpt"
            },
            "explainer_path":"results/wandb_transformer_interpretability_project/2oq7lhr7/checkpoints/epoch=85-step=7911.ckpt"
        },
        "deit_base_patch16_224":{

        }
    })    

# Load dataset

In [None]:
def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,
                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:
    """
    Args:
        num_players: the number of players in the coalitional game
        num_mask_samples: the number of masks to generate
        paired_mask_samples: if True, the generated masks are pairs of x and 1-x.
        mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')
        random_state: random generator

    Returns:
        torch.Tensor of shape
        (num_masks, num_players) if num_masks is int
        (num_players) if num_masks is None

    """
    random_state = random_state or np.random

    num_samples_ = num_mask_samples or 1

    if paired_mask_samples:
        assert num_samples_ % 2 == 0, "'num_samples' must be a multiple of 2 if 'paired' is True"
        num_samples_ = num_samples_ // 2
    else:
        num_samples_ = num_samples_

    if mode == 'uniform':
        masks = (random_state.rand(num_samples_, num_players) > random_state.rand(num_samples_, 1)).astype('int')
    elif mode == 'shapley':
        probs = 1 / (np.arange(1, num_players) * (num_players - np.arange(1, num_players)))
        probs = probs / probs.sum()
        masks = (random_state.rand(num_samples_, num_players) > 1 / num_players * random_state.choice(
            np.arange(num_players - 1), p=probs, size=[num_samples_, 1])).astype('int')
    else:
        raise ValueError("'mode' must be 'random' or 'shapley'")

    if paired_mask_samples:
        masks = np.stack([masks, 1 - masks], axis=1).reshape(num_samples_ * 2, num_players)

    if num_mask_samples is None:
        masks = masks.squeeze(0)
        return masks  # (num_masks)
    else:
        return masks  # (num_samples, num_masks)

def set_datamodule(datasets,
                   dataset_location,
                   explanation_location_train,
                   explanation_mask_amount_train,
                   explanation_mask_ascending_train,
                   
                   explanation_location_val,
                   explanation_mask_amount_val,
                   explanation_mask_ascending_val,                   
                   
                   explanation_location_test,
                   explanation_mask_amount_test,
                   explanation_mask_ascending_test,                   
                   
                   transforms_train,
                   transforms_val,
                   transforms_test,
                   num_workers,
                   per_gpu_batch_size,
                   test_data_split):
    dataset_parameters = {
        "dataset_location": dataset_location,
        "explanation_location_train": explanation_location_train,
        "explanation_mask_amount_train": explanation_mask_amount_train,
        "explanation_mask_ascending_train": explanation_mask_ascending_train,
        
        "explanation_location_val": explanation_location_val,
        "explanation_mask_amount_val": explanation_mask_amount_val,
        "explanation_mask_ascending_val": explanation_mask_ascending_val,
        
        "explanation_location_test": explanation_location_test,
        "explanation_mask_amount_test": explanation_mask_amount_test,
        "explanation_mask_ascending_test": explanation_mask_ascending_test,        
        
        "transforms_train": transforms_train,
        "transforms_val": transforms_val,
        "transforms_test": transforms_test,
        "num_workers": num_workers,
        "per_gpu_batch_size": per_gpu_batch_size,
        "test_data_split": test_data_split
    }

    if datasets == "MURA":
        datamodule = MURADataModule(**dataset_parameters)
    elif datasets == "ImageNette":
        datamodule = ImageNetteDataModule(**dataset_parameters)
    elif datasets == "Pet":
        datamodule = PetDataModule(**dataset_parameters)        
    else:
        ValueError("Invalid 'datasets' configuration")
    return datamodule

datamodule = set_datamodule(datasets=_config["datasets"],
                            dataset_location=_config["dataset_location"],

                            explanation_location_train=_config["explanation_location_train"],
                            explanation_mask_amount_train=_config["explanation_mask_amount_train"],
                            explanation_mask_ascending_train=_config["explanation_mask_ascending_train"],

                            explanation_location_val=_config["explanation_location_val"],
                            explanation_mask_amount_val=_config["explanation_mask_amount_val"],
                            explanation_mask_ascending_val=_config["explanation_mask_ascending_val"],

                            explanation_location_test=_config["explanation_location_test"],
                            explanation_mask_amount_test=_config["explanation_mask_amount_test"],
                            explanation_mask_ascending_test=_config["explanation_mask_ascending_test"],                            

                            transforms_train=_config["transforms_train"],
                            transforms_val=_config["transforms_val"],
                            transforms_test=_config["transforms_test"],
                            num_workers=_config["num_workers"],
                            per_gpu_batch_size=_config["per_gpu_batch_size"],
                            test_data_split=_config["test_data_split"])

# The batch for training classifier consists of images and labels, but the batch for training explainer consists of images and masks.
# The masks are generated to follow the Shapley distribution.
"""
original_getitem = copy.deepcopy(datamodule.dataset_cls.__getitem__)
def __getitem__(self, idx):
    if self.split == 'train':
        masks = generate_mask(num_players=surrogate.num_players,
                              num_mask_samples=_config["explainer_num_mask_samples"],
                              paired_mask_samples=_config["explainer_paired_mask_samples"], mode='shapley')
    elif self.split == 'val' or self.split == 'test':
        # get cached if available
        if not hasattr(self, "masks_cached"):
            self.masks_cached = {}
        masks = self.masks_cached.setdefault(idx, generate_mask(num_players=surrogate.num_players,
                                                                num_mask_samples=_config[
                                                                    "explainer_num_mask_samples"],
                                                                paired_mask_samples=_config[
                                                                    "explainer_paired_mask_samples"],
                                                                mode='shapley'))
    else:
        raise ValueError("'split' variable must be train, val or test.")
    return {"images": original_getitem(self, idx)["images"],
            "labels": original_getitem(self, idx)["labels"],
            "masks": masks}
datamodule.dataset_cls.__getitem__ = __getitem__
"""

datamodule.set_train_dataset()
datamodule.set_val_dataset()
datamodule.set_test_dataset()

train_dataset=datamodule.train_dataset
val_dataset=datamodule.val_dataset
test_dataset=datamodule.test_dataset

dset=test_dataset

if dataset_split=="train":
    dset.data = train_dataset.data
elif dataset_split=="val":
    dset.data = val_dataset.data     
elif dataset_split=="test": 
    dset.data = test_dataset.data
else:
    raise


if _config["datasets"]=="ImageNette":
    
    labels = np.array([i['label'] for i in dset.data])
    num_classes = labels.max() + 1

    images_idx_list = [np.where(labels == category)[0] for category in range(num_classes)]

    images_idx=[]
    for classidx in range(4,4+int(10/len(images_idx_list))):
        images_idx+=[category_idx[classidx] for category_idx in images_idx_list]

    xy=[dset[idx] for idx in images_idx]
    x, y = zip(*[(i['images'], i['labels']) for i in xy])
    x = torch.stack(x)
    y_labels=[dset.labels[i] for i in y]    
    
    label_name_list=['Cassette player', 
                      'Garbage truck', 
                      'Tench', 
                      'English springer', 
                      'Church', 
                      'Parachute', 
                      'French horn', 
                      'Chain saw', 
                      'Golf ball', 
                      'Gas pump']
    
elif _config["datasets"]=="MURA":
    label_name_list=["Normal", "Abnormal"]
    
elif _config["datasets"]=="Pet":
    
    labels = np.array([i['label'] for i in dset.data])
    num_classes = labels.max() + 1

    images_idx_list = [np.where(labels == category)[0] for category in range(num_classes)]

    images_idx=[]
    for classidx in range(4,4+int(37/len(images_idx_list))):
        images_idx+=[category_idx[classidx] for category_idx in images_idx_list]

    xy=[dset[idx] for idx in images_idx]
    x, y = zip(*[(i['images'], i['labels']) for i in xy])
    x = torch.stack(x)
    y_labels=[dset.labels[i] for i in y]      

In [None]:
train_dataset=datamodule.train_dataset
val_dataset=datamodule.val_dataset
test_dataset=datamodule.test_dataset

In [None]:
len(train_dataset), len(val_dataset), len(test_dataset)

# Load models

In [None]:
backbone_type_config_dict = OrderedDict()
for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict_.items()):
    if backbone_type in backbone_to_use:
        print(backbone_type)
        backbone_type_config_dict[backbone_type]=backbone_type_config

In [None]:
classifier_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    classifier_dict[backbone_type] = Classifier(backbone_type=backbone_type,
                                               download_weight=_config['classifier_download_weight'],
                                               load_path=backbone_type_config["classifier_path"],
                                               target_type=_config["target_type"],
                                               output_dim=_config["output_dim"],
                                               enable_pos_embed=_config["classifier_enable_pos_embed"],

                                               checkpoint_metric=None,
                                               loss_weight=None,
                                               optim_type=None,
                                               learning_rate=None,
                                               weight_decay=None,
                                               decay_power=None,
                                               warmup_steps=None).to(_config["gpus_classifier"][idx])

In [None]:
classifier_dict_ = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    classifier_dict_[backbone_type] = Surrogate(mask_location=_config["surrogate_mask_location"],
                                                   backbone_type=backbone_type,
                                                   download_weight=_config['classifier_download_weight'],
                                                   load_path=backbone_type_config["classifier_path"],
                                                   target_type=_config["target_type"],
                                                   output_dim=_config["output_dim"],

                                                   target_model=None,
                                                   checkpoint_metric=None,
                                                   optim_type=None,
                                                   learning_rate=None,
                                                   weight_decay=None,
                                                   decay_power=None,
                                                   warmup_steps=None).to(_config["gpus_classifier"][idx])

In [None]:
if evaluation_stage=="7_classifiermasked":
    classifier_masked_dict = OrderedDict()

    for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
        classifier_masked_dict[backbone_type] = ClassifierMasked(mask_location=_config["classifier_masked_mask_location"],
                                                               backbone_type=backbone_type,
                                                               download_weight=_config['classifier_download_weight'],
                                                               load_path=backbone_type_config["classifier_masked_path"],
                                                               target_type=_config["target_type"],
                                                               output_dim=_config["output_dim"],

                                                               checkpoint_metric=None,
                                                               loss_weight=None,                                                             
                                                               optim_type=None,
                                                               learning_rate=None,
                                                               weight_decay=None,
                                                               decay_power=None,
                                                               warmup_steps=None).to(_config["gpus_classifier"][idx])

In [None]:
surrogate_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    mask_method_dict = OrderedDict()
    for mask_location in backbone_type_config["surrogate_path"].keys():
        mask_method_dict[mask_location] = Surrogate(mask_location=mask_location if mask_location!="original" else "pre-softmax",
                                          backbone_type=backbone_type,
                                          download_weight=_config['surrogate_download_weight'],
                                          load_path=backbone_type_config["surrogate_path"][mask_location],
                                          target_type=_config["target_type"],
                                          output_dim=_config["output_dim"],

                                          target_model=None,
                                          checkpoint_metric=None,
                                          optim_type=None,
                                          learning_rate=None,
                                          weight_decay=None,
                                          decay_power=None,
                                          warmup_steps=None).to(_config["gpus_surrogate"][idx])
    surrogate_dict[backbone_type]=mask_method_dict

In [None]:
from vitmedical.modules.explainer import Explainer

In [None]:
_config.update({'explainer_normalization': "additive",
                'explainer_activation': "tanh",
                'explainer_link': 'sigmoid' if _config["output_dim"]==1 else 'softmax',
                'explainer_head_num_attention_blocks': 1,
                'explainer_head_include_cls': True,
                'explainer_head_num_mlp_layers': 3,
                'explainer_head_mlp_layer_ratio': 4,
                'explainer_residual': [],
                'explainer_freeze_backbone': "all"})

explainer_dict = OrderedDict()
for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    explainer_dict[backbone_type] = Explainer(normalization=_config["explainer_normalization"],
                                              normalization_class=_config["explainer_normalization_class"],
                                              activation=_config["explainer_activation"],
                                              surrogate=surrogate_dict[backbone_type]["pre-softmax"],
                                              link=_config["explainer_link"],
                                              backbone_type=backbone_type,
                                              download_weight=False,
                                              residual=_config['explainer_residual'],
                                              load_path=backbone_type_config["explainer_path"],
                                              target_type=_config["target_type"],
                                              output_dim=_config["output_dim"],

                                              explainer_head_num_attention_blocks=_config["explainer_head_num_attention_blocks"],
                                              explainer_head_include_cls=_config["explainer_head_include_cls"],
                                              explainer_head_num_mlp_layers=_config["explainer_head_num_mlp_layers"],
                                              explainer_head_mlp_layer_ratio=_config["explainer_head_mlp_layer_ratio"],
                                              explainer_norm=_config["explainer_norm"],

                                              efficiency_lambda=_config["explainer_efficiency_lambda"],
                                              efficiency_class_lambda=_config["explainer_efficiency_class_lambda"],
                                              freeze_backbone=_config["explainer_freeze_backbone"],

                                              checkpoint_metric=_config["checkpoint_metric"],
                                              optim_type=_config["optim_type"],
                                              learning_rate=_config["learning_rate"],
                                              weight_decay=_config["weight_decay"],
                                              decay_power=_config["decay_power"],
                                              warmup_steps=_config["warmup_steps"]).to(_config["gpus_explainer"][idx])

# explanation methods

## attention

In [None]:
def compute_joint_attention(attentions, add_residual=True):
    """
    Args:
        attentions: (num_batches, num_layers, num_players, num_players)
        add_residual: bool
    Returns:
        joint_attentions: (num_batches, num_layers, num_players, num_players)
    """
    assert len(attentions.shape)==4
    if add_residual:
        residual_att = np.eye(attentions.shape[2])[np.newaxis, np.newaxis, ...]
        aug_attentions = attentions + residual_att
        aug_attentions = aug_attentions / aug_attentions.sum(axis=-1)[..., np.newaxis]
    else:
        aug_attentions =  attentions
    
    joint_attentions = np.zeros(aug_attentions.shape) # (num_batches, num_layers, num_players, num_players)

    for i in np.arange(joint_attentions.shape[1]):
        if i==0:
            joint_attentions[:,i] = aug_attentions[:,0]
        else:
            joint_attentions[:,i] = (aug_attentions[:,i] @ joint_attentions[:,i-1])
    return joint_attentions


def attentions_to_explanation(attentions, mode='rollout'):
    """
    Args:
        attentions: (num_batches, num_layers, num_heads, num_players, num_players)
    """
    assert len(attentions.shape)==5 and attentions.shape[-1]==attentions.shape[-2]
    attentions_nohead = attentions.sum(axis=2)/attentions.shape[2] # (num_batch, num_layers, num_players, num_players)
    attentions_nohead_residual = attentions_nohead + np.eye(attentions_nohead.shape[2])[np.newaxis, np.newaxis, ...] # (num_batch, num_layers, num_players, num_players)
    attentions_nohead_residual_normalized = attentions_nohead_residual / attentions_nohead_residual.sum(axis=-1)[..., np.newaxis] # (num_batch, num_layers, num_players, num_players)
    
    if isinstance(mode, int):
        return attentions_nohead_residual_normalized[:, mode, 0, 1:]
    elif mode=='raw':
        return attentions_nohead_residual_normalized[:, -1, 0, 1:]
    elif mode=='rollout':
        attentions_nohead_residual_normalized_rollout = compute_joint_attention(attentions_nohead_residual_normalized,
                                                                                add_residual=False)
        return attentions_nohead_residual_normalized_rollout[:, -1, 0, 1:]
#explanation_to_mask(attention_rollout).argmin(axis=2)

## lrp

In [None]:
import utils.transformer_explainability.baselines.ViT.ViT_new as ViT_new
import utils.transformer_explainability.baselines.ViT.ViT_LRP as ViT_LRP
import utils.transformer_explainability.baselines.ViT.ViT_orig_LRP as ViT_orig_LRP

from utils.transformer_explainability.baselines.ViT.ViT_explanation_generator import Baselines, LRP

In [None]:
baselines_dict = OrderedDict()
lrp_dict = OrderedDict()
orig_lrp_dict = OrderedDict()


for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    checkpoint = torch.load(backbone_type_config["classifier_path"], map_location="cpu")
    checkpoint["state_dict"]=OrderedDict([(k.replace('backbone.',''), v) for k, v in checkpoint["state_dict"].items()])
    state_dict = checkpoint["state_dict"]
    
    model = getattr(ViT_new, backbone_type)(num_classes=_config["output_dim"]).to(_config["gpus_classifier"][idx])
    ret = model.load_state_dict(state_dict, strict=False)
    print(f"Model parameters were updated from a checkpoint file {backbone_type_config['classifier_path']}")
    print(f"Unmatched parameters - missing_keys:    {ret.missing_keys}")
    print(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}")
    model.eval()
    with torch.no_grad():
        output1=model(x.to(next(model.parameters()).device))
        output2=classifier_dict[backbone_type](x.to(next(model.parameters()).device))['logits']
        assert torch.allclose(output1,output2,atol=1e-03)
    baselines = Baselines(model)
    baselines_dict[backbone_type]=baselines        
    
    model_LRP=getattr(ViT_LRP, backbone_type)(num_classes=_config["output_dim"]).to(_config["gpus_classifier"][idx])
    ret = model_LRP.load_state_dict(state_dict, strict=False)
    print(f"Model parameters were updated from a checkpoint file {backbone_type_config['classifier_path']}")
    print(f"Unmatched parameters - missing_keys:    {ret.missing_keys}")
    print(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}")
    model_LRP.eval()      
    lrp = LRP(model_LRP)
    lrp_dict[backbone_type]=lrp
    
#     model_orig_LRP=getattr(ViT_orig_LRP, backbone_type)(num_classes=_config["output_dim"]).to(_config["gpus_classifier"][idx])
#     ret = model_orig_LRP.load_state_dict(state_dict, strict=False)
#     print(f"Model parameters were updated from a checkpoint file {backbone_type_config['classifier_path']}")
#     print(f"Unmatched parameters - missing_keys:    {ret.missing_keys}")
#     print(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}")
#     model_orig_LRP.eval()    
#     orig_lrp = LRP(model_orig_LRP)  
#     orig_lrp_dict[backbone_type]=orig_lrp
    
    
def get_lrp_module_explanation(backbone_type, original_image, class_index=None, mode='transformer_attribution'):
    if mode=="transformer_attribution": # ours
        transformer_attribution = lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(lrp_dict[backbone_type].model.parameters()).device), method="transformer_attribution", index=class_index).detach()
    elif mode=="rollout": # rollout
        transformer_attribution = baselines_dict[backbone_type].generate_rollout(original_image.unsqueeze(0).to(next(baselines_dict[backbone_type].model.parameters()).device), start_layer=1).detach()
    elif mode=="attn_last_layer": # raw-attention
        transformer_attribution = lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(lrp_dict[backbone_type].model.parameters()).device), method="last_layer_attn", index=class_index).detach()
    elif mode == 'attn_gradcam': # GradCAM
        transformer_attribution = baselines_dict[backbone_type].generate_cam_attn(original_image.unsqueeze(0).to(next(baselines_dict[backbone_type].model.parameters()).device), index=class_index).detach()
        transformer_attribution = transformer_attribution.reshape(1,-1)
        #transformer_attribution=torch.nan_to_num(transformer_attribution,nan=0)
        #transformer_attribution+=torch.rand(size=transformer_attribution.shape, device=transformer_attribution.device)*1e-20        
        #transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    elif mode == 'full_lrp':
        transformer_attribution = orig_lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(orig_lrp_dict[backbone_type].model.parameters()).device), method="full", index=class_index).detach()
        #transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    elif mode == 'lrp_last_layer':
        transformer_attribution = orig_lrp_dict[backbone_type].generate_LRP(original_image.unsqueeze(0).to(next(orig_lrp_dict[backbone_type].model.parameters()).device), method="last_layer", index=class_index).detach()
        #transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    #print(transformer_attribution.max(), transformer_attribution.min())
    #print(transformer_attribution.shape)
    return transformer_attribution    

## CAM

In [None]:
from utils.pytorch_grad_cam import GradCAM
from utils.pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

def reshape_transform(tensor, height=14, width=14):
    result = tensor[:, 1 :  , :].reshape(tensor.size(0),
        height, width, tensor.size(2))

    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result

class WrapperLogits(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model=model

    def forward(self, images):
        x = self.model(images)
        return x['logits']

cam_dict = OrderedDict()
for backbone_type, backbone_type_config in backbone_type_config_dict.items():
    cam_dict[backbone_type] = GradCAM(model=WrapperLogits(classifier_dict[backbone_type]),
                                      target_layers=[classifier_dict[backbone_type].backbone.blocks[-1].norm1],
                                      reshape_transform=reshape_transform)

## Gradient-based

In [None]:
from captum.attr import IntegratedGradients, InputXGradient, Saliency, NoiseTunnel
import torch.nn as nn

class FromPixel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model=model

    def forward(self, images):
        x = self.model.backbone.patch_embed(images)
        x = self.model.backbone.forward_features(x)['x']
        logits = self.model.head(x)
        
        if _config["output_dim"]==1:
            return logits.sigmoid()
        else:
            return logits.softmax(dim=-1)
    
class FromEmbedding(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model=model

    def forward(self, embedding):
        x = self.model.backbone.forward_features(embedding)['x']
        logits = self.model.head(x)
        
        if _config["output_dim"]==1:
            return logits.sigmoid()
        else:
            return logits.softmax(dim=-1)

#Classifier Wrapping    
classifier_pixel_dict = OrderedDict()
classifier_embedding_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    classifier_pixel_dict[backbone_type]=FromPixel(classifier_dict_[backbone_type])
    classifier_embedding_dict[backbone_type]=FromEmbedding(classifier_dict_[backbone_type])

#Vanilla
saliency_pixel_dict = OrderedDict()
saliency_embedding_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    saliency_pixel_dict[backbone_type] = Saliency(classifier_pixel_dict[backbone_type])
    saliency_embedding_dict[backbone_type] = Saliency(classifier_embedding_dict[backbone_type])      

#NoiseTunnel
noisetunnel_pixel_dict = OrderedDict()
noisetunnel_embedding_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    noisetunnel_pixel_dict[backbone_type] = NoiseTunnel(saliency_pixel_dict[backbone_type])
    noisetunnel_embedding_dict[backbone_type] = NoiseTunnel(saliency_embedding_dict[backbone_type])      

#IntegratedGradients    
ig_pixel_dict = OrderedDict()
ig_embedding_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    ig_pixel_dict[backbone_type] = IntegratedGradients(classifier_pixel_dict[backbone_type])
    ig_embedding_dict[backbone_type] = IntegratedGradients(classifier_embedding_dict[backbone_type])

In [None]:
def attributions_pixel_process(attributions_pixel):
    attributions_pixel_sum = attributions_pixel.sum(axis=-3)
    attributions_pixel_abssum = attributions_pixel.abs().sum(axis=-3)
    attributions_pixel_patchsum = F.conv2d(attributions_pixel,
                                           weight=torch.ones(size=(1, 3, 16, 16),
                                                             dtype=attributions_pixel.dtype,
                                                             device=attributions_pixel.device),
                                           stride=16).squeeze(axis=1)#.flatten(1, 2)
    attributions_pixel_pathabssum = F.conv2d(attributions_pixel.abs(),
                                             weight=torch.ones(size=(1, 3, 16, 16),
                                                               dtype=attributions_pixel.dtype,
                                                               device=attributions_pixel.device),
                                             stride=16).squeeze(axis=1)#.flatten(1, 2) 
    
    return {'attributions_pixel_sum': attributions_pixel_sum.detach().cpu(),# makes sense? (but cannot used for benchmarking)
            'attributions_pixel_abssum': attributions_pixel_abssum.detach().cpu(),# makes sense (but cannot used for benchmarking)
            'attributions_pixel_patchsum': attributions_pixel_patchsum.detach().cpu(),  # makes sense?
            'attributions_pixel_patchabssum': attributions_pixel_pathabssum.detach().cpu()  # makes sense    
           }
    
    
def attributions_embedding_process(attributions_embedding):
    attributions_embedding_sum = attributions_embedding.sum(axis=-1)
    attributions_embedding_abssum = attributions_embedding.abs().sum(axis=-1)
    return {'attributions_embedding_sum': attributions_embedding_sum.detach().cpu(), # makes sense?
            'attributions_embedding_abssum': attributions_embedding_abssum.detach().cpu() # makes sense
           }  

def get_vanilla(image, saliency_pixel=None, saliency_embedding=None):
    result={}
    with torch.no_grad():
        if saliency_pixel is not None:
            attributions_pixel = [saliency_pixel.attribute(inputs=image.unsqueeze(0).to(next(saliency_pixel.forward_func.parameters()).device), 
                                                           target=i) for i in range(_config["output_dim"])]

            attributions_pixel = torch.concat(attributions_pixel)
            result.update(attributions_pixel_process(attributions_pixel))
            
        if saliency_embedding is not None:
            attributions_embedding = [saliency_embedding.attribute(inputs=saliency_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(saliency_embedding.forward_func.parameters()).device)).detach(),
                                                                   target=i) for i in range(_config["output_dim"])]

            attributions_embedding = torch.concat(attributions_embedding)
            result.update(attributions_embedding_process(attributions_embedding))
        
    return result

def get_sg(image, noisetunnel_pixel=None, noisetunnel_embedding=None):
    result={}    
    with torch.no_grad():
        if noisetunnel_pixel is not None:
            attributions_pixel = [noisetunnel_pixel.attribute(inputs=image.unsqueeze(0).to(next(noisetunnel_pixel.forward_func.parameters()).device),
                                                              nt_type='smoothgrad',
                                                              nt_samples=10,
                                                              target=i) for i in range(_config["output_dim"])]

            attributions_pixel = torch.concat(attributions_pixel)
            result.update(attributions_pixel_process(attributions_pixel))

        if noisetunnel_embedding is not None:
            attributions_embedding  = [noisetunnel_embedding.attribute(inputs=noisetunnel_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(noisetunnel_embedding.forward_func.parameters()).device)).detach(),
                                                                       nt_type='smoothgrad',
                                                                       nt_samples=10,            
                                                                       target=i) for i in range(_config["output_dim"])]

            attributions_embedding = torch.concat(attributions_embedding)
            result.update(attributions_embedding_process(attributions_embedding))

        
    return result

def get_vargrad(image, noisetunnel_pixel=None, noisetunnel_embedding=None):
    result={}    
    with torch.no_grad():
        if noisetunnel_pixel is not None:
            attributions_pixel = [noisetunnel_pixel.attribute(inputs=image.unsqueeze(0).to(next(noisetunnel_pixel.forward_func.parameters()).device),
                                                              nt_type='vargrad',
                                                              nt_samples=10,
                                                              target=i) for i in range(_config["output_dim"])]

            attributions_pixel = torch.concat(attributions_pixel)
            result.update(attributions_pixel_process(attributions_pixel))   

        if noisetunnel_embedding is not None:
            attributions_embedding  = [noisetunnel_embedding.attribute(inputs=noisetunnel_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(noisetunnel_embedding.forward_func.parameters()).device)).detach(),
                                                                       nt_type='vargrad',
                                                                       nt_samples=10,            
                                                                       target=i) for i in range(_config["output_dim"])]

            attributions_embedding = torch.concat(attributions_embedding)
            result.update(attributions_embedding_process(attributions_embedding))            

        
    return result

def get_ig(image, ig_pixel=None, ig_embedding=None):
    result={}    
    with torch.no_grad():
        if ig_pixel is not None:
            attributions_pixel = [ig_pixel.attribute(inputs=image.unsqueeze(0).to(next(ig_pixel.forward_func.parameters()).device),
                                                                            baselines=torch.zeros_like(image).unsqueeze(0).to(next(ig_pixel.forward_func.parameters()).device),
                                                                            target=i,
                                                                            n_steps=10) for i in range(_config["output_dim"])]

            attributions_pixel = torch.concat(attributions_pixel)
            result.update(attributions_pixel_process(attributions_pixel))           

        if ig_embedding is not None:
            attributions_embedding = [ig_embedding.attribute(inputs=ig_embedding.forward_func.model.backbone.patch_embed(image.unsqueeze(0).to(next(ig_embedding.forward_func.parameters()).device)).detach(),
                                                                                           baselines=ig_embedding.forward_func.model.backbone.patch_embed(torch.zeros_like(image).unsqueeze(0).to(next(ig_embedding.forward_func.parameters()).device)).detach(),
                                                                                           target=i,
                                                                                           n_steps=10) for i in range(_config["output_dim"])]

            attributions_embedding = torch.concat(attributions_embedding)
            result.update(attributions_embedding_process(attributions_embedding))          
        
    return result

## leave-one-out

In [None]:
def leave_one_out(image, surrogate=None, classifier=None):
    with torch.no_grad():
        mask=torch.cat([torch.ones(1, 196) ,1-torch.eye(196)])
        if surrogate is not None:
            out=surrogate(image.unsqueeze(0).repeat(196+1, 1, 1, 1).to(surrogate.device), 
                          masks=mask.to(surrogate.device))
        elif classifier is not None:
            mask_scaled = torch.repeat_interleave(torch.repeat_interleave(mask.reshape(-1, 14, 14), 16, dim=2), 16, dim=1)
            image_masked=image * mask_scaled.unsqueeze(1)
            
            if classifier.__class__==Classifier:
                out=classifier(image_masked.to(classifier.device))
            elif classifier.__class__==Surrogate:
                out=classifier(image_masked.to(classifier.device),
                              masks=torch.ones((len(image_masked),196)))
            else:
                raise
            
        if _config["output_dim"]==1:
            prob=out['logits'].sigmoid().detach().cpu().numpy()
        else:
            prob=out['logits'].softmax(dim=-1).detach().cpu().numpy()    
        
        result=prob[0:1]-prob[1:]

    return result.transpose(1,0)

# RISE

In [None]:
def rise(image, surrogate=None, classifier=None, include_prob=0.5, N=2000):
    assert (surrogate is None) != (classifier is None)
    
    prob_list=[]
    mask_list=[]
    
    with torch.no_grad():
        for i in range(N//100):
            mask=torch.rand(100, 196)<include_prob
            if surrogate is not None:
                out=surrogate(image.unsqueeze(0).repeat(100, 1, 1, 1).to(surrogate.device), 
                              masks=(mask).to(surrogate.device))
            elif classifier is not None:
                mask_scaled = torch.repeat_interleave(torch.repeat_interleave(mask.reshape(-1, 14, 14), 16, dim=2), 16, dim=1)
                image_masked = image * mask_scaled.unsqueeze(1)
                del mask_scaled
                if classifier.__class__==Classifier:
                    out=classifier(image_masked.to(classifier.device))
                elif classifier.__class__==Surrogate:
                    out=classifier(image_masked.to(classifier.device),
                                  masks=torch.ones_like(mask))
                else:
                    raise
                #out=surrogate_dict[backbone_type](image_masked.to(surrogate_dict[backbone_type].device), 
                #             masks=torch.ones((100,196)).to(surrogate_dict[backbone_type].device))                
            else:
                raise
            
            if _config["output_dim"]==1:
                prob=out['logits'].sigmoid().detach().cpu().numpy()
            else:
                prob=out['logits'].softmax(dim=-1).detach().cpu().numpy()    
            
            del out
            prob_list.append(prob)
            mask_list.append(mask.numpy())
            del mask
            
            
    prob_list_array=np.concatenate(prob_list) # (num_trials, num_classes)
    mask_list_array=np.concatenate(mask_list) # (num_trials, num_players)

    result = (prob_list_array.T @ mask_list_array) # (num_classes, num_players)
    result = result/mask_list_array.sum(axis=0)
    
    return result

# KernelSHAP

In [None]:
from utils.shapreg import removal, games, shapley

class SurrogateSHAPWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model=model
        if _config["output_dim"]==1:
            self.activation=nn.Sigmoid()
        else:
            self.activation=nn.Softmax(dim=-1)
        
    def forward(self, x):
        images, mask = x
        mask = mask.squeeze(1).flatten(1)
        out=self.model(images, mask)['logits']
        out=self.activation(out)
        return out

surrogate_SHAP_wrapped_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    surrogate_SHAP_wrapped_dict[backbone_type]=SurrogateSHAPWrapper(surrogate_dict[backbone_type]["pre-softmax"])    

def get_shap(surrogate_SHAP_wrapped, x, batch_size=64, thresh=0.2, variance_batches=60):
    game = games.PredictionGame_torchimagetensor(surrogate_SHAP_wrapped, x)
    explanation = shapley.ShapleyRegression(game, batch_size=batch_size, thresh=thresh, variance_batches=variance_batches)
    return explanation

# save_dict_setting

In [None]:
explanation_save_dict={}
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    explanation_save_dict_backbone={"random":{},
                                    "attention_rollout":{},
                                    "attention_last":{},
                                    "LRP":{},
                                    "gradcam":{},
                                    "gradcamgithub": {},
                                    "vanillapixel": {},
                                    "vanillaembedding": {},
                                    "sgpixel": {},
                                    "sgembedding": {},
                                    "vargradpixel": {},
                                    "vargradembedding": {},               
                                    "igpixel": {},
                                    "igembedding": {},
                                    "leaveoneoutclassifier": {},
                                    "leaveoneoutsurrogate": {},
                                    "riseclassifier": {},
                                    "risesurrogate": {},
                                    "ours": {},
                                    "kernelshap": {}
                                    }
    explanation_save_dict[backbone_type]=explanation_save_dict_backbone
    
def explanation_save_dict_update(backbone_type, explanation_method,
                                 path_list, explanation_list, elapsed_time_list, 
                                 shape=None):
    explanation_save_dict_backbone_method=explanation_save_dict[backbone_type][explanation_method]
        
    assert len(path_list) == len(explanation_list) == len(elapsed_time_list)
    
    for explanation, path, elapsed_time in zip(explanation_list, path_list, elapsed_time_list):
        assert type(explanation)==np.ndarray
        assert type(path)==str
        assert type(elapsed_time)==float
        if shape is not None:
            assert explanation.shape==shape
        explanation_save_dict_backbone_method[path]={"explanation": explanation.astype(float),
                                                     "elapsed_time": elapsed_time}    

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, explanation_save_dict_backbone_method in explanation_save_dict[backbone_type].items():
        try:
            explanation_save_dict_path=f'results/3_explanation_generate/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

            if os.path.isfile(explanation_save_dict_path):
                with open(explanation_save_dict_path, 'rb') as f:
                    explanation_save_dict_loaded=pickle.load(f)
            else:
                explanation_save_dict_loaded={}

            len_original=len(explanation_save_dict_backbone_method)            
            len_loaded=len(explanation_save_dict_loaded)
            explanation_save_dict_backbone_method.update(explanation_save_dict_loaded)
            len_updated=len(explanation_save_dict_backbone_method)

            print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}') 
        except:
            print('aa')

In [None]:
insertdelete_save_dict={}
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    insertdelete_save_dict_backbone={"random":{},
                                     "attention_rollout":{},
                                     "attention_last":{},
                                     "LRP":{},
                                     "gradcam":{},
                                     "gradcamgithub": {},
                                     "vanillapixel": {},
                                     "vanillaembedding": {},
                                     "sgpixel": {},
                                     "sgembedding": {},
                                     "vargradpixel": {},
                                     "vargradembedding": {},               
                                     "igpixel": {},
                                     "igembedding": {},
                                     "leaveoneoutclassifier": {},
                                     "leaveoneoutsurrogate": {},
                                     "riseclassifier": {},
                                     "risesurrogate": {},
                                     "ours": {},
                                     "kernelshap":{}
                                    }
    insertdelete_save_dict[backbone_type]=insertdelete_save_dict_backbone
    
def insertdelete_save_dict_update(backbone_type, explanation_method,
                                 path_list, insert_list, delete_list,
                                 shape=None):
    insertdelete_save_dict_backbone_method=insertdelete_save_dict[backbone_type][explanation_method]
        
    assert len(path_list) == len(insert_list) == len(delete_list)
    
    for insert, delete, path in zip(insert_list, delete_list, path_list):
        assert type(insert)==np.ndarray
        assert type(delete)==np.ndarray
        assert type(path)==str
        if shape is not None:
            assert insert.shape==shape
            assert delete.shape==shape
        insertdelete_save_dict_backbone_method[path]={"insert": insert.astype(float),
                                                      "delete": delete.astype(float)
                                                      }    

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, insertdelete_save_dict_backbone_method in insertdelete_save_dict[backbone_type].items():
        insertdelete_save_dict_path=f'results/4_insert_delete/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

        if os.path.isfile(insertdelete_save_dict_path):
            with open(insertdelete_save_dict_path, 'rb') as f:
                insertdelete_save_dict_loaded=pickle.load(f)
        else:
            insertdelete_save_dict_loaded={}
            
        len_original=len(insertdelete_save_dict_backbone_method)            
        len_loaded=len(insertdelete_save_dict_loaded)
        insertdelete_save_dict_backbone_method.update(insertdelete_save_dict_loaded)
        len_updated=len(insertdelete_save_dict_backbone_method)
            
        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')        

In [None]:
sensitivity_save_dit={}
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    sensitivity_save_dit_backbone={"attention_rollout":{},
                                   "attention_last":{},
                                   "LRP":{},
                                   "gradcam":{},
                                   "gradcamgithub": {},
                                   "vanillapixel": {},
                                   "vanillaembedding": {},
                                   "sgpixel": {},
                                   "sgembedding": {},
                                   "vargradpixel": {},
                                   "vargradembedding": {},               
                                   "igpixel": {},
                                   "igembedding": {},
                                   "leaveoneoutclassifier": {},
                                   "leaveoneoutsurrogate": {},
                                   "riseclassifier": {},
                                   "risesurrogate": {},
                                   "ours": {},
                                   }
    sensitivity_save_dit[backbone_type]=sensitivity_save_dit_backbone
    
def sensitivity_save_dit_update(backbone_type, explanation_method, num_included_players,
                                path_list, sensitivity_list,
                                shape=None):
    
    sensitivity_save_dit_backbone_method=sensitivity_save_dit[backbone_type][explanation_method]
        
    assert len(path_list) == len(sensitivity_list)
    
    for sensitivity, path in zip(sensitivity_list, path_list):
        assert type(sensitivity)==np.ndarray
        assert type(path)==str
        if shape is not None:
            assert sensitivity.shape==shape
        sensitivity_save_dit_backbone_method.setdefault(path, {})
        sensitivity_save_dit_backbone_method[path][num_included_players]=sensitivity.astype(float)

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, sensitivity_save_dit_backbone_method in sensitivity_save_dit[backbone_type].items():
        sensitivity_save_dit_path=f'results/5_sensitivity/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

        if os.path.isfile(sensitivity_save_dit_path):
            with open(sensitivity_save_dit_path, 'rb') as f:
                sensitivity_save_dit_loaded=pickle.load(f)
        else:
            sensitivity_save_dit_loaded={}

        len_original=len(sensitivity_save_dit_backbone_method)            
        len_loaded=len(sensitivity_save_dit_loaded)
        sensitivity_save_dit_backbone_method.update(sensitivity_save_dit_loaded)
        len_updated=len(sensitivity_save_dit_backbone_method)

        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')

In [None]:
noretraining_save_dict={}
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    noretraining_save_dict_backbone={"random":{},
                                     "attention_rollout":{},
                                     "attention_last":{},
                                     "LRP":{},
                                     "gradcam":{},
                                     "gradcamgithub": {},
                                     "vanillapixel": {},
                                     "vanillaembedding": {},
                                     "sgpixel": {},
                                     "sgembedding": {},
                                     "vargradpixel": {},
                                     "vargradembedding": {},               
                                     "igpixel": {},
                                     "igembedding": {},
                                     "leaveoneoutclassifier": {},
                                     "leaveoneoutsurrogate": {},
                                     "riseclassifier": {},
                                     "risesurrogate": {},
                                     "ours": {},
                                    }
    noretraining_save_dict[backbone_type]=noretraining_save_dict_backbone
    
def noretraining_save_dict_update(backbone_type, explanation_method,
                                 path_list, insert_list, delete_list,
                                 shape=None):
    noretraining_save_dict_backbone_method=noretraining_save_dict[backbone_type][explanation_method]
        
    assert len(path_list) == len(insert_list) == len(delete_list)
    
    for insert, delete, path in zip(insert_list, delete_list, path_list):
        assert type(insert)==np.ndarray
        assert type(delete)==np.ndarray
        assert type(path)==str
        if shape is not None:
            assert insert.shape==shape
            assert delete.shape==shape
        noretraining_save_dict_backbone_method[path]={"insert": insert.astype(float),
                                                      "delete": delete.astype(float)
                                                      }    

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, noretraining_save_dict_backbone_method in noretraining_save_dict[backbone_type].items():
        noretraining_save_dict_path=f'results/6_noretraining/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

        if os.path.isfile(noretraining_save_dict_path):
            with open(noretraining_save_dict_path, 'rb') as f:
                noretraining_save_dict_loaded=pickle.load(f)
        else:
            noretraining_save_dict_loaded={}
            
        len_original=len(noretraining_save_dict_backbone_method)            
        len_loaded=len(noretraining_save_dict_loaded)
        noretraining_save_dict_backbone_method.update(noretraining_save_dict_loaded)
        len_updated=len(noretraining_save_dict_backbone_method)
            
        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                              

In [None]:
classifiermasked_save_dict={}
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    classifiermasked_save_dict_backbone={"random":{},
                                     "attention_rollout":{},
                                     "attention_last":{},
                                     "LRP":{},
                                     "gradcam":{},
                                     "gradcamgithub": {},
                                     "vanillapixel": {},
                                     "vanillaembedding": {},
                                     "sgpixel": {},
                                     "sgembedding": {},
                                     "vargradpixel": {},
                                     "vargradembedding": {},               
                                     "igpixel": {},
                                     "igembedding": {},
                                     "leaveoneoutclassifier": {},
                                     "leaveoneoutsurrogate": {},
                                     "riseclassifier": {},
                                     "risesurrogate": {},
                                     "ours": {},
                                    }
    classifiermasked_save_dict[backbone_type]=classifiermasked_save_dict_backbone
    
def classifiermasked_save_dict_update(backbone_type, explanation_method,
                                 path_list, insert_list, delete_list,
                                 shape=None):
    classifiermasked_save_dict_backbone_method=classifiermasked_save_dict[backbone_type][explanation_method]
        
    assert len(path_list) == len(insert_list) == len(delete_list)
    
    for insert, delete, path in zip(insert_list, delete_list, path_list):
        assert type(insert)==np.ndarray
        assert type(delete)==np.ndarray
        assert type(path)==str
        if shape is not None:
            assert insert.shape==shape
            assert delete.shape==shape
        classifiermasked_save_dict_backbone_method[path]={"insert": insert.astype(float),
                                                      "delete": delete.astype(float)
                                                      }    

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, classifiermasked_save_dict_backbone_method in classifiermasked_save_dict[backbone_type].items():
        classifiermasked_save_dict_path=f'results/7_classifiermasked/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

        if os.path.isfile(classifiermasked_save_dict_path):
            with open(classifiermasked_save_dict_path, 'rb') as f:
                classifiermasked_save_dict_loaded=pickle.load(f)
        else:
            classifiermasked_save_dict_loaded={}
            
        len_original=len(classifiermasked_save_dict_backbone_method)            
        len_loaded=len(classifiermasked_save_dict_loaded)
        classifiermasked_save_dict_backbone_method.update(classifiermasked_save_dict_loaded)
        len_updated=len(classifiermasked_save_dict_backbone_method)
            
        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                              

In [None]:
elapsedtime_save_dict={}
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    elapsedtime_save_dict_backbone={"random":{},
                                     "attention_rollout":{},
                                     "attention_last":{},
                                     "LRP":{},
                                     "gradcam":{},
                                     "gradcamgithub": {},
                                     "vanillapixel": {},
                                     "vanillaembedding": {},
                                     "sgpixel": {},
                                     "sgembedding": {},
                                     "vargradpixel": {},
                                     "vargradembedding": {},               
                                     "igpixel": {},
                                     "igembedding": {},
                                     "leaveoneoutclassifier": {},
                                     "leaveoneoutsurrogate": {},
                                     "riseclassifier": {},
                                     "risesurrogate": {},
                                     "ours": {},
                                    }
    elapsedtime_save_dict[backbone_type]=elapsedtime_save_dict_backbone
    
def elapsedtime_save_dict_update(backbone_type, explanation_method,
                                 path_list, elapsed_time_list,
                                 shape=None):
    elapsedtime_save_dict_backbone_method=elapsedtime_save_dict[backbone_type][explanation_method]
        
    assert len(path_list) == len(elapsed_time_list)
    
    for elapsed_time, path in zip(elapsed_time_list, path_list):
        assert type(elapsed_time)==float
        assert type(path)==str
        elapsedtime_save_dict_backbone_method[path]={"time": elapsed_time}

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, elapsedtime_save_dict_backbone_method in elapsedtime_save_dict[backbone_type].items():
        elapsedtime_save_dict_path=f'results/8_elapsedtime/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

        if os.path.isfile(elapsedtime_save_dict_path):
            with open(elapsedtime_save_dict_path, 'rb') as f:
                elapsedtime_save_dict_loaded=pickle.load(f)
        else:
            elapsedtime_save_dict_loaded={}
            
        len_original=len(elapsedtime_save_dict_backbone_method)            
        len_loaded=len(elapsedtime_save_dict_loaded)
        elapsedtime_save_dict_backbone_method.update(elapsedtime_save_dict_loaded)
        len_updated=len(elapsedtime_save_dict_backbone_method)
            
        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                                                                                    

In [None]:
estimationerror_save_dict={}
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    estimationerror_save_dict_backbone={"kernelshap":{},
                                        "kernelshapnopair":{},
                                        "ours": {}
                                        }
    estimationerror_save_dict[backbone_type]=estimationerror_save_dict_backbone
    
def estimationerror_save_dict_update(backbone_type, explanation_method,
                                     path_list, estimation_list, label_list,
                                     shape=None):
    estimationerror_save_dict_backbone_method=estimationerror_save_dict[backbone_type][explanation_method]
        
    assert len(path_list) == len(estimation_list) == len(label_list)
    
    for path, estimation, label in zip(path_list, estimation_list, label_list):
        assert type(path)==str
        #assert type(estimation)==np.ndarray
        assert type(label)==int
        
        if shape is not None:
            assert estimation.shape==shape        
        
        estimationerror_save_dict_backbone_method[path]={"estimation": estimation,
                                                         "label": label}        

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, estimationerror_save_dict_backbone_method in estimationerror_save_dict[backbone_type].items():
        estimationerror_save_dict_path=f'results/9_estimationerror/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

        if os.path.isfile(estimationerror_save_dict_path):
            with open(estimationerror_save_dict_path, 'rb') as f:
                estimationerror_save_dict_loaded=pickle.load(f)
        else:
            estimationerror_save_dict_loaded={}
            
        len_original=len(estimationerror_save_dict_backbone_method)            
        len_loaded=len(estimationerror_save_dict_loaded)
        estimationerror_save_dict_backbone_method.update(estimationerror_save_dict_loaded)
        len_updated=len(estimationerror_save_dict_backbone_method)
            
        print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')                                                                                                                            

# utils

In [None]:
def get_relative_value(x, random_seed=None):
    assert len(x.shape)==1
    
    if isinstance(random_seed, int):
        rng = np.random.default_rng(random_seed)
        perm = rng.permutation(np.arange(len(x)))
    else:
        perm = np.random.permutation(np.arange(len(x)))    

    argsorted=np.arange(len(x))[perm][np.argsort(x[perm])]
    relative_value=np.argsort(argsorted)

    return relative_value

In [None]:
# def adapt_path(path_original, dict_keys):
#     path_list = ['l0.cs.washington.edu', 'l1lambda.cs.washington.edu', 'l2lambda.cs.washington.edu',
#                  'l3.cs.washington.edu', 'deeper.cs.washington.edu', 'sync']

#     dict_keys=list(dict_keys)


#     for path1 in path_list:
#         if path1 in path_original:
#             for path2 in path_list:
#                 path_replaced=path_original.replace(path1, path2)
#                 if path_replaced in dict_keys:
#                     return path_replaced
#     return path_original
#     #raise ValueError(f"not found {path_original}")

import itertools

def adapt_path(path_original, dict_keys):
    path1_list = ['l0.cs.washington.edu', 'l1lambda.cs.washington.edu', 'l2lambda.cs.washington.edu',
                 'l3.cs.washington.edu', 'deeper.cs.washington.edu', 'sync']    
    path2_list = ['/mmfs1/home/chanwkim/', '/homes/gws/chanwkim/']
    for path1_from, path1_to in itertools.product(path1_list, path1_list):
        for path2_from, path2_to in itertools.product(path2_list, path2_list):
            if path_original.replace(path1_from, path1_to).replace(path2_from, path2_to) in dict_keys:
                return path_original.replace(path1_from, path1_to).replace(path2_from, path2_to)
    return path_original

# Methods to run

In [None]:
explanation_method_to_run_=["random", "attention_rollout", "attention_last", 
                            "LRP", "gradcam", "gradcamgithub",
                            "vanillapixel", "vanillaembedding",
                            "sgpixel", "sgembedding",
                            "vargradpixel", "vargradembedding",
                            "igpixel", "igembedding",                           
                            "leaveoneoutclassifier",
                            "riseclassifier", 
                            "ours"]
#explanation_method_to_run_=["kernelshap"]
explanation_method_to_run_=["random", "attention_rollout", "attention_last", 
                            "LRP", "gradcam", 
                            "vanillaembedding",
                            "sgembedding",
                            "vargradembedding",
                            "igembedding",                           
                            "leaveoneoutclassifier",
                            "riseclassifier", 
                            "ours"]
explanation_method_to_run=[]
explanation_method_to_run+=explanation_method_to_run_[:]


print(explanation_method_to_run)

In [None]:
data_loader=DataLoader(dset, batch_size=16, shuffle=False, drop_last=False, num_workers=4) #16
print(len(dset))
print(len(data_loader))

# 1_classifier_evaluate

In [None]:
if evaluation_stage=="1_classifier_evaluate":    
    classifier_result_list_all={}
    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
        classifier_result_list_all[backbone_type]={}
           
    
    for idx, batch in enumerate(tqdm(data_loader)):

        images=batch['images']
        labels=batch['labels']
        paths=batch['path']

        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
            # Get classifier output
            classifier_dict[backbone_type].eval()
            with torch.no_grad():
                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),
                                                                 output_attentions=True)
            if _config["output_dim"]==1:
                prob=classifier_output['logits'].sigmoid().cpu().numpy()
            else:
                prob=classifier_output['logits'].softmax(dim=-1).cpu().numpy()          
                
                
            for path, label, prob in zip(paths, labels, prob):
                classifier_result_list_all[backbone_type][path]={'label':label.item(), 'prob':prob.astype(float)}
                
    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
        classifier_result_list_path=f'results/1_classifier_evaluate/{_config["datasets"]}/{backbone_type}_{dataset_split}.pickle'
        with open(classifier_result_list_path, "wb") as f:
            pickle.dump(classifier_result_list_all[backbone_type], f)        

# 2_surrogate_evaluate

In [None]:
if evaluation_stage=="2_surrogate_evaluate":
    result_list_all={}
    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
        result_list_all[backbone_type]=[]

    dset_loader=DataLoader(dset, batch_size=64, num_workers=4, shuffle=False, drop_last=True)

    for batch_idx, batch in enumerate(tqdm(dset_loader, unit='batch')):  
        for num_mask in range(0,196+1,14):
            mask=torch.zeros((len(batch["images"]), 196))
            mask[:,:num_mask]=1
            for i in range(len(mask)):
                mask[i]=mask[i][torch.randperm(len(mask[i]))]
            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
                surrogate_dict[backbone_type]["original"].eval()
                with torch.no_grad():
                    out_original=surrogate_dict[backbone_type]["original"](batch["images"].to(surrogate_dict[backbone_type]["original"].device),
                                                                          torch.ones((len(batch["images"]), 196)).to(surrogate_dict[backbone_type]["original"].device))

                for mask_location_model in ["original" , "pre-softmax", "zero-input", "zero-embedding"]:
                    if mask_location_model=="original":
                        kl_divergence=0

                        if _config["output_dim"]==1:
                            accuracy=((out_original["logits"].sigmoid()>0.5).cpu().int()==batch['labels']).float().mean().item()
                        else:
                            accuracy=(torch.argmax(out_original["logits"], dim=1).cpu()==batch['labels']).float().mean().item()

                        result_list_all[backbone_type].append({"batch_idx": batch_idx,
                                            "backbone_type": backbone_type,
                                            "num_mask": num_mask,
                                            "mask_location_model": mask_location_model,
                                            "mask_location_parameter": "original",
                                            "kl_divergence": kl_divergence,
                                            "accuracy": accuracy})

                    for mask_location_parameter in ["pre-softmax", "post-softmax", "zero-input", "zero-embedding", "random-sampling"]:
                        surrogate_dict[backbone_type][mask_location_model].eval()
                        with torch.no_grad():
                            out_surrogate=surrogate_dict[backbone_type][mask_location_model](batch["images"].to(surrogate_dict[backbone_type][mask_location_model].device), 
                                                                                             mask.to(surrogate_dict[backbone_type][mask_location_model].device),
                                                                                             mask_location_parameter)


                        if _config["output_dim"]==1:
                            kl_divergence = F.kl_div(input=torch.concat([F.logsigmoid(out_surrogate["logits"]), F.logsigmoid(-out_surrogate["logits"])], dim=1),
                                                    target=torch.concat([torch.sigmoid(out_original["logits"]), torch.sigmoid(-out_original["logits"])], dim=1),
                                                    reduction="batchmean",
                                                    log_target=False)                        

                        else:
                            kl_divergence=F.kl_div(input=torch.log_softmax(out_surrogate["logits"], dim=1),
                                                   target=torch.softmax(out_original["logits"], dim=1),
                                                   log_target=False,
                                                   reduction='batchmean').item()                           

                        if _config["output_dim"]==1:
                            accuracy=((out_surrogate["logits"].sigmoid()>0.5).cpu().int()==batch['labels']).float().mean().item()
                        else:
                            accuracy=(torch.argmax(out_surrogate["logits"], dim=1).cpu()==batch['labels']).float().mean().item()


                        result_list_all[backbone_type].append({"batch_idx": batch_idx,
                                            "backbone_type": backbone_type,
                                            "num_mask": num_mask,
                                            "mask_location_model": mask_location_model,
                                            "mask_location_parameter": mask_location_parameter,
                                            "kl_divergence": kl_divergence,
                                            "accuracy": accuracy})
                        
                        
    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
        result_df=pd.DataFrame(result_list_all[backbone_type])

        result_df.to_csv(f'results/2_surrogate_evaluate/{_config["datasets"]}/{backbone_type}.csv')                            
                        

# 3_explanation_generate

In [None]:
# def adapt_path(path_original, path_format):
#     path_list=['l0.cs.washington.edu', 'l1lambda.cs.washington.edu', 'l2lambda.cs.washington.edu', 'l3.cs.washington.edu', 'deeper.cs.washington.edu']
    
#     for path1 in path_list:
#         if path1 in path_original:
#             for path2 in path_list:
#                 if path2 in path_format:
#                     return path_original.replace(path1, path2)
#             raise
#     return path_original

In [None]:
def get_random_explanation(num_players, num_samples=None):
    if num_samples is None:
        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_players,))
    else:
        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_samples, num_players))

In [None]:
label_to_use=['Garbage truck', 
              'Tench', 
              'English springer', 
              'Parachute',  
              'Golf ball', 
              'Gas pump']
kernelshap_sample_idx_list_all=[]
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for random_seed in [2, 3, 4, 5]:
        label_data_list=np.array([i['label'] for i in dset.data])
        kernelshap_sample_idx_list=[np.random.RandomState(random_seed).choice(np.arange(len(label_data_list))[(label_data_list==label_idx)]) for label_idx in [label_name_list.index(label) for label in label_to_use]]
        kernelshap_sample_idx_list_all+=kernelshap_sample_idx_list
kernelshap_sample_path_list_all=[dset[i]['path'] for i in kernelshap_sample_idx_list_all]        

In [None]:
if evaluation_stage=="3_explanation_generate":
    for idx, batch in enumerate(tqdm(data_loader)):#, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" else None)):
#         if dataset_split=="test":
#             if idx>int(1000/data_loader.batch_size+0.5):
#                 break
        if (idx%parallel_mode[1])!=parallel_mode[0]:
            continue
            
        images=batch['images']
        labels=batch['labels']
        paths=batch['path']
        updated_signal_list=[]
        
        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
            # Get classifier output
            classifier_dict[backbone_type].eval()
            with torch.no_grad():
                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),
                                                                 output_attentions=True)        
            for explanation_method in explanation_method_to_run:
                data_keys=explanation_save_dict[backbone_type][explanation_method].keys()
                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]
                if all([path in data_keys for path in paths]):
                    continue
                else:
                    print(explanation_method,'not exist')
                    updated_signal_list.append(explanation_method)
                if explanation_method=="random":
                    start_time=time.time()
                    explanation_random_list=[get_random_explanation(num_players=196) for path in paths]
                    elapsed_time=time.time()-start_time
                    explanation_save_dict_update(backbone_type, 'random', path_list=paths, explanation_list=explanation_random_list, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(196, ))
                
                elif explanation_method=="attention_rollout":
                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)
                    start_time=time.time()
                    explanation_attention_rollout_list=attentions_to_explanation(attentions, mode='rollout')
                    elapsed_time=time.time()-start_time
                    explanation_save_dict_update(backbone_type, 'attention_rollout', path_list=paths, explanation_list=explanation_attention_rollout_list, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(196,))

                elif explanation_method=="attention_last":
                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)
                    start_time=time.time()
                    explanation_attention_last_list=attentions_to_explanation(attentions, mode=-1)
                    elapsed_time=time.time()-start_time
                    explanation_save_dict_update(backbone_type, 'attention_last', path_list=paths, explanation_list=explanation_attention_last_list, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(196,))            

                elif explanation_method=="LRP":
                    explanation_lrp_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_lrp_list.append(np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,
                                                                              original_image=image.squeeze(0),
                                                                              class_index=i,
                                                                              mode='transformer_attribution').cpu().numpy() for i in range(_config["output_dim"])], axis=0))
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'LRP', path_list=paths, explanation_list=explanation_lrp_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="gradcam":
                    explanation_gradcam_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_gradcam = np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,
                                                                                              original_image=image.squeeze(0),
                                                                                              class_index=i,
                                                                                              mode='attn_gradcam').cpu().numpy() for i in range(_config["output_dim"])], axis=0)
                        explanation_gradcam = np.nan_to_num(explanation_gradcam,nan=0)+np.random.uniform(low=0, high=1e-20, size=explanation_gradcam.shape)                    
                        explanation_gradcam_list.append(explanation_gradcam)
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'gradcam', path_list=paths, explanation_list=explanation_gradcam_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))


                elif explanation_method=="gradcamgithub":
                    explanation_gradcamgithub_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_gradcamgithub = np.concatenate([cam_dict[backbone_type](input_tensor=image.unsqueeze(0).to(next(cam_dict[backbone_type].model.parameters()).device),
                                                                                            targets=[ClassifierOutputTarget(i)], resize=False).flatten()[np.newaxis,:] for i in range(_config["output_dim"])], axis=0)
                        explanation_gradcamgithub_list.append(explanation_gradcamgithub)
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'gradcamgithub', path_list=paths, explanation_list=explanation_gradcamgithub_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="vanillapixel":
                    explanation_vanillapixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_vanilla(image, saliency_pixel=saliency_pixel_dict[backbone_type])
                        explanation_vanillapixel_list.append(grad["attributions_pixel_patchabssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'vanillapixel', path_list=paths, explanation_list=explanation_vanillapixel_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="vanillaembedding":
                    explanation_vanillaembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:
                        start_time=time.time()
                        grad=get_vanilla(image, saliency_embedding=saliency_embedding_dict[backbone_type])
                        explanation_vanillaembedding_list.append(grad["attributions_embedding_abssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'vanillaembedding', path_list=paths, explanation_list=explanation_vanillaembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="sgpixel":
                    explanation_sgpixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_sg(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])
                        explanation_sgpixel_list.append(grad["attributions_pixel_patchabssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'sgpixel', path_list=paths, explanation_list=explanation_sgpixel_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))


                elif explanation_method=="sgembedding":
                    explanation_sgembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_sg(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])
                        explanation_sgembedding_list.append(grad["attributions_embedding_abssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'sgembedding', path_list=paths, explanation_list=explanation_sgembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))                                

                elif explanation_method=="vargradpixel":
                    explanation_vargradpixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_vargrad(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])
                        explanation_vargradpixel_list.append(grad["attributions_pixel_patchabssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'vargradpixel', path_list=paths, explanation_list=explanation_vargradpixel_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="vargradembedding":
                    explanation_vargradembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_vargrad(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])
                        explanation_vargradembedding_list.append(grad["attributions_embedding_abssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'vargradembedding', path_list=paths, explanation_list=explanation_vargradembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))                                

                elif explanation_method=="igpixel":
                    explanation_igpixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_ig(image, ig_pixel=ig_pixel_dict[backbone_type])
                        explanation_igpixel_list.append(grad["attributions_pixel_patchsum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'igpixel', path_list=paths, explanation_list=explanation_igpixel_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="igembedding":
                    explanation_igembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_ig(image, ig_embedding=ig_embedding_dict[backbone_type])
                        explanation_igembedding_list.append(grad["attributions_embedding_sum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'igembedding', path_list=paths, explanation_list=explanation_igembedding_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))                

                elif explanation_method=="leaveoneoutclassifier":
                    explanation_leaveoneoutclassifier_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_leaveoneoutclassifier = leave_one_out(classifier=classifier_dict_[backbone_type], image=image).reshape(_config["output_dim"], 196)
                        explanation_leaveoneoutclassifier_list.append(explanation_leaveoneoutclassifier)
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'leaveoneoutclassifier', path_list=paths, explanation_list=explanation_leaveoneoutclassifier_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="leaveoneoutsurrogate":
                    explanation_leaveoneoutsurrogate_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_leaveoneoutsurrogate = leave_one_out(surrogate=surrogate_dict[backbone_type]["pre-softmax"], image=image).reshape(_config["output_dim"], 196)
                        explanation_leaveoneoutsurrogate_list.append(explanation_leaveoneoutsurrogate)
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'leaveoneoutsurrogate', path_list=paths, explanation_list=explanation_leaveoneoutsurrogate_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))                

                elif explanation_method=="riseclassifier":
                    explanation_riseclassifier_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_riseclassifier = rise(classifier=classifier_dict_[backbone_type], image=image, N=2000, include_prob=0.5).reshape(_config["output_dim"], 196)
                        explanation_riseclassifier_list.append(explanation_riseclassifier)
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'riseclassifier', path_list=paths, explanation_list=explanation_riseclassifier_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))

                elif explanation_method=="risesurrogate":
                    explanation_risesurrogate_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_risesurrogate = rise(surrogate=surrogate_dict[backbone_type]["pre-softmax"], image=image, N=2000, include_prob=0.5).reshape(_config["output_dim"], 196)
                        explanation_risesurrogate_list.append(explanation_risesurrogate)
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'risesurrogate', path_list=paths, explanation_list=explanation_risesurrogate_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))                
                    
                elif explanation_method=="ours":
                    start_time=time.time()
                    explainer_dict[backbone_type].eval()
                    with torch.no_grad():
                        explanation_ours = explainer_dict[backbone_type](images.to(explainer_dict[backbone_type].device))[0].detach().cpu().numpy().transpose(0, 2, 1)
                    elapsed_time=time.time()-start_time
                    explanation_save_dict_update(backbone_type, 'ours', path_list=paths, explanation_list=explanation_ours, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))], shape=(_config["output_dim"], 196))                                
                    
                elif explanation_method=="kernelshap":                    
                    explanation_kernelshap_list=[]
                    elapsed_time_list=[]
                    path_list=[]
                    surrogate_SHAP_wrapped_dict[backbone_type].eval()
                    for path,image in zip(paths, images):
                        if path not in kernelshap_sample_path_list_all:
                            continue
                        print(path)
                        start_time=time.time()                        
                        explanation_kernelshap_ret = get_shap(surrogate_SHAP_wrapped_dict[backbone_type], image, thresh=0.2)
                        explanation_kernelshap = explanation_kernelshap_ret.values.T
                        explanation_kernelshap_list.append(explanation_kernelshap)
                        path_list.append(path)
                        elapsed_time_list.append(time.time()-start_time)
                    explanation_save_dict_update(backbone_type, 'kernelshap', path_list=path_list, explanation_list=explanation_kernelshap_list, elapsed_time_list=elapsed_time_list, shape=(_config["output_dim"], 196))                    
                else:
                    raise

        try:
            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
                for explanation_method, explanation_save_dict_backbone_method in explanation_save_dict[backbone_type].items():
                    if explanation_method not in updated_signal_list:
                        continue
                    explanation_save_dict_path=f'results/3_explanation_generate/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

                    if os.path.isfile(explanation_save_dict_path):
                        try:
                            with open(explanation_save_dict_path, 'rb') as f:
                                explanation_save_dict_loaded=pickle.load(f)
                        except:
                            explanation_save_dict_loaded={}
                    else:
                        explanation_save_dict_loaded={}

                    len_original=len(explanation_save_dict_backbone_method)            
                    len_loaded=len(explanation_save_dict_loaded)
                    explanation_save_dict_backbone_method.update(explanation_save_dict_loaded)
                    len_updated=len(explanation_save_dict_backbone_method)

                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')

                    with open(explanation_save_dict_path, "wb") as f:
                        pickle.dump(explanation_save_dict_backbone_method, f)
        except:
            pass
            
            
            
    #         # Get ours
    #         start_time=time.time();explainer_dict[backbone_type].eval()
    #         with torch.no_grad():
    #             values, _, _=explainer_dict[backbone_type](torch.Tensor(image).unsqueeze(0).to(explainer_dict[backbone_type].device))
    #             values=values.cpu().numpy().transpose(0,2,1).squeeze(0)
    #         values_ours=(values, time.time()-start_time)        

In [None]:
if evaluation_stage=="3_explanation_generate":
    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
        for explanation_method, explanation_save_dict_backbone_method in explanation_save_dict[backbone_type].items():
            explanation_save_dict_path=f'results/3_explanation_generate/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

            if os.path.isfile(explanation_save_dict_path):
                with open(explanation_save_dict_path, 'rb') as f:
                    explanation_save_dict_loaded=pickle.load(f)
            else:
                explanation_save_dict_loaded={}

            len_original=len(explanation_save_dict_backbone_method)            
            len_loaded=len(explanation_save_dict_loaded)
            explanation_save_dict_backbone_method.update(explanation_save_dict_loaded)
            len_updated=len(explanation_save_dict_backbone_method)

            print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')

            with open(explanation_save_dict_path, "wb") as f:
                pickle.dump(explanation_save_dict_backbone_method, f)       

# 4_insert_delete

In [None]:
def explanation_to_mask(explanation, mode='insertion'):
    """
    Args:
        explanation: (num_batches, num_players)
    Returns:
        explanation_expaned_bool: (num_batches, num_players+1, num_players)
    """
    
    explanation_expaned=np.repeat(explanation[:,np.newaxis,:], explanation.shape[-1], axis=1) # (num_batches, num_players, num_players)
    
    if mode=='insertion':
        explanation_expaned_bool = explanation_expaned > ((np.sort(explanation, axis=-1)[:, : :-1])[:, :, np.newaxis]) # (num_batches, num_players, num_players)
        explanation_expaned_bool = np.concatenate([explanation_expaned_bool,
                                                   np.ones(shape=(explanation_expaned_bool.shape[0], 1, explanation_expaned_bool.shape[2]))==1], axis=1) # (num_batches, num_players+1, num_players)
        #print(explanation_expaned_bool.shape)
    elif mode=='deletion':
        explanation_expaned_bool = explanation_expaned < ((np.sort(explanation, axis=-1)[:, : :-1])[:, :, np.newaxis]) # (num_batches, num_players, num_players)
        explanation_expaned_bool = np.concatenate([np.ones(shape=(explanation_expaned_bool.shape[0], 1, explanation_expaned_bool.shape[2]))==1,
                                                   explanation_expaned_bool],axis=1) # (num_batches, num_players+1, num_players)        
        
    else:
        raise ValueError(f'{mode} should be insertion or deletion.')
    
    
    return explanation_expaned_bool

In [None]:
evaluation_stage="4_insert_delete"

In [None]:
explanation_method_to_run=["kernelshap"]

In [None]:
estimationerror_sample_path_list=pd.DataFrame(data_loader.dataset.data).groupby("label").apply(lambda x: x.sample(n=10, random_state=42))["img_path"].tolist()

In [None]:
parallel_mode=(0,1)

In [None]:
if evaluation_stage=="4_insert_delete":
    num_players=196
    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" and (len(explanation_method_to_run)!=1 or explanation_method_to_run[0]!="kernelshap") else None)):
        if dataset_split=="test" and (len(explanation_method_to_run)!=1 or explanation_method_to_run[0]!="kernelshap"):
            if idx>int(1000/data_loader.batch_size+0.5):
                break
        if (idx%parallel_mode[1])!=parallel_mode[0]:
            continue                                

        images=batch['images']
        labels=batch['labels']
        paths=batch['path']

        updated_signal_list=[]

        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):

            for explanation_method in explanation_method_to_run:
                data_keys=insertdelete_save_dict[backbone_type][explanation_method].keys()
                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]
                if explanation_method=='kernelshap':
                    if all([(path in data_keys) or (path not in estimationerror_sample_path_list) for path in paths]):
                        continue
                    else:
                        print(explanation_method,'not exist')
                        updated_signal_list.append(explanation_method)                
                else:
                    if all([path in data_keys for path in paths]):
                        continue
                    else:
                        print(explanation_method,'not exist')
                        updated_signal_list.append(explanation_method)                      
                
                if explanation_method=="random":
                    explanations=[np.random.RandomState(idx).uniform(low=0, high=1e-40, size=(10, num_players)) for idx, path in enumerate(paths)]
                else:
                    explanations=[explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys()))]['explanation']
                                  for path in paths]
                
                insertdelete_dict={'insertion': [], 'deletion': []}
                for image, explanation in zip(images, explanations):
                    image_loaded=image[np.newaxis, :].repeat(num_players+1, 1, 1, 1).to(surrogate_dict[backbone_type]["pre-softmax"].device)
                    if np.isnan(explanation).any():
                        print(explanation_method, "Null found")
                      
                    if explanation_method=="random":
                        for metric_mode in insertdelete_dict.keys():
                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)
                            prob_=[]
                            for random_iter in range(mask.shape[0]):
                                surrogate_dict[backbone_type]["pre-softmax"].eval()
                                with torch.no_grad():
                                    output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                          masks=torch.Tensor(mask[random_iter]).to(surrogate_dict[backbone_type]["pre-softmax"].device))
                                if _config["output_dim"]==1:
                                    prob=output['logits'].sigmoid().cpu().numpy()
                                else:
                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()
                                prob_.append(prob.T)
                            prob=np.array(prob_) # (10, num_players+1, num_classes)
                            insertdelete_dict[metric_mode].append(prob)
                        
                    elif len(explanation.shape)==1:
                        for metric_mode in insertdelete_dict.keys():
                            mask=explanation_to_mask(explanation=get_relative_value(explanation)[np.newaxis,:], mode=metric_mode)
                            surrogate_dict[backbone_type]["pre-softmax"].eval()
                            with torch.no_grad():
                                output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                      masks=torch.Tensor(mask[0]).to(surrogate_dict[backbone_type]["pre-softmax"].device))
                            if _config["output_dim"]==1:
                                prob=output['logits'].sigmoid().cpu().numpy()
                            else:
                                prob=output['logits'].softmax(dim=-1).cpu().numpy()
                            prob=prob.T # ( , num_players)
                            insertdelete_dict[metric_mode].append(prob) 

                    elif len(explanation.shape)==2:
                        for metric_mode in insertdelete_dict.keys():
                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)
                            prob_=[]
                            for class_idx in range(mask.shape[0]):
                                surrogate_dict[backbone_type]["pre-softmax"].eval()
                                with torch.no_grad():
                                    output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                          masks=torch.Tensor(mask[class_idx]).to(surrogate_dict[backbone_type]["pre-softmax"].device))                                  
                                if _config["output_dim"]==1:
                                    prob=output['logits'].sigmoid().cpu().numpy()
                                else:
                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()
                                prob_.append(prob[:, class_idx])
                            prob=np.array(prob_)
                            insertdelete_dict[metric_mode].append(prob)
                    else:
                        raise
                            
                if explanation_method=="random":
                    insertdelete_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=insertdelete_dict["insertion"], 
                                                  delete_list=insertdelete_dict["deletion"], 
                                                  shape=(10, _config["output_dim"], num_players+1))                       
                elif len(explanations[0].shape)==1:
                    insertdelete_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=insertdelete_dict["insertion"], 
                                                  delete_list=insertdelete_dict["deletion"], 
                                                  shape=(_config["output_dim"], num_players+1))                    
                elif len(explanations[0].shape)==2:
                    insertdelete_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=insertdelete_dict["insertion"], 
                                                  delete_list=insertdelete_dict["deletion"], 
                                                  shape=(_config["output_dim"], num_players+1))
                else:
                    raise
        
        
        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
            for explanation_method, insertdelete_save_dict_backbone_method in insertdelete_save_dict[backbone_type].items():
                
                if explanation_method not in updated_signal_list:
                    continue                
                
                insertdelete_save_dict_path=f'results/4_insert_delete/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

                if os.path.isfile(insertdelete_save_dict_path):
                    with open(insertdelete_save_dict_path, 'rb') as f:
                        insertdelete_save_dict_loaded=pickle.load(f)
                else:
                    insertdelete_save_dict_loaded={}

                len_original=len(insertdelete_save_dict_backbone_method)            
                len_loaded=len(insertdelete_save_dict_loaded)
                insertdelete_save_dict_backbone_method.update(insertdelete_save_dict_loaded)
                len_updated=len(insertdelete_save_dict_backbone_method)

                print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')


                with open(insertdelete_save_dict_path, "wb") as f:
                    pickle.dump(insertdelete_save_dict_backbone_method, f)           
        
        

In [None]:
if evaluation_stage=="4_insert_delete":
    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
        for explanation_method, insertdelete_save_dict_backbone_method in insertdelete_save_dict[backbone_type].items():
            insertdelete_save_dict_path=f'results/4_insert_delete/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

            if os.path.isfile(insertdelete_save_dict_path):
                with open(insertdelete_save_dict_path, 'rb') as f:
                    insertdelete_save_dict_loaded=pickle.load(f)
            else:
                insertdelete_save_dict_loaded={}

            len_original=len(insertdelete_save_dict_backbone_method)            
            len_loaded=len(insertdelete_save_dict_loaded)
            insertdelete_save_dict_backbone_method.update(insertdelete_save_dict_loaded)
            len_updated=len(insertdelete_save_dict_backbone_method)

            print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')


            with open(insertdelete_save_dict_path, "wb") as f:
                pickle.dump(insertdelete_save_dict_backbone_method, f)        

# 5_sensitivity

In [None]:
def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,
                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:
    """
    Args:
        num_players: the number of players in the coalitional game
        num_mask_samples: the number of masks to generate
        paired_mask_samples: if True, the generated masks are pairs of x and 1-x.
        mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')
        random_state: random generator

    Returns:
        torch.Tensor of shape
        (num_masks, num_players) if num_masks is int
        (num_players) if num_masks is None

    """
    random_state = random_state or np.random

    num_samples_ = num_mask_samples or 1

    if paired_mask_samples:
        assert num_samples_ % 2 == 0, "'num_samples' must be a multiple of 2 if 'paired' is True"
        num_samples_ = num_samples_ // 2
    else:
        num_samples_ = num_samples_

    if mode == 'uniform':
        masks = (random_state.rand(num_samples_, num_players) > random_state.rand(num_samples_, 1)).astype('int')
    elif mode == 'shapley':
        probs = 1 / (np.arange(1, num_players) * (num_players - np.arange(1, num_players)))
        probs = probs / probs.sum()
        masks = (random_state.rand(num_samples_, num_players) > 1 / num_players * random_state.choice(
            np.arange(num_players - 1), p=probs, size=[num_samples_, 1])).astype('int')
    else:
        raise ValueError("'mode' must be 'random' or 'shapley'")

    if paired_mask_samples:
        masks = np.stack([masks, 1 - masks], axis=1).reshape(num_samples_ * 2, num_players)

    if num_mask_samples is None:
        masks = masks.squeeze(0)
        return masks  # (num_masks)
    else:
        return masks  # (num_samples, num_masks)

In [None]:
if evaluation_stage=="5_sensitivity":
    num_players=196
    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" else None)):   
        if dataset_split=="test":
            if idx>int(1000/data_loader.batch_size+0.5):
                break
        if (idx%parallel_mode[1])!=parallel_mode[0]:
            continue                                

        images=batch['images']
        labels=batch['labels']
        paths=batch['path']

        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):            
            for image, path in zip(images, paths):
                #path=path.replace('l0.cs.washington.edu','l2lambda.cs.washington.edu')
                for num_included_players in ["all"] + list(range(14, 196, 14)):
                    if num_included_players=="all":
                        prob_all=[]
                        mask_all=[]
                        for random_iter in range(20):
                            image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type]["pre-softmax"].device)
                            mask=generate_mask(num_players=num_players,
                                               num_mask_samples=50,
                                               paired_mask_samples=False,
                                               mode="uniform",
                                               random_state=np.random.RandomState(random_iter))

                            surrogate_dict[backbone_type]["pre-softmax"].eval()
                            with torch.no_grad():
                                output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                      masks=torch.Tensor(mask).to(surrogate_dict[backbone_type]["pre-softmax"].device))

                            if _config["output_dim"]==1:
                                prob=output['logits'].sigmoid().cpu().numpy()
                            else:
                                prob=output['logits'].softmax(dim=-1).cpu().numpy()   
                            prob_all.append(prob)
                            mask_all.append(mask)
                        prob_all=np.concatenate(prob_all, axis=0)
                        mask_all=np.concatenate(mask_all, axis=0)
                    else:
                        prob_all=[]
                        mask_all=[]
                        for random_iter in range(20):
                            image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type]["pre-softmax"].device)                                                        
                            mask=np.zeros((50, num_players))
                            mask[:, :num_included_players]=1
                            for i in range(len(mask)):
                                mask[i]=np.random.RandomState(42+10*random_iter+i).permutation(mask[i])

                            surrogate_dict[backbone_type]["pre-softmax"].eval()
                            with torch.no_grad():
                                output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                      masks=torch.Tensor(mask).to(surrogate_dict[backbone_type]["pre-softmax"].device))

                            if _config["output_dim"]==1:
                                prob=output['logits'].sigmoid().cpu().numpy()
                            else:
                                prob=output['logits'].softmax(dim=-1).cpu().numpy()   
                            prob_all.append(prob)
                            mask_all.append(mask)
                        prob_all=np.concatenate(prob_all, axis=0)
                        mask_all=np.concatenate(mask_all, axis=0)
                    
                    for explanation_method in explanation_method_to_run:                
                        if explanation_method=="random":
                            continue
                        explanation=explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys()))]['explanation']
                        explanation=explanation+np.random.RandomState(42).uniform(low=0, high=1e-40, size=explanation.shape)
                        explanation_mask=explanation@(mask_all.T)

                        if len(explanation.shape)==1:
                            correlation = np.array([stats.spearmanr(explanation_mask, prob_all[:, i]).correlation for i in range(prob_all.shape[1])])
                            assert correlation.shape==(_config["output_dim"],)
                            
                        elif len(explanation.shape)==2:
                            #correlation=stats.spearmanr(np.concatenate([explanation_mask, prob_all.T], axis=0), axis=1).correlation
                            correlation = np.array([stats.spearmanr(explanation_mask[i], prob_all[:, i]).correlation for i in range(prob_all.shape[1])])
                            assert correlation.shape==(_config["output_dim"],)
                        else:
                            raise
                        sensitivity_save_dit_update(backbone_type, explanation_method,
                                                     num_included_players=num_included_players,
                                                     path_list=[path], sensitivity_list=[correlation],
                                                     shape=(_config["output_dim"],))     
                    
                    
                
        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
            for explanation_method, sensitivity_save_dit_backbone_method in sensitivity_save_dit[backbone_type].items():
                sensitivity_save_dit_path=f'results/5_sensitivity/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

                if os.path.isfile(sensitivity_save_dit_path):
                    with open(sensitivity_save_dit_path, 'rb') as f:
                        sensitivity_save_dit_loaded=pickle.load(f)
                else:
                    sensitivity_save_dit_loaded={}

                len_original=len(sensitivity_save_dit_backbone_method)            
                len_loaded=len(sensitivity_save_dit_loaded)
                sensitivity_save_dit_backbone_method.update(sensitivity_save_dit_loaded)
                len_updated=len(sensitivity_save_dit_backbone_method)

                print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')


                with open(sensitivity_save_dit_path, "wb") as f:
                    pickle.dump(sensitivity_save_dit_backbone_method, f)        
                    
             
                

In [None]:
path

In [None]:
path1_list

In [None]:
explanation_save_dict[backbone_type][explanation_method]

In [None]:
explanation_save_dict[backbone_type][explanation_method].keys()

In [None]:
evaluation_stage

# 6_noretraining

In [None]:
if evaluation_stage=="6_noretraining":
    num_players=196
    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" else None)):   
        if dataset_split=="test":
            if idx>int(1000/data_loader.batch_size+0.5):
                break
        if (idx%parallel_mode[1])!=parallel_mode[0]:
            continue                                

        images=batch['images']
        labels=batch['labels']
        paths=batch['path']
        updated_signal_list=[]

        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):

            for explanation_method in ["random"]+explanation_method_to_run:
                data_keys=noretraining_save_dict[backbone_type][explanation_method].keys()
                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]
                if all([path in data_keys for path in paths]):
                    continue
                else:
                    print(explanation_method,'not exist')
                    updated_signal_list.append(explanation_method)                
                
                if explanation_method=="random":
                    explanations=[np.random.RandomState(idx).uniform(low=0, high=1e-40, size=(10, num_players)) for idx, path in enumerate(paths)]
                else:
                    explanations=[explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys()))]['explanation']
                                  for path in paths]
                noretraining_dict={'insertion': [], 'deletion': []}
                for image, explanation, label in zip(images, explanations, labels):
                    image_loaded=image[np.newaxis, :].repeat(num_players+1, 1, 1, 1).to(surrogate_dict[backbone_type]["pre-softmax"].device)
                    if np.isnan(explanation).any():
                        print(explanation_method, "Null found")
                      
                    if explanation_method=="random":
                        for metric_mode in noretraining_dict.keys():
                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)
                            prob_=[]
                            for random_iter in range(mask.shape[0]):
                                surrogate_dict[backbone_type]["pre-softmax"].eval()
                                with torch.no_grad():
                                    output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                          masks=torch.Tensor(mask[random_iter]).to(surrogate_dict[backbone_type]["pre-softmax"].device))
                                if _config["output_dim"]==1:
                                    prob=output['logits'].sigmoid().cpu().numpy()
                                else:
                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()
                                prob_.append(prob.T)
                            prob=np.array(prob_) # (10, num_classes, num_players+1)
                            noretraining_dict[metric_mode].append(prob)#(prob.argmax(axis=1)==label.item()).astype(float))
                        
                    elif len(explanation.shape)==1:
                        for metric_mode in noretraining_dict.keys():
                            mask=explanation_to_mask(explanation=get_relative_value(explanation)[np.newaxis,:], mode=metric_mode)
                            surrogate_dict[backbone_type]["pre-softmax"].eval()
                            with torch.no_grad():
                                output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                      masks=torch.Tensor(mask[0]).to(surrogate_dict[backbone_type]["pre-softmax"].device))
                            if _config["output_dim"]==1:
                                prob=output['logits'].sigmoid().cpu().numpy()
                            else:
                                prob=output['logits'].softmax(dim=-1).cpu().numpy()
                            prob=prob.T # (num_classes, num_players+1)                            
                            #noretraining_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float)) 
                            noretraining_dict[metric_mode].append(prob) 

                    elif len(explanation.shape)==2:
                        for metric_mode in noretraining_dict.keys():
                            #mask=explanation_to_mask(explanation=explanation[label.item()], mode=metric_mode)
                            if len(explanation)==1:
                                mask=explanation_to_mask(explanation=get_relative_value(explanation[0])[np.newaxis,:], mode=metric_mode)
                            else:
                                mask=explanation_to_mask(explanation=get_relative_value(explanation[label.item()])[np.newaxis,:], mode=metric_mode)
                            
                            surrogate_dict[backbone_type]["pre-softmax"].eval()
                            with torch.no_grad():
                                output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                      masks=torch.Tensor(mask[0]).to(surrogate_dict[backbone_type]["pre-softmax"].device))
                            if _config["output_dim"]==1:
                                prob=output['logits'].sigmoid().cpu().numpy()
                            else:
                                prob=output['logits'].softmax(dim=-1).cpu().numpy()
                            prob=prob.T
                            #noretraining_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float))
                            noretraining_dict[metric_mode].append(prob)
                    else:
                        raise
                          
                if explanation_method=="random":
                    noretraining_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=noretraining_dict["insertion"], 
                                                  delete_list=noretraining_dict["deletion"], 
                                                  shape=(10, _config["output_dim"], num_players+1))                       
                elif len(explanations[0].shape)==1:
                    noretraining_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=noretraining_dict["insertion"], 
                                                  delete_list=noretraining_dict["deletion"], 
                                                  shape=(_config["output_dim"], num_players+1))                    
                elif len(explanations[0].shape)==2:
                    noretraining_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=noretraining_dict["insertion"], 
                                                  delete_list=noretraining_dict["deletion"], 
                                                  shape=(_config["output_dim"], num_players+1))
                else:
                    raise
        
        try:        
            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
                for explanation_method, noretraining_save_dict_backbone_method in noretraining_save_dict[backbone_type].items():

                    if explanation_method not in updated_signal_list:
                        continue                

                    noretraining_save_dict_path=f'results/6_noretraining/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

                    if os.path.isfile(noretraining_save_dict_path):
                        try:
                            with open(noretraining_save_dict_path, 'rb') as f:
                                noretraining_save_dict_loaded=pickle.load(f)
                        except:
                            noretraining_save_dict_loaded={}
                    else:
                        noretraining_save_dict_loaded={}

                    len_original=len(noretraining_save_dict_backbone_method)            
                    len_loaded=len(noretraining_save_dict_loaded)
                    noretraining_save_dict_backbone_method.update(noretraining_save_dict_loaded)
                    len_updated=len(noretraining_save_dict_backbone_method)

                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')


                    with open(noretraining_save_dict_path, "wb") as f:
                        pickle.dump(noretraining_save_dict_backbone_method, f)
        except:
            pass
        
                            
                    

# 7_classifiermasked

In [None]:
if evaluation_stage=="7_classifiermasked":
    num_players=196
    for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" else None)):   
        if dataset_split=="test":
            if idx>int(1000/data_loader.batch_size+0.5):
                break
        if (idx%parallel_mode[1])!=parallel_mode[0]:
            continue                                

        images=batch['images']
        labels=batch['labels']
        paths=batch['path']
        updated_signal_list=[]

        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):

            for explanation_method in ["random"]+explanation_method_to_run:
                data_keys=classifiermasked_save_dict[backbone_type][explanation_method].keys()
                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]
                if all([path in data_keys for path in paths]):
                    continue
                else:
                    print(explanation_method,'not exist')
                    updated_signal_list.append(explanation_method)                
                
                if explanation_method=="random":
                    explanations=[np.random.RandomState(idx).uniform(low=0, high=1e-40, size=(10, num_players)) for idx, path in enumerate(paths)]
                else:
                    explanations=[explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys()))]['explanation']
                                  for path in paths]
                classifiermasked_dict={'insertion': [], 'deletion': []}
                for image, explanation, label in zip(images, explanations, labels):
                    image_loaded=image[np.newaxis, :].repeat(num_players+1, 1, 1, 1).to(classifier_masked_dict[backbone_type].device)
                    if np.isnan(explanation).any():
                        print(explanation_method, "Null found")
                      
                    if explanation_method=="random":
                        for metric_mode in classifiermasked_dict.keys():
                            mask=explanation_to_mask(explanation=np.array([get_relative_value(explanation_) for explanation_ in explanation]), mode=metric_mode)
                            prob_=[]
                            for random_iter in range(mask.shape[0]):
                                classifier_masked_dict[backbone_type].eval()
                                with torch.no_grad():
                                    output = classifier_masked_dict[backbone_type](image_loaded,
                                                                                          masks=torch.Tensor(mask[random_iter]).to(classifier_masked_dict[backbone_type].device))
                                if _config["output_dim"]==1:
                                    prob=output['logits'].sigmoid().cpu().numpy()
                                else:
                                    prob=output['logits'].softmax(dim=-1).cpu().numpy()
                                prob_.append(prob.T)
                            prob=np.array(prob_) # (10, num_classes, num_players+1)
                            classifiermasked_dict[metric_mode].append(prob)#(prob.argmax(axis=1)==label.item()).astype(float))
                        
                    elif len(explanation.shape)==1:
                        for metric_mode in classifiermasked_dict.keys():
                            mask=explanation_to_mask(explanation=get_relative_value(explanation)[np.newaxis,:], mode=metric_mode)
                            classifier_masked_dict[backbone_type].eval()
                            with torch.no_grad():
                                output = classifier_masked_dict[backbone_type](image_loaded,
                                                                                      masks=torch.Tensor(mask[0]).to(classifier_masked_dict[backbone_type].device))
                            if _config["output_dim"]==1:
                                prob=output['logits'].sigmoid().cpu().numpy()
                            else:
                                prob=output['logits'].softmax(dim=-1).cpu().numpy()
                            prob=prob.T # (num_classes, num_players+1)                            
                            #classifiermasked_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float)) 
                            classifiermasked_dict[metric_mode].append(prob) 

                    elif len(explanation.shape)==2:
                        for metric_mode in classifiermasked_dict.keys():
                            #mask=explanation_to_mask(explanation=explanation[label.item()], mode=metric_mode)
                            if len(explanation)==1:
                                mask=explanation_to_mask(explanation=get_relative_value(explanation[0])[np.newaxis,:], mode=metric_mode)
                            else:
                                mask=explanation_to_mask(explanation=get_relative_value(explanation[label.item()])[np.newaxis,:], mode=metric_mode)
                            
                            classifier_masked_dict[backbone_type].eval()
                            with torch.no_grad():
                                output = classifier_masked_dict[backbone_type](image_loaded,
                                                                                      masks=torch.Tensor(mask[0]).to(classifier_masked_dict[backbone_type].device))
                            if _config["output_dim"]==1:
                                prob=output['logits'].sigmoid().cpu().numpy()
                            else:
                                prob=output['logits'].softmax(dim=-1).cpu().numpy()
                            prob=prob.T
                            #classifiermasked_dict[metric_mode].append((prob.argmax(axis=0)==label.item()).astype(float))
                            classifiermasked_dict[metric_mode].append(prob)
                    else:
                        raise
                          
                if explanation_method=="random":
                    classifiermasked_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=classifiermasked_dict["insertion"], 
                                                  delete_list=classifiermasked_dict["deletion"], 
                                                  shape=(10, _config["output_dim"], num_players+1))                       
                elif len(explanations[0].shape)==1:
                    classifiermasked_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=classifiermasked_dict["insertion"], 
                                                  delete_list=classifiermasked_dict["deletion"], 
                                                  shape=(_config["output_dim"], num_players+1))                    
                elif len(explanations[0].shape)==2:
                    classifiermasked_save_dict_update(backbone_type, explanation_method,
                                                  path_list=paths, 
                                                  insert_list=classifiermasked_dict["insertion"], 
                                                  delete_list=classifiermasked_dict["deletion"], 
                                                  shape=(_config["output_dim"], num_players+1))
                else:
                    raise
        
        try:        
            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
                for explanation_method, classifiermasked_save_dict_backbone_method in classifiermasked_save_dict[backbone_type].items():

                    if explanation_method not in updated_signal_list:
                        continue                

                    classifiermasked_save_dict_path=f'results/7_classifiermasked/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

                    if os.path.isfile(classifiermasked_save_dict_path):
                        try:
                            with open(classifiermasked_save_dict_path, 'rb') as f:
                                classifiermasked_save_dict_loaded=pickle.load(f)
                        except:
                            classifiermasked_save_dict_loaded={}
                    else:
                        classifiermasked_save_dict_loaded={}

                    len_original=len(classifiermasked_save_dict_backbone_method)            
                    len_loaded=len(classifiermasked_save_dict_loaded)
                    classifiermasked_save_dict_backbone_method.update(classifiermasked_save_dict_loaded)
                    len_updated=len(classifiermasked_save_dict_backbone_method)

                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')


                    with open(classifiermasked_save_dict_path, "wb") as f:
                        pickle.dump(classifiermasked_save_dict_backbone_method, f)
        except:
            pass
        
                            
                    

# 8_elapsedtime

In [None]:
data_loader

In [None]:
data_loader.batch_size

In [None]:
def get_random_explanation(num_players, num_samples=None):
    if num_samples is None:
        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_players,))
    else:
        return np.random.RandomState(42).uniform(low=0, high=1e-40, size=(num_samples, num_players))


if evaluation_stage=="8_elapsedtime":
    for idx, batch in enumerate(tqdm(data_loader)):#, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" else None)):
        if dataset_split=="test":
            if idx>int(1000/data_loader.batch_size+0.5):
                break
        if idx==7:
            break
        if (idx%parallel_mode[1])!=parallel_mode[0]:
            continue
            
        images=batch['images']
        labels=batch['labels']
        paths=batch['path']
        updated_signal_list=[]
        
        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
            # Get classifier output
            classifier_dict[backbone_type].eval()
            with torch.no_grad():
                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),
                                                                 output_attentions=True)        
            for explanation_method in explanation_method_to_run:
                data_keys=elapsedtime_save_dict[backbone_type][explanation_method].keys()
                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]
                if all([path in data_keys for path in paths]):
                    continue
                else:
                    print(explanation_method,'not exist')
                    updated_signal_list.append(explanation_method)
                    
                if explanation_method=="random":
                    start_time=time.time()
                    explanation_random_list=[get_random_explanation(num_players=196) for path in paths]
                    elapsed_time=time.time()-start_time
                    elapsedtime_save_dict_update(backbone_type, 'random', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])
                
                elif explanation_method=="attention_rollout":
                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)
                    start_time=time.time()
                    explanation_attention_rollout_list=attentions_to_explanation(attentions, mode='rollout')
                    elapsed_time=time.time()-start_time
                    elapsedtime_save_dict_update(backbone_type, 'attention_rollout', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])

                elif explanation_method=="attention_last":
                    attentions = np.asarray([att.cpu().detach().numpy() for att in classifier_output['self_attentions']]).transpose(1,0,2,3,4)
                    start_time=time.time()
                    explanation_attention_last_list=attentions_to_explanation(attentions, mode=-1)
                    elapsed_time=time.time()-start_time
                    elapsedtime_save_dict_update(backbone_type, 'attention_last', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])            

                elif explanation_method=="LRP":
                    explanation_lrp_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_lrp_list.append(np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,
                                                                              original_image=image.squeeze(0),
                                                                              class_index=i,
                                                                              mode='transformer_attribution').cpu().numpy() for i in range(_config["output_dim"])], axis=0))
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'LRP', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="gradcam":
                    explanation_gradcam_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_gradcam = np.concatenate([get_lrp_module_explanation(backbone_type=backbone_type,
                                                                                              original_image=image.squeeze(0),
                                                                                              class_index=i,
                                                                                              mode='attn_gradcam').cpu().numpy() for i in range(_config["output_dim"])], axis=0)
                        explanation_gradcam = np.nan_to_num(explanation_gradcam,nan=0)+np.random.uniform(low=0, high=1e-20, size=explanation_gradcam.shape)                    
                        explanation_gradcam_list.append(explanation_gradcam)
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'gradcam', path_list=paths, elapsed_time_list=elapsed_time_list)


                elif explanation_method=="gradcamgithub":
                    explanation_gradcamgithub_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_gradcamgithub = np.concatenate([cam_dict[backbone_type](input_tensor=image.unsqueeze(0).to(next(cam_dict[backbone_type].model.parameters()).device),
                                                                                            targets=[ClassifierOutputTarget(i)], resize=False).flatten()[np.newaxis,:] for i in range(_config["output_dim"])], axis=0)
                        explanation_gradcamgithub_list.append(explanation_gradcamgithub)
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'gradcamgithub', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="vanillapixel":
                    explanation_vanillapixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_vanilla(image, saliency_pixel=saliency_pixel_dict[backbone_type])
                        explanation_vanillapixel_list.append(grad["attributions_pixel_patchabssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'vanillapixel', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="vanillaembedding":
                    explanation_vanillaembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:
                        start_time=time.time()
                        grad=get_vanilla(image, saliency_embedding=saliency_embedding_dict[backbone_type])
                        explanation_vanillaembedding_list.append(grad["attributions_embedding_abssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'vanillaembedding', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="sgpixel":
                    explanation_sgpixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_sg(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])
                        explanation_sgpixel_list.append(grad["attributions_pixel_patchabssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'sgpixel', path_list=paths, elapsed_time_list=elapsed_time_list)


                elif explanation_method=="sgembedding":
                    explanation_sgembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_sg(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])
                        explanation_sgembedding_list.append(grad["attributions_embedding_abssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'sgembedding', path_list=paths, elapsed_time_list=elapsed_time_list)                                

                elif explanation_method=="vargradpixel":
                    explanation_vargradpixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_vargrad(image, noisetunnel_pixel=noisetunnel_pixel_dict[backbone_type])
                        explanation_vargradpixel_list.append(grad["attributions_pixel_patchabssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'vargradpixel', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="vargradembedding":
                    explanation_vargradembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_vargrad(image, noisetunnel_embedding=noisetunnel_embedding_dict[backbone_type])
                        explanation_vargradembedding_list.append(grad["attributions_embedding_abssum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'vargradembedding', path_list=paths, elapsed_time_list=elapsed_time_list)                                

                elif explanation_method=="igpixel":
                    explanation_igpixel_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_ig(image, ig_pixel=ig_pixel_dict[backbone_type])
                        explanation_igpixel_list.append(grad["attributions_pixel_patchsum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'igpixel', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="igembedding":
                    explanation_igembedding_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        grad=get_ig(image, ig_embedding=ig_embedding_dict[backbone_type])
                        explanation_igembedding_list.append(grad["attributions_embedding_sum"].reshape(_config["output_dim"],196).numpy())
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'igembedding', path_list=paths, elapsed_time_list=elapsed_time_list)                

                elif explanation_method=="leaveoneoutclassifier":
                    explanation_leaveoneoutclassifier_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_leaveoneoutclassifier = leave_one_out(classifier=classifier_dict_[backbone_type], image=image).reshape(_config["output_dim"], 196)
                        explanation_leaveoneoutclassifier_list.append(explanation_leaveoneoutclassifier)
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'leaveoneoutclassifier', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="leaveoneoutsurrogate":
                    explanation_leaveoneoutsurrogate_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_leaveoneoutsurrogate = leave_one_out(surrogate=surrogate_dict[backbone_type]["pre-softmax"], image=image).reshape(_config["output_dim"], 196)
                        explanation_leaveoneoutsurrogate_list.append(explanation_leaveoneoutsurrogate)
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'leaveoneoutsurrogate', path_list=paths, elapsed_time_list=elapsed_time_list)                

                elif explanation_method=="riseclassifier":
                    explanation_riseclassifier_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_riseclassifier = rise(classifier=classifier_dict_[backbone_type], image=image, N=2000, include_prob=0.5).reshape(_config["output_dim"], 196)
                        explanation_riseclassifier_list.append(explanation_riseclassifier)
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'riseclassifier', path_list=paths, elapsed_time_list=elapsed_time_list)

                elif explanation_method=="risesurrogate":
                    explanation_risesurrogate_list=[]
                    elapsed_time_list=[]
                    for image in images:               
                        start_time=time.time()
                        explanation_risesurrogate = rise(surrogate=surrogate_dict[backbone_type]["pre-softmax"], image=image, N=2000, include_prob=0.5).reshape(_config["output_dim"], 196)
                        explanation_risesurrogate_list.append(explanation_risesurrogate)
                        elapsed_time_list.append(time.time()-start_time)
                    elapsedtime_save_dict_update(backbone_type, 'risesurrogate', path_list=paths, elapsed_time_list=elapsed_time_list)                
                    
                elif explanation_method=="ours":
                    start_time=time.time()
                    explainer_dict[backbone_type].eval()
                    with torch.no_grad():
                        explanation_ours = explainer_dict[backbone_type](images.to(explainer_dict[backbone_type].device))[0].detach().cpu().numpy().transpose(0, 2, 1)
                    elapsed_time=time.time()-start_time
                    elapsedtime_save_dict_update(backbone_type, 'ours', path_list=paths, elapsed_time_list=[elapsed_time/len(paths) for i in range(len(paths))])                                
                else:
                    raise

        try:
            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
                for explanation_method, elapsedtime_save_dict_backbone_method in elapsedtime_save_dict[backbone_type].items():
                    if explanation_method not in updated_signal_list:
                        continue
                    elapsedtime_save_dict_path=f'results/8_elapsedtime/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

                    if os.path.isfile(elapsedtime_save_dict_path):
                        try:
                            with open(elapsedtime_save_dict_path, 'rb') as f:
                                elapsedtime_save_dict_loaded=pickle.load(f)
                        except:
                            elapsedtime_save_dict_loaded={}
                    else:
                        elapsedtime_save_dict_loaded={}

                    len_original=len(elapsedtime_save_dict_backbone_method)            
                    len_loaded=len(elapsedtime_save_dict_loaded)
                    elapsedtime_save_dict_backbone_method.update(elapsedtime_save_dict_loaded)
                    len_updated=len(elapsedtime_save_dict_backbone_method)

                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')

                    with open(elapsedtime_save_dict_path, "wb") as f:
                        pickle.dump(elapsedtime_save_dict_backbone_method, f)
        except:
            pass
            
            
            
    #         # Get ours
    #         start_time=time.time();explainer_dict[backbone_type].eval()
    #         with torch.no_grad():
    #             values, _, _=explainer_dict[backbone_type](torch.Tensor(image).unsqueeze(0).to(explainer_dict[backbone_type].device))
    #             values=values.cpu().numpy().transpose(0,2,1).squeeze(0)
    #         values_ours=(values, time.time()-start_time)                

In [None]:
elapsedtime_save_dict[backbone_type].keys()

# 9_estimationerror

In [None]:
estimationerror_sample_path_list=pd.DataFrame(data_loader.dataset.data).groupby("label").apply(lambda x: x.sample(n=10, random_state=42))["img_path"].tolist()

In [None]:
if evaluation_stage=="9_estimationerror":
    for idx, batch in enumerate(tqdm(data_loader)):#, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" else None)):
#         if dataset_split=="test":
#             if idx>int(1000/data_loader.batch_size+0.5):
#                 break
        if (idx%parallel_mode[1])!=parallel_mode[0]:
            continue
            
        images=batch['images']
        labels=batch['labels']
        paths=batch['path']
        updated_signal_list=[]
        
        for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
            # Get classifier output
            classifier_dict[backbone_type].eval()
            with torch.no_grad():
                classifier_output=classifier_dict[backbone_type](images.to(classifier_dict[backbone_type].device),
                                                                 output_attentions=True)        
            for explanation_method in ["kernelshapnopair"]:
                data_keys=estimationerror_save_dict[backbone_type][explanation_method].keys()
                data_keys=[adapt_path(data_keys_path, paths) for data_keys_path in data_keys]
                if all([(path in data_keys) or (path not in estimationerror_sample_path_list) for path in paths]):
                    continue
                else:
                    print(explanation_method,'not exist')
                    updated_signal_list.append(explanation_method)
                    
                if explanation_method=="ours":
                    start_time=time.time()
                    explainer_dict[backbone_type].eval()
                    with torch.no_grad():
                        explanation_ours = explainer_dict[backbone_type](images.to(explainer_dict[backbone_type].device))[0].detach().cpu().numpy().transpose(0, 2, 1)
                    elapsed_time=time.time()-start_time
                    estimationerror_save_dict_update(backbone_type, 
                                                     'ours', 
                                                     path_list=[path for path, paths in zip(paths, paths) if path in estimationerror_sample_path_list],
                                                     estimation_list=[estimation for path, estimation in zip(paths, explanation_ours) if path in estimationerror_sample_path_list],
                                                     label_list=[label for path, label in zip(paths, labels.cpu().numpy().tolist()) if path in estimationerror_sample_path_list],
                                                     shape=(_config["output_dim"], 196))
                    
                elif explanation_method=="kernelshap":
                    path_list=[]
                    explanation_kernelshap_list=[]
                    label_list=[]
                    
                    surrogate_SHAP_wrapped_dict[backbone_type].eval()
                    for path, image, label in zip(paths, images, labels.cpu().numpy().tolist()):
                        if path in estimationerror_sample_path_list:
                            start_time=time.time()      
                            game = games.PredictionGame_torchimagetensor(surrogate_SHAP_wrapped_dict[backbone_type],
                                                                         image)
                            
                            explanation_kernelshap = shapley.ShapleyRegression(game, 
                                                                    batch_size=64, 
                                                                    thresh=0.1,
                                                                    variance_batches=60,
                                                                    return_all=True)                              
                            
                            path_list.append(path)
                            explanation_kernelshap_list.append(explanation_kernelshap)
                            label_list.append(label)
                            
                    estimationerror_save_dict_update(backbone_type, 
                                                     'kernelshap', 
                                                     path_list=path_list, 
                                                     estimation_list=explanation_kernelshap_list, 
                                                     label_list=label_list,
                                                     shape=None)  
                    
                elif explanation_method=="kernelshapnopair":
                    path_list=[]
                    explanation_kernelshap_list=[]
                    label_list=[]
                    
                    surrogate_SHAP_wrapped_dict[backbone_type].eval()
                    for path, image, label in zip(paths, images, labels.cpu().numpy().tolist()):
                        if path in estimationerror_sample_path_list:
                            start_time=time.time()      
                            game = games.PredictionGame_torchimagetensor(surrogate_SHAP_wrapped_dict[backbone_type],
                                                                         image)
                            
                            explanation_kernelshap = shapley.ShapleyRegression(game, 
                                                                    batch_size=128, 
                                                                    detect_convergence=False,
                                                                    paired_sampling=False,
                                                                    n_samples=200000,
                                                                    variance_batches=60,
                                                                    return_all=True)                              
                            
                            path_list.append(path)
                            explanation_kernelshap_list.append(explanation_kernelshap)
                            label_list.append(label)
                            
                    estimationerror_save_dict_update(backbone_type, 
                                                     'kernelshapnopair', 
                                                     path_list=path_list, 
                                                     estimation_list=explanation_kernelshap_list, 
                                                     label_list=label_list,
                                                     shape=None)  
                    
                else:
                    raise
        
        try:
            for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
                for explanation_method, estimationerror_save_dict_backbone_method in estimationerror_save_dict[backbone_type].items():

                    if explanation_method not in updated_signal_list:
                        continue

                    estimationerror_save_dict_path=f'results/9_estimationerror/{_config["datasets"]}/{backbone_type}_{explanation_method}_{dataset_split}.pickle'

                    if os.path.isfile(estimationerror_save_dict_path):
                        try:
                            with open(estimationerror_save_dict_path, 'rb') as f:
                                estimationerror_save_dict_loaded=pickle.load(f)
                        except:
                            estimationerror_save_dict_loaded={}
                    else:
                        estimationerror_save_dict_loaded={}

                    len_original=len(estimationerror_save_dict_backbone_method)            
                    len_loaded=len(estimationerror_save_dict_loaded)
                    estimationerror_save_dict_backbone_method.update(estimationerror_save_dict_loaded)
                    len_updated=len(estimationerror_save_dict_backbone_method)

                    print(f'{explanation_method:24}  {len_original:6}   + {len_loaded:6}   -> {len_updated:6}')
                    
                    with open(estimationerror_save_dict_path, "wb") as f:
                        pickle.dump(estimationerror_save_dict_backbone_method, f)
        except:
            pass                    
            
    #         # Get ours
    #         start_time=time.time();explainer_dict[backbone_type].eval()
    #         with torch.no_grad():
    #             values, _, _=explainer_dict[backbone_type](torch.Tensor(image).unsqueeze(0).to(explainer_dict[backbone_type].device))
    #             values=values.cpu().numpy().transpose(0,2,1).squeeze(0)
    #         values_ours=(values, time.time()-start_time)        

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for path, data in estimationerror_save_dict[backbone_type]["kernelshap"].items():
        explanation_save_dict[backbone_type]['kernelshap'][path]={"explanation": data['estimation'][0].values.T,
                                                                  "elapsed_time": np.nan}
        print(path, data['estimation'][0].values.T.shape)    

In [None]:
for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    for explanation_method, estimationerror_save_dict_backbone_method in estimationerror_save_dict[backbone_type].items():
        print(backbone_type, explanation_method, len(estimationerror_save_dict_backbone_method))

In [None]:
[elapsed_time/len(paths) for i in range(len(paths))][0]

In [None]:
elapsedtime_save_dict_update??

In [None]:
classifier_masked_dict.keys()

In [None]:
%%javascript
Jupyter.notebook.session.delete();

# sensitivity-n

In [None]:
num_players=196
for idx, batch in enumerate(tqdm(data_loader, total = int(1000/data_loader.batch_size+0.5) if dataset_split=="test" else None)):   
    if dataset_split=="test":
        if idx>int(1000/data_loader.batch_size+0.5):
            break
    if (idx%parallel_mode[1])!=parallel_mode[0]:
        continue                                

    images=batch['images']
    labels=batch['labels']
    paths=batch['path']

    for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):            
        for image, path, label in zip(images, paths, labels):
            #path=path.replace('l0.cs.washington.edu','l2lambda.cs.washington.edu')
            for num_included_players in ["all"]:
                print(num_included_players)                
                if num_included_players=="all":
                    prob_all=[]
                    mask_all=[]
                    for random_iter in range(20):
                        image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type]["pre-softmax"].device)
                        mask=generate_mask(num_players=num_players,
                                           num_mask_samples=50,
                                           paired_mask_samples=False,
                                           mode="uniform",
                                           random_state=np.random.RandomState(random_iter))

                        surrogate_dict[backbone_type]["pre-softmax"].eval()
                        with torch.no_grad():
                            output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                                  masks=torch.Tensor(mask).to(surrogate_dict[backbone_type]["pre-softmax"].device))

                        if _config["output_dim"]==1:
                            prob=output['logits'].sigmoid().cpu().numpy()
                        else:
                            prob=output['logits'].softmax(dim=-1).cpu().numpy()   
                        prob_all.append(prob)
                        mask_all.append(mask)
                    prob_all=np.concatenate(prob_all, axis=0)
                    mask_all=np.concatenate(mask_all, axis=0)

                for explanation_method in explanation_method_to_run:                
                    print(explanation_method)
                    explanation=explanation_save_dict[backbone_type][explanation_method][adapt_path(path, list(explanation_save_dict[backbone_type][explanation_method].keys())[0])]['explanation']
                    explanation=explanation+np.random.RandomState(42).uniform(low=0, high=1e-40, size=explanation.shape)
                    explanation_mask=explanation@(mask_all.T)

                    if len(explanation.shape)==1:
                        correlation = np.array([stats.spearmanr(explanation_mask, prob_all[:, i]).correlation for i in range(prob_all.shape[1])])
                        assert correlation.shape==(_config["output_dim"],)

                    elif len(explanation.shape)==2:
                        #correlation=stats.spearmanr(np.concatenate([explanation_mask, prob_all.T], axis=0), axis=1).correlation
                        correlation = np.array([stats.spearmanr(explanation_mask[i], prob_all[:, i]).correlation for i in range(prob_all.shape[1])])
                        assert correlation.shape==(_config["output_dim"],)
                        
                        fig=plt.figure(figsize=(20,5))
                        ax=fig.add_subplot(121)
                        ax.scatter(explanation_mask[label], prob_all[:,label])#
                        ax.set_xlabel("sum_explanation")
                        ax.set_ylabel("model output")
                        ax.set_title(explanation_method)
                        
                        ax=fig.add_subplot(122)
                        for i in range(prob_all.shape[1]):
                            if i>3:
                                break
                            if i!=label:
                                ax.scatter(explanation_mask[i], prob_all[:,i])#, s=4)#
                        ax.set_xlabel("sum_explanation")
                        ax.set_ylabel("model output")
                        ax.set_title(explanation_method)                        
                        
                        plt.show()
                        
                    else:
                        raise
#                     sensitivity_save_dit_update(backbone_type, explanation_method,
#                                                  num_included_players=num_included_players,
#                                                  path_list=[path], sensitivity_list=[correlation],
#                                                  shape=(_config["output_dim"],))     
            break

        break
    break

In [None]:
fig=plt.figure()
ax=fig.add_subplot()
ax.scatter(explanation_mask[label], prob_all[:,label])#
ax.set_xlabel("sum_explanation")
ax.set_ylabel("model output")
ax.set_title(explanation_method)

# Visualization

In [None]:
# import seaborn as sns
fig = plt.figure()
ax_temp=fig.add_subplot()
plt.clf()

def visualize_result(x, values, pred=None, vmin_vmax='separate',
                     image_labels=['normal','abnormal']*5,
                     class_labels=['normal','abnormal']*5):
    # colormap
    from matplotlib import cm
    from matplotlib.colors import ListedColormap, LinearSegmentedColormap

    color_num = 1000
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis]    

    if isinstance(vmin_vmax,tuple):
        vmin, vmax= vmin_vmax
        assert vmin < vmax
        if vmin * vmax < 0:
            ratio=vmax/(-vmin+vmax)
            seismic = cm.get_cmap('seismic', color_num)
            newcolors = seismic(np.linspace(0, 1, color_num))
            newcolors[:int(color_num*(1-ratio))] = seismic(np.linspace(0, 0.5, int(color_num*(1-ratio))))
            newcolors[-int(color_num*ratio)-1:] = seismic(np.linspace(0.5, 1, int(color_num*ratio)+1))
            newcmp = ListedColormap(newcolors)
        elif vmin > 0:
            seismic = cm.get_cmap('seismic', color_num)
            newcolors = seismic(np.linspace(0.5, 1, color_num))
            newcmp = ListedColormap(newcolors)
        else:
            raise
    elif vmin_vmax=="separate":
        #seismic = cm.get_cmap('seismic', color_num)
        #newcolors = seismic(np.linspace(0.5, 1, color_num))
        #newcmp = ListedColormap(newcolors)        
        if values.min()>0:
            seismic = cm.get_cmap('seismic', color_num)
            newcolors = seismic(np.linspace(0.5, 1, color_num))
            newcmp = ListedColormap(newcolors)            
        else:
            seismic = cm.get_cmap('seismic', color_num)
            newcolors = seismic(np.linspace(0, 1, color_num))
            newcmp = ListedColormap(newcolors)


    fig, axes = plt.subplots(values.shape[0], 1+values.shape[1], figsize=(2*(1+values.shape[1]), 2*(values.shape[0]+1)))

    assert len(image_labels)==values.shape[0]==(len(image_labels) if pred is None else pred.shape[0])
    assert len(class_labels)==values.shape[1]==(len(class_labels) if pred is None else pred.shape[1])
    
    for row in range(axes.shape[0]):
        for col in range(axes.shape[1]):
            if col==0: # Image
                im = x[row].numpy() * img_std + img_mean # (C, H, W)
                im = im.transpose(1, 2, 0).astype(float) # (H, W, C)
                im = np.clip(im, a_min=0, a_max=1)

                axes[row, 0].imshow(im, vmin=0, vmax=1)
                axes[row, 0].set_ylabel('{}'.format(image_labels[row]), fontsize=12)
            else: # Explanation
                values_select=values[row, col-1]
                values_select_min, values_select_max=values_select.min(),values_select.max()

                if vmin_vmax=="separate":
                    if values.min()>0:
                        axes[row, col].imshow(values_select, cmap=newcmp,
                                              vmin=values_select_min,
                                              vmax=values_select_max)
                    else:
                        axes[row, col].imshow(values_select, cmap=newcmp, 
                                              vmin=-max([abs(values_select_min),abs(values_select_max)]), 
                                              vmax=max([abs(values_select_min),abs(values_select_max)]))
                else:
                    axes[row, col].imshow(values_select, cmap=newcmp, vmin=vmin, vmax=vmax)

                if pred is None:
                    axes[row, col].set_xlabel('{:.2f}/{:.2f}'.format(values_select_min, values_select_max), fontsize=12)
                else:
                    axes[row, col].set_xlabel('{:.2f} {:.2f}/{:.2f}'.format(pred[row, col-1], values_select_min, values_select_max), fontsize=12)            

                # Class labels
                if row == 0:
                    axes[row, col].set_title('{}'.format(class_labels[col-1]), fontsize=12)  

            axes[row, col].set_xticks([])
            axes[row, col].set_yticks([])                           

    if vmin_vmax!="separate":
        fig = plt.figure(figsize=(5, 0.5))
        ax=fig.add_subplot()        
        
        sns.heatmap([[0,0],[0,0]],
                    ax=ax_temp,
                    cmap=newcmp,
                    vmin=vmin,
                    vmax=vmax,
                    xticklabels=True,
                    linewidths=1,
                    linecolor=np.array([220,220,220,256])/256,
                    cbar_ax=ax,
                    cbar_kws={'fraction':0.1, "ticks":np.linspace(vmin, vmax, 5), "orientation": "horizontal"},
                    cbar=True,
                    alpha=1,edgecolor='black')#,legend=None)

        #plt.tight_layout()
        plt.show()

        fig=plt.figure()
        ax=fig.add_subplot()
        ax.hist(values.flatten(),bins=np.linspace(vmin, vmax, 20))
        ax.set_yscale('log')
        print(values.flatten().min(), values.flatten().max())

In [None]:
prob_all_all=[]
mask_all_all=[]
for idx, image in enumerate(x):
    print(idx)
    prob_all=[]
    mask_all=[]
    for random_iter in range(20):
        image_loaded=image[np.newaxis, :].repeat(50, 1, 1, 1).to(surrogate_dict[backbone_type]["pre-softmax"].device)
        mask=generate_mask(num_players=num_players,
                           num_mask_samples=50,
                           paired_mask_samples=False,
                           mode="uniform",
                           random_state=np.random.RandomState(random_iter))

        surrogate_dict[backbone_type]["pre-softmax"].eval()
        with torch.no_grad():
            output = surrogate_dict[backbone_type]["pre-softmax"](image_loaded,
                                                                  masks=torch.Tensor(mask).to(surrogate_dict[backbone_type]["pre-softmax"].device))

        if _config["output_dim"]==1:
            prob=output['logits'].sigmoid().cpu().numpy()
        else:
            prob=output['logits'].softmax(dim=-1).cpu().numpy()   
        prob_all.append(prob)
        mask_all.append(mask)
    prob_all=np.concatenate(prob_all, axis=0)
    mask_all=np.concatenate(mask_all, axis=0)
    #print(prob_all.shape, mask_all.shape)
    prob_all_all.append(prob_all)
    mask_all_all.append(mask_all)    

In [None]:
values.reshape(values.shape[0], values.shape[1], 196)

In [None]:
for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    #get classifier
    classifier_dict[backbone_type].eval()
    with torch.no_grad():
        output = classifier_dict[backbone_type](x.to(next(classifier_dict[backbone_type].parameters()).device), output_attentions=False)
        if _config["output_dim"]==1:
            pred=output['logits'].detach().sigmoid().cpu().data.numpy()
        else:
            pred=output['logits'].detach().softmax().cpu().data.numpy()
    del output    
    
    # Get explanation (modified_LRP)
    values=np.concatenate([np.concatenate([get_lrp_module_explanation(backbone_type, image.squeeze(0), class_index=i, mode='transformer_attribution').cpu() for i in range(_config["output_dim"])], axis=0)[np.newaxis,:] for image in x])
    values=values.reshape(values.shape[0], values.shape[1], 14, 14)
    
    visualize_result(x, pred=pred, values=values,
                     class_labels=y_labels[1:2] if _config["output_dim"]==1 else y_labels,
                     image_labels=y_labels,
                     vmin_vmax="separate")
    
    
    
    values_lrp=values.reshape(values.shape[0], values.shape[1], 196)
    
    
    explainer_dict[backbone_type].eval()
    with torch.no_grad():
        values=explainer_dict[backbone_type](x.to(next(explainer_dict[backbone_type].parameters()).device))    
        values=values[0].reshape(-1, _config["output_dim"], 14, 14).cpu().numpy()
        
    visualize_result(x, pred=pred, values=values,
                     class_labels=y_labels[1:2] if _config["output_dim"]==1 else y_labels,
                     image_labels=y_labels,
                     vmin_vmax=(-0.2, 0.2))    
    
    values_ours=values.reshape(values.shape[0], values.shape[1], 196)
    
    #visualize_result(x, pred=pred.repeat(10, axis=1), values=values.repeat(10, axis=1), separate=True)