# Setup

**Start the colab kernel with GPU**: Runtime -> Change runtime type -> GPU

## Install dependencies

In [1]:
!git clone https://github.com/suinleelab/vit-shapley

Cloning into 'vit-shapley'...
remote: Enumerating objects: 349, done.[K
remote: Counting objects: 100% (114/114), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 349 (delta 60), reused 88 (delta 44), pack-reused 235[K
Receiving objects: 100% (349/349), 137.42 MiB | 21.19 MiB/s, done.
Resolving deltas: 100% (136/136), done.


In [2]:
!pip uninstall -y torchtext torchaudio tensorflow arviz cxvpy
!pip install -r vit-shapley/requirements.txt 

Found existing installation: torchtext 0.14.1
Uninstalling torchtext-0.14.1:
  Successfully uninstalled torchtext-0.14.1
Found existing installation: torchaudio 0.13.1+cu116
Uninstalling torchaudio-0.13.1+cu116:
  Successfully uninstalled torchaudio-0.13.1+cu116
Found existing installation: tensorflow 2.12.0
Uninstalling tensorflow-2.12.0:
  Successfully uninstalled tensorflow-2.12.0
Found existing installation: arviz 0.15.1
Uninstalling arviz-0.15.1:
  Successfully uninstalled arviz-0.15.1
[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting numpy~=1.21.2
  Downloading numpy-1.21.6-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m85.3 MB/s[0m eta [36m0:00:00[0m
Collecting scikit-learn~=1.0.2
  Downloading scikit_learn-1.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━

In [3]:
import sys
sys.path.append("./vit-shapley")

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from collections import OrderedDict
import copy
import pickle
import time
from scipy import stats
from tqdm import tqdm
import requests

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader

from vit_shapley.datamodules.ImageNette_datamodule import ImageNetteDataModule
from vit_shapley.datamodules.MURA_datamodule import MURADataModule
from vit_shapley.datamodules.Pet_datamodule import PetDataModule

from vit_shapley.modules.classifier import Classifier
from vit_shapley.modules.classifier_masked import ClassifierMasked
from vit_shapley.modules.surrogate import Surrogate
from vit_shapley.modules.explainer import Explainer

from vit_shapley.config import ex
from vit_shapley.config import config, env_chanwkim, dataset_ImageNette, dataset_MURA, dataset_Pet


def download_file(url, path):
  # Streaming, so we can iterate over the response.
  response = requests.get(url, stream=True)
  total_size_in_bytes= int(response.headers.get('content-length', 0))
  block_size = 1024 #1 Kibibyte
  progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
  with open(path, 'wb') as file:
      for data in response.iter_content(block_size):
          progress_bar.update(len(data))
          file.write(data)
  progress_bar.close()
  if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
      print("ERROR, something went wrong")  

dataset_name="Pet"
#backbone_type="vit_base_patch16_224"
backbone_to_use=["vit_base_patch16_224"]


_config=config()
if dataset_name=="ImageNette":
  _config.update(dataset_ImageNette())
elif dataset_name=="MURA":
  _config.update(dataset_MURA())
elif dataset_name=="Pet":
  _config.update(dataset_Pet())    

_config.update({'gpus_classifier':[0,],
                'gpus_surrogate':[0,],
                'gpus_explainer':[0,]})

_config.update({'classifier_backbone_type': None,
                'classifier_download_weight': False,
                'classifier_load_path': None})
_config.update({'classifier_masked_mask_location': "pre-softmax",
                'classifier_enable_pos_embed': True,
                })
_config.update({'surrogate_mask_location': "pre-softmax"})
_config.update({'surrogate_backbone_type': None,
                'surrogate_download_weight': False,
                'surrogate_load_path': None})
_config.update({'explainer_num_mask_samples': 2,
                'explainer_paired_mask_samples': True})

In [5]:
if _config["datasets"]=="ImageNette":
    backbone_type_config_dict_=OrderedDict({
        "vit_base_patch16_224":{
            "surrogate_path":{
                "pre-softmax": "https://aimslab.cs.washington.edu/vitshapley/checkpoints/ImageNette_vit_base_patch16_224_surrogate_3i6zzjnp.ckpt",
                
                },
            "explainer_path": "https://aimslab.cs.washington.edu/vitshapley/checkpoints/ImageNette_vit_base_patch16_224_explainer_3ty85eft.ckpt"
        },
    })  
elif _config["datasets"]=="MURA":
    backbone_type_config_dict_=OrderedDict({
        "vit_base_patch16_224":{
            "surrogate_path":{
                "pre-softmax": "https://aimslab.cs.washington.edu/vitshapley/checkpoints/MURA_vit_base_patch16_224_surrogate_22ompjqu.ckpt",
                
                },
            "explainer_path": "https://aimslab.cs.washington.edu/vitshapley/checkpoints/MURA_vit_base_patch16_224_explainer_1dmhcwej.ckpt"
        },
    })    

elif _config["datasets"]=="Pet":    
    backbone_type_config_dict_=OrderedDict({
        "vit_base_patch16_224":{
            "surrogate_path":{
                "pre-softmax": "https://aimslab.cs.washington.edu/vitshapley/checkpoints/Pet_vit_base_patch16_224_surrogate_146vf465.ckpt",
                
                },
            "explainer_path": "https://aimslab.cs.washington.edu/vitshapley/checkpoints/Pet_vit_base_patch16_224_explainer_2oq7lhr7.ckpt"
        },
    })    
    

In [8]:
def set_datamodule(datasets,
                   dataset_location,
                   explanation_location_train,
                   explanation_mask_amount_train,
                   explanation_mask_ascending_train,
                   
                   explanation_location_val,
                   explanation_mask_amount_val,
                   explanation_mask_ascending_val,                   
                   
                   explanation_location_test,
                   explanation_mask_amount_test,
                   explanation_mask_ascending_test,                   
                   
                   transforms_train,
                   transforms_val,
                   transforms_test,
                   num_workers,
                   per_gpu_batch_size,
                   test_data_split):
    dataset_parameters = {
        "dataset_location": dataset_location,
        "explanation_location_train": explanation_location_train,
        "explanation_mask_amount_train": explanation_mask_amount_train,
        "explanation_mask_ascending_train": explanation_mask_ascending_train,
        
        "explanation_location_val": explanation_location_val,
        "explanation_mask_amount_val": explanation_mask_amount_val,
        "explanation_mask_ascending_val": explanation_mask_ascending_val,
        
        "explanation_location_test": explanation_location_test,
        "explanation_mask_amount_test": explanation_mask_amount_test,
        "explanation_mask_ascending_test": explanation_mask_ascending_test,        
        
        "transforms_train": transforms_train,
        "transforms_val": transforms_val,
        "transforms_test": transforms_test,
        "num_workers": num_workers,
        "per_gpu_batch_size": per_gpu_batch_size,
        "test_data_split": test_data_split
    }

    if datasets == "MURA":
        datamodule = MURADataModule(**dataset_parameters)
    elif datasets == "ImageNette":
        datamodule = ImageNetteDataModule(**dataset_parameters)
    elif datasets == "Pet":
        !mkdir pets
        download_file("https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz", "pets_images.tar.gz")
        download_file("https://thor.robots.ox.ac.uk/~vgg/data/pets/annotations.tar.gz", "pets_annotations.tar.gz")
        !tar -xvf pets_images.tar.gz -C ./pets
        !tar -xvf pets_annotations.tar.gz -C ./pets
        dataset_parameters["dataset_location"]="./pets"
        datamodule = PetDataModule(**dataset_parameters)        
    else:
        ValueError("Invalid 'datasets' configuration")
    return datamodule

datamodule = set_datamodule(datasets=_config["datasets"],
                            dataset_location=_config["dataset_location"],

                            explanation_location_train=_config["explanation_location_train"],
                            explanation_mask_amount_train=_config["explanation_mask_amount_train"],
                            explanation_mask_ascending_train=_config["explanation_mask_ascending_train"],

                            explanation_location_val=_config["explanation_location_val"],
                            explanation_mask_amount_val=_config["explanation_mask_amount_val"],
                            explanation_mask_ascending_val=_config["explanation_mask_ascending_val"],

                            explanation_location_test=_config["explanation_location_test"],
                            explanation_mask_amount_test=_config["explanation_mask_amount_test"],
                            explanation_mask_ascending_test=_config["explanation_mask_ascending_test"],                            

                            transforms_train=_config["transforms_train"],
                            transforms_val=_config["transforms_val"],
                            transforms_test=_config["transforms_test"],
                            num_workers=_config["num_workers"],
                            per_gpu_batch_size=_config["per_gpu_batch_size"],
                            test_data_split=_config["test_data_split"])

datamodule.set_test_dataset()

test_dataset=datamodule.test_dataset

dset=test_dataset

mkdir: cannot create directory ‘pets’: File exists


100%|██████████| 792M/792M [00:13<00:00, 57.7MiB/s]
100%|██████████| 19.2M/19.2M [00:00<00:00, 65.7MiB/s]


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
annotations/trimaps/._pomeranian_180.png
annotations/trimaps/pomeranian_180.png
annotations/trimaps/._pomeranian_181.png
annotations/trimaps/pomeranian_181.png
annotations/trimaps/._pomeranian_182.png
annotations/trimaps/pomeranian_182.png
annotations/trimaps/._pomeranian_183.png
annotations/trimaps/pomeranian_183.png
annotations/trimaps/._pomeranian_184.png
annotations/trimaps/pomeranian_184.png
annotations/trimaps/._pomeranian_185.png
annotations/trimaps/pomeranian_185.png
annotations/trimaps/._pomeranian_186.png
annotations/trimaps/pomeranian_186.png
annotations/trimaps/._pomeranian_187.png
annotations/trimaps/pomeranian_187.png
annotations/trimaps/._pomeranian_188.png
annotations/trimaps/pomeranian_188.png
annotations/trimaps/._pomeranian_189.png
annotations/trimaps/pomeranian_189.png
annotations/trimaps/._pomeranian_19.png
annotations/trimaps/pomeranian_19.png
annotations/trimaps/._pomeranian_190.png
annotations/trim

In [9]:
backbone_type_config_dict = OrderedDict()
for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict_.items()):
    if backbone_type in backbone_to_use:
        print(backbone_type)
        backbone_type_config_dict[backbone_type]=backbone_type_config

vit_base_patch16_224


In [10]:
surrogate_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    mask_method_dict = OrderedDict()
    for mask_location in backbone_type_config["surrogate_path"].keys():
        download_file(backbone_type_config["surrogate_path"][mask_location], "surrogate.ckpt")
        mask_method_dict[mask_location] = Surrogate(mask_location=mask_location if mask_location!="original" else "pre-softmax",
                                          backbone_type=backbone_type,
                                          download_weight=_config['surrogate_download_weight'],
                                          load_path="surrogate.ckpt",
                                          target_type=_config["target_type"],
                                          output_dim=_config["output_dim"],

                                          target_model=None,
                                          checkpoint_metric=None,
                                          optim_type=None,
                                          learning_rate=None,
                                          weight_decay=None,
                                          decay_power=None,
                                          warmup_steps=None).to(_config["gpus_surrogate"][idx])
    surrogate_dict[backbone_type]=mask_method_dict

100%|██████████| 1.37G/1.37G [01:18<00:00, 17.4MiB/s]


In [11]:
_config.update({'explainer_normalization': "additive",
                'explainer_activation': "tanh",
                'explainer_link': 'sigmoid' if _config["output_dim"]==1 else 'softmax',
                'explainer_head_num_attention_blocks': 1,
                'explainer_head_include_cls': True,
                'explainer_head_num_mlp_layers': 3,
                'explainer_head_mlp_layer_ratio': 4,
                'explainer_residual': [],
                'explainer_freeze_backbone': "all"})

explainer_dict = OrderedDict()
for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    download_file(backbone_type_config["explainer_path"], "explainer.ckpt")
    explainer_dict[backbone_type] = Explainer(normalization=_config["explainer_normalization"],
                                              normalization_class=_config["explainer_normalization_class"],
                                              activation=_config["explainer_activation"],
                                              surrogate=surrogate_dict[backbone_type]["pre-softmax"],
                                              link=_config["explainer_link"],
                                              backbone_type=backbone_type,
                                              download_weight=False,
                                              residual=_config['explainer_residual'],
                                              load_path="explainer.ckpt",
                                              target_type=_config["target_type"],
                                              output_dim=_config["output_dim"],

                                              explainer_head_num_attention_blocks=_config["explainer_head_num_attention_blocks"],
                                              explainer_head_include_cls=_config["explainer_head_include_cls"],
                                              explainer_head_num_mlp_layers=_config["explainer_head_num_mlp_layers"],
                                              explainer_head_mlp_layer_ratio=_config["explainer_head_mlp_layer_ratio"],
                                              explainer_norm=_config["explainer_norm"],

                                              efficiency_lambda=_config["explainer_efficiency_lambda"],
                                              efficiency_class_lambda=_config["explainer_efficiency_class_lambda"],
                                              freeze_backbone=_config["explainer_freeze_backbone"],

                                              checkpoint_metric=_config["checkpoint_metric"],
                                              optim_type=_config["optim_type"],
                                              learning_rate=_config["learning_rate"],
                                              weight_decay=_config["weight_decay"],
                                              decay_power=_config["decay_power"],
                                              warmup_steps=_config["warmup_steps"]).to(_config["gpus_explainer"][idx])

100%|██████████| 1.60G/1.60G [01:20<00:00, 19.8MiB/s]


In [12]:
label_dict={}

label_dict["ImageNette"]=['Cassette player', 
                          'Garbage truck', 
                          'Tench', 
                          'English springer', 
                          'Church', 
                          'Parachute', 
                          'French horn', 
                          'Chain saw', 
                          'Golf ball', 
                          'Gas pump']
label_dict["MURA"]=["Normal", "Abnormal"]


label_dict["Pet"]=['Abyssinian',
                       'American Bulldog',
                       'American Pit Bull Terrier',
                       'Basset Hound',
                       'Beagle',
                       'Bengal',
                       'Birman',
                       'Bombay',
                       'boxer',
                       'British Shorthair',
                       'Chihuahua',
                       'Egyptian Mau',
                       'English Cocker Spaniel',
                       'English Setter',
                       'German Shorthaired',
                       'Great Pyrenees',
                       'Havanese',
                       'Japanese Chin',
                       'Keeshond',
                       'Leonberger',
                       'Maine Coon',
                       'Miniature Pinscher',
                       'Newfoundland',
                       'Persian',
                       'Pomeranian',
                       'Pug',
                       'Ragdoll',
                       'Russian_Blue',
                       'Saint Bernard',
                       'Samoyed',
                       'Scottish Terrier',
                       'Shiba_inu',
                       'Siamese',
                       'Sphynx',
                       'Staffordshire Bull Terrier',
                       'Wheaten Terrier',
                       'Yorkshire Terrier']

In [13]:
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

import matplotlib.gridspec as gridspec

from matplotlib.patches import Patch
from matplotlib.lines import Line2D

from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import matplotlib.ticker as ticker
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


In [14]:
def plot_figure(sample_idx_list, explainer):
    plt.rcParams["font.size"] = 12
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=[dset[sample_idx]["labels"] for sample_idx in sample_idx_list]
    class_list = np.unique(label_choice).tolist()

    fig = plt.figure(figsize=(2.3*(len(["image"]+class_list)+0.2*len(["empty"])), 3*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+class_list), 
                              wspace=0.1, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(class_list))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+class_list):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.3)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        #continue
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 

    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dset[sample_idx]

        image = dataset_item["images"]
        label = dataset_item["labels"]
        path = dataset_item["path"]

        #print(idx1, sample_idx)

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())

        #class_idx=label     

        for idx2, plot_type in enumerate(["image"]+["empty"]+class_list):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"

                axd[plot_key].imshow(image_unnormlized_scaled)

                axd[plot_key].set_xticks([]) 
                axd[plot_key].set_yticks([])             
                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1)
                if dataset_name=="ImageNette" or "Pet":
                    #print('label', idx1, label_choice)
                    axd[plot_key].set_title(f"{label_dict[dataset_name][label_choice[idx1]]}", pad=7, zorder=10)
                else:
                    axd[plot_key].set_title(f"Abnormal", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                explanation=explainer(image.unsqueeze(0).to(explainer.device))[0][0].T

                if len(explanation.shape)==2:
                    explanation_class=explanation[plot_type].detach().cpu().numpy()
                else:
                    explanation_class=explanation.detach().cpu().numpy()


                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        


                colormap_max=np.max(np.abs(explanation_class_expanded))

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/colormap_max*0.5)
                cmap=sns.color_palette("icefire", as_cmap=True)#cmap=cmr.redshift#cmap=cm.get_cmap('seismic', 1000)
                explanation_class_expanded_heatmap=cmap(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                cmap=cm.get_cmap('Greys', 1000) 
                image_unnormlized_normalized=cmap(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                plot_key=f"{sample_idx}_{plot_type}"

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)

                axd[plot_key].set_xticks([])
                axd[plot_key].set_yticks([])
                for axis in ['top','bottom','left','right']:
                    axd[plot_key].spines[axis].set_linewidth(1)  


                axd[plot_key].set_title(label_dict[dataset_name][plot_type])

In [25]:
plot_figure(sample_idx_list=[7, 12, 13, 14, 15, 16, 17], explainer=explainer_dict[backbone_type])

Output hidden; open in https://colab.research.google.com to view.