In [1]:
import captum
import torch
import thermostat
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import IntegratedGradients, FeatureAblation
import onnxruntime as ort
from tqdm import tqdm

import json

In [2]:
from datasets import load_dataset
dataset = load_dataset('imdb', ignore_verifications=True)

Reusing dataset imdb (/home/tim/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
ds_sorted = sorted([(i, x) for i,x in enumerate(dataset["test"])], key=lambda x: len(x[1]["text"]))

In [4]:
ds_sorted[20]

(17327,
 {'text': 'Don\'t waste your time and money on it. It\'s not quite as bad as "Adrenalin", by the same director but that\'s not saying much.',
  'label': 0})

keys should be {text, input_ids, tokens, attributions, idx)

In [5]:
tokenizer = AutoTokenizer.from_pretrained("textattack/distilbert-base-uncased-imdb")

ort_session = ort.InferenceSession("../models/distilbert-base-uncased-imdb/model-optimized-quantized.onnx")
forward_func = lambda x: torch.tensor(ort_session.run(["output_0"], dict(input_ids=np.array(x),
                                                        attention_mask=np.ones_like(x)))[0])

In [6]:
method = FeatureAblation(forward_func)

In [7]:
samples = []
for idx, data in tqdm(ds_sorted[:1000]):
    input1 = tokenizer(data["text"])
    input_ids = torch.tensor([input1["input_ids"]]).long()
    attention_mask = torch.tensor([input1["attention_mask"]]).long()
    target_class = data["label"]
    attributions = method.attribute(input_ids,target=target_class, method='gausslegendre')
    
    sample={"idx": idx,
            "text": data["text"],
            "attributions": attributions[0].tolist()}
    tokenized = tokenizer(data["text"])
    sample["input_ids"] = tokenized["input_ids"]
    sample["attention_mask"] = tokenized["attention_mask"]
    sample["tokens"] = tokenizer.convert_ids_to_tokens(tokenized["input_ids"])
    
    samples.append(sample)

100%|█████████████████████████████████████████████████████████████| 1000/1000 [19:59<00:00,  1.20s/it]


In [8]:
with open("../data/albert-imdb-feature-ablation-1000.json", "w") as fp:
    json.dump(samples, fp)