In [None]:
import os
print(os.getcwd())
os.chdir('../')
print(os.getcwd())

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from collections import OrderedDict
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from vitshapley.datamodules.ImageNette_datamodule import ImageNetteDataModule

from vitshapley.modules.classifier import Classifier
from vitshapley.modules.surrogate import Surrogate
from vitshapley.modules.explainer import Explainer

from vitshapley.config import ex
from vitshapley.config import config, env_chanwkim, dataset_ImageNette

dataset_split="test"
backbone_to_use=["vit_base_patch16_224"]


_config=config()
_config.update(env_chanwkim()); _config.update({'gpus_classifier':[0,],
                                                'gpus_surrogate':[0,],
                                                'gpus_explainer':[0,]})
_config.update(dataset_ImageNette())
_config.update({'classifier_backbone_type': None,
                'classifier_download_weight': False,
                'classifier_load_path': None})
_config.update({'surrogate_mask_location': "pre-softmax"})
_config.update({'surrogate_backbone_type': None,
                'surrogate_download_weight': False,
                'surrogate_load_path': None})
_config.update({'explainer_num_mask_samples': 2,
                'explainer_paired_mask_samples': True})
print('done')

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from collections import OrderedDict
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
backbone_type_config_dict_=OrderedDict({
    "vit_small_patch16_224":{
        "surrogate_path":{"original": "results/transformer_interpretability/17inn4ht/checkpoints/epoch=14-step=2204.ckpt",
                          "pre-softmax": "results/transformer_interpretability/3kv2ns41/checkpoints/epoch=29-step=4409.ckpt",
                          "post-softmax": "results/transformer_interpretability/31as48v7/checkpoints/epoch=32-step=4850.ckpt",
                          "zero-input": "results/transformer_interpretability/j8sihn8t/checkpoints/epoch=33-step=4997.ckpt"},
    },
    "deit_small_patch16_224":{
        "surrogate_path": {},
    },
    "vit_base_patch16_224":{
        "surrogate_path": {"original": "results/transformer_interpretability/3f67z73f/checkpoints/epoch=11-step=1763.ckpt",
                           "pre-softmax": "results/transformer_interpretability/zeydyraj/checkpoints/epoch=15-step=2351.ckpt",
                           "post-softmax": "results/transformer_interpretability/1ijt5xox/checkpoints/epoch=33-step=4997.ckpt",
                           "zero-input": "results/transformer_interpretability/1w1sgm9q/checkpoints/epoch=15-step=2351.ckpt"
                          },
    },
    "deit_base_patch16_224":{
        "surrogate_path": {},
    }
})


In [None]:
def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,
                  mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:
    """
    Args:
        num_players: the number of players in the coalitional game
        num_mask_samples: the number of masks to generate
        paired_mask_samples: if True, the generated masks are pairs of x and 1-x.
        mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')
        random_state: random generator

    Returns:
        torch.Tensor of shape
        (num_masks, num_players) if num_masks is int
        (num_players) if num_masks is None

    """
    random_state = random_state or np.random

    num_samples_ = num_mask_samples or 1

    if paired_mask_samples:
        assert num_samples_ % 2 == 0, "'num_samples' must be a multiple of 2 if 'paired' is True"
        num_samples_ = num_samples_ // 2
    else:
        num_samples_ = num_samples_

    if mode == 'uniform':
        masks = (random_state.rand(num_samples_, num_players) > random_state.rand(num_samples_, 1)).astype('int')
    elif mode == 'shapley':
        probs = 1 / (np.arange(1, num_players) * (num_players - np.arange(1, num_players)))
        probs = probs / probs.sum()
        masks = (random_state.rand(num_samples_, num_players) > 1 / num_players * random_state.choice(
            np.arange(num_players - 1), p=probs, size=[num_samples_, 1])).astype('int')
    else:
        raise ValueError("'mode' must be 'random' or 'shapley'")

    if paired_mask_samples:
        masks = np.stack([masks, 1 - masks], axis=1).reshape(num_samples_ * 2, num_players)

    if num_mask_samples is None:
        masks = masks.squeeze(0)
        return masks  # (num_masks)
    else:
        return masks  # (num_samples, num_masks)


def set_datamodule(datasets,
                   dataset_location,
                   transforms_train,
                   transforms_val,
                   transforms_test,
                   num_workers,
                   per_gpu_batch_size,
                   test_data_split):

    dataset_parameters = {
        "dataset_location": dataset_location,
        "transforms_train": transforms_train,
        "transforms_val": transforms_val,
        "transforms_test": transforms_test,
        "num_workers": num_workers,
        "per_gpu_batch_size": per_gpu_batch_size,
        "test_data_split": test_data_split
    }

    if datasets == "APTOS2019":
        datamodule = APTOS2019DataModule(**dataset_parameters)
    elif datasets == "CheXpert":
        datamodule = CheXpertDataModule(**dataset_parameters)
    elif datasets == "MIMIC":
        datamodule = MIMICDataModule(**dataset_parameters)
    elif datasets == "CIFAR":
        datamodule = CIFARDataModule(**dataset_parameters)
    elif datasets == "ImageNette":
        datamodule = ImageNetteDataModule(**dataset_parameters)
    else:
        ValueError("Invalid 'datasets' configuration")
    return datamodule

datamodule = set_datamodule(datasets=_config["datasets"],
                            dataset_location=_config["dataset_location"],
                            transforms_train=_config["transforms_train"],
                            transforms_val=_config["transforms_val"],
                            transforms_test=_config["transforms_test"],
                            num_workers=_config["num_workers"],
                            per_gpu_batch_size=_config["per_gpu_batch_size"],
                            test_data_split=_config["test_data_split"])

# The batch for training classifier consists of images and labels, but the batch for training explainer consists of images and masks.
# The masks are generated to follow the Shapley distribution.
"""
original_getitem = copy.deepcopy(datamodule.dataset_cls.__getitem__)
def __getitem__(self, idx):
    if self.split == 'train':
        masks = generate_mask(num_players=surrogate.num_players,
                              num_mask_samples=_config["explainer_num_mask_samples"],
                              paired_mask_samples=_config["explainer_paired_mask_samples"], mode='shapley')
    elif self.split == 'val' or self.split == 'test':
        # get cached if available
        if not hasattr(self, "masks_cached"):
            self.masks_cached = {}
        masks = self.masks_cached.setdefault(idx, generate_mask(num_players=surrogate.num_players,
                                                                num_mask_samples=_config[
                                                                    "explainer_num_mask_samples"],
                                                                paired_mask_samples=_config[
                                                                    "explainer_paired_mask_samples"],
                                                                mode='shapley'))
    else:
        raise ValueError("'split' variable must be train, val or test.")
    return {"images": original_getitem(self, idx)["images"],
            "labels": original_getitem(self, idx)["labels"],
            "masks": masks}
datamodule.dataset_cls.__getitem__ = __getitem__
"""

datamodule.set_train_dataset()
datamodule.set_val_dataset()
datamodule.set_test_dataset()

train_dataset=datamodule.train_dataset
val_dataset=datamodule.val_dataset
test_dataset=datamodule.test_dataset

classidx=4

if dataset_split=="train":
    dset = train_dataset 
elif dataset_split=="val":
    dset = val_dataset     
elif dataset_split=="test": 
    dset = test_dataset
else:
    raise

labels = np.array([i['label'] for i in dset.data])
num_classes = labels.max() + 1

images_idx_list = [np.where(labels == category)[0] for category in range(num_classes)]
images_idx = [category_idx[classidx] for category_idx in images_idx_list]

xy=[dset[idx] for idx in images_idx]
x, y = zip(*[(i['images'], i['labels']) for i in xy])
x = torch.stack(x)
len(dset)

In [None]:
backbone_type_config_dict = OrderedDict()
for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict_.items()):
    if backbone_type in backbone_to_use:
        print(backbone_type)
        backbone_type_config_dict[backbone_type]=backbone_type_config

In [None]:
surrogate_dict = OrderedDict()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    mask_method_dict = OrderedDict()
    for mask_location in ["original", "pre-softmax", "post-softmax", "zero-input"]:
        mask_method_dict[mask_location] = Surrogate(mask_location=mask_location if mask_location!="original" else "pre-softmax",
                                          backbone_type=backbone_type,
                                          download_weight=_config['surrogate_download_weight'],
                                          load_path=backbone_type_config["surrogate_path"][mask_location],
                                          target_type=_config["target_type"],
                                          output_dim=_config["output_dim"],

                                          target_model=None,
                                          checkpoint_metric=None,
                                          optim_type=None,
                                          learning_rate=None,
                                          weight_decay=None,
                                          decay_power=None,
                                          warmup_steps=None).to(_config["gpus_surrogate"][idx])
    surrogate_dict[backbone_type]=mask_method_dict

In [None]:
dset_loader=DataLoader(dset, batch_size=64, num_workers=4, shuffle=False, drop_last=True)

In [None]:
from tqdm import tqdm
import copy

mask = (torch.rand(1, 196)>0.5).int()

for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):
    surrogate = surrogate_dict[backbone_type]['pre-softmax']
    surrogate_ = copy.deepcopy(surrogate)
    for batch_idx, batch in enumerate(tqdm(dset_loader, unit='batch')):  
        with torch.no_grad():
            logits = surrogate(batch["images"].to(surrogate.device),
                               torch.repeat_interleave(mask, len(batch["images"]), dim=0).to(surrogate.device))['logits']
        
            image_patchified=surrogate.backbone.patch_embed(batch["images"].to(surrogate.device))
            surrogate_.backbone.pos_embed=torch.nn.Parameter(torch.concat([surrogate.backbone.pos_embed[:,0:1],
                                                                           surrogate.backbone.pos_embed[:,1:][:, mask[0]==1]], dim=1))
            
            image_patchified_attention=surrogate_.backbone.forward_features(image_patchified[:,mask[0]==1,:], 
                                                                            torch.ones(len(image_patchified),(mask[0]==1).sum().item()).to(surrogate_.device), 'pre-softmax')
            logits_held_out = surrogate.head(image_patchified_attention['x'])
            
            images_perturbed=copy.deepcopy(batch["images"])
            images_perturbed[torch.repeat_interleave(torch.repeat_interleave(torch.repeat_interleave(torch.repeat_interleave(mask.reshape(1, 1, 14, 14), 16, dim=2), 16, dim=3), 64, dim=0), 3, dim=1)==0]=4242
            logits_perturbed = surrogate(images_perturbed.to(surrogate.device),
                               torch.repeat_interleave(mask, len(batch["images"]), dim=0).to(surrogate.device))['logits']            
            
            
            assert torch.isclose(logits, logits_held_out, rtol=1e-2).all()
            assert torch.isclose(logits, logits_perturbed, rtol=1e-2).all()
        
            
    break

In [None]:
with torch.no_grad():
    b=surrogate_.backbone.forward_features(image_patchified[:,mask[0]==1,:], 
    torch.ones(len(image_patchified),(mask[0]==1).sum().item()).to(surrogate_.device), 'pre-softmax')['x']

In [None]:
with torch.no_grad():
    a=surrogate_.backbone.forward_features(image_patchified[:,mask[0]==1,:], 
    torch.ones(len(image_patchified),(mask[0]==1).sum().item()).to(surrogate_.device), 'pre-softmax_')['x']

In [None]:
(a==b).all()