In [7]:

import json
from typing import Callable
from cfg import DATA_PATH
import numpy as np
import torch
from captum.attr import KernelShap, Lime
from tqdm import tqdm

from attribution_methods import AttributionMethod
from baselines import ZeroBaselineFactory
from evaluators import ProportionalityEvaluator
from experiment_runner import ExperimentRunner
from models import load_distilbert

In [8]:
with open("../data/imdb-distilbert-1000.json", "r") as fp:
    dataset = json.load(fp)

In [9]:
model = load_distilbert(from_notebook=1)
evaluator = ProportionalityEvaluator(model=model, baseline_factory=ZeroBaselineFactory)

In [10]:
class KernelShapWrapper(AttributionMethod):

    def __init__(self, model: Callable):
        self.model = model
        forward_func = lambda x: torch.tensor(model(x.squeeze().numpy())[None])
        self.method = Lime(forward_func=forward_func)

    def get_attribution_values(self, observation: np.array):
        target_class = torch.tensor(np.argmax(self.model(observation)))
        observation = torch.tensor(observation[None]).long()
        attribution = self.method.attribute(observation, target=target_class)
        return attribution[0].detach().numpy()


method = KernelShapWrapper(model)

In [11]:
runner = ExperimentRunner(name="test-kernelshap", num_samples=4, attribution_method=method, dataset=dataset, evaluator=evaluator, experiment=None,
                          softmax_attributions=False)
runner.run()

100%|██████████| 4/4 [00:01<00:00,  3.58it/s]
