# Explainable AI applied to assessors

In [2]:
from typing import *
from pathlib import Path

from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
from transformers.models.auto.tokenization_auto import AutoTokenizer
import numpy as np
import pandas as pd
import bigbench.api.results as bb

from lass.log_handling import LogLoader
from lass.datasets import split_instance_level, huggingfaceify

In [2]:
print("Loading data...")
loader = (LogLoader(logdir = Path('../artifacts/logs'))
        .with_tasks('paper-full')
        .with_model_families(['BIG-G T=0'])
        .with_model_sizes(['128b'])
        .with_shots([0])
        .with_query_types([bb.MultipleChoiceQuery])
)

_train, test = split_instance_level(loader, seed=42, test_fraction=0.2)
print("Data loaded.")

Loading data...


Data loaded.


In [3]:
import transformers
%env TOKENIZERS_PARALLELISM=true

dataset = huggingfaceify(_train[:1], test)
dataset['test'][0]

env: TOKENIZERS_PARALLELISM=true


{'text': 'In what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Barry was furious when the power went out-he wanted to read his new novel! He paced through the dark house, swearing angrily. But the more he ranted and raved, the worse he felt. Then Barry remembered he had a battery-powered lantern, so he dug it out and set it up. Pretty soon, one corner of the house was warmly lit and welcoming. Barry settled down happily with his book and waited for the power to return.\nThis narrative is a good illustration of the following proverb: ',
 'label': 1}

In [4]:
transformers.logging.set_verbosity_error() # type: ignore
model = AutoModelForSequenceClassification.from_pretrained("../artifacts/assessors/bert-bs32-0sh/checkpoint-1500")

# Tokenize according to specific model tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

# This "transformer-interpret" library doesn't really deal well with truncation
# of long sequences, so we'll just truncate the sequences ourselves.
# 
# Note: Encoding is a destructive process,
# so we need to work with offset_mapping instead of just reversing.
# https://github.com/huggingface/tokenizers/issues/826#issuecomment-966082496
def truncate(batch):
   tokens = tokenizer(batch["text"], padding="max_length", truncation=True, 
      return_tensors="np", return_offsets_mapping=True)

   # We use `amax` as a trick to find the last non-zero offset mapping.
   # Array dimensions are [batch size, sequence length (in tokens), 2],
   # where the last dimension is [start, end] of the token (referring to index in the string).
   # With [:,:,-1], we make it [batch size, sequence length], taking only the end offset of the token.
   # Then we take the max of each sequence, producing batch_size numbers.
   lengths = np.amax(tokens['offset_mapping'][:,:,-1], axis=1) # type: ignore
   assert len(lengths) == len(batch['text'])

   # Now we cut all the strings
   texts = [text[:end] for text, end in zip(batch["text"], lengths)]
   return {'text': texts, 'label': batch['label']}

truncated_datasets = dataset.map(truncate, batched=True)
truncated_datasets['test'][0]



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

{'text': 'In what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Barry was furious when the power went out-he wanted to read his new novel! He paced through the dark house, swearing angrily. But the more he ranted and raved, the worse he felt. Then Barry remembered he had a battery-powered lantern, so he dug it out and set it up. Pretty soon, one corner of the house was warmly lit and welcoming. Barry settled down happily with his book and waited for the power to return.\nThis narrative is a good illustration of the following proverb:',
 'label': 1}

In [5]:
from transformers_interpret import SequenceClassificationExplainer
from collections import defaultdict

cls_explainer = SequenceClassificationExplainer(model, tokenizer) #type: ignore

# Print header
path = Path("xai.csv")
pd.DataFrame([],columns=['word', 'contribution', 'LM_score', 'Assr_pred']).to_csv(path, index=False)

# for index, instance in enumerate(truncated_datasets['test'].select(range(255, 260))):
for index, instance in enumerate(truncated_datasets['test']):
    if index % 50 == 0: # type: ignore
        print(f"{index}/{len(truncated_datasets['test'])}")

    # EXPLAINABILITY
    text, LM_correct = instance['text'], instance['label'] # type: ignore
    exp_neg = cls_explainer(text, class_name='LABEL_0')

    frame = pd.DataFrame(exp_neg, columns=['word', 'contribution'])
    frame['LM_score'] = LM_correct
    frame['Assr_pred'] = cls_explainer.predicted_class_index
    frame.to_csv(path, mode='a', header=False, index=False)

0/11004


50/11004


100/11004


150/11004


200/11004


250/11004


300/11004


350/11004


400/11004


450/11004


500/11004


550/11004


600/11004


650/11004


700/11004


750/11004


800/11004


850/11004


900/11004


950/11004


1000/11004


1050/11004


1100/11004


1150/11004


1200/11004


1250/11004


1300/11004


1350/11004


1400/11004


1450/11004


1500/11004


1550/11004


1600/11004


1650/11004


1700/11004


1750/11004


1800/11004


1850/11004


1900/11004


1950/11004


2000/11004


2050/11004


2100/11004


2150/11004


2200/11004


2250/11004


2300/11004


2350/11004


2400/11004


2450/11004


2500/11004


2550/11004


2600/11004


2650/11004


2700/11004


2750/11004


2800/11004


2850/11004


2900/11004


2950/11004


3000/11004


3050/11004


3100/11004


3150/11004


3200/11004


3250/11004


3300/11004


3350/11004


3400/11004


3450/11004


3500/11004


3550/11004


3600/11004


3650/11004


3700/11004


3750/11004


3800/11004


3850/11004


3900/11004


3950/11004


4000/11004


4050/11004


4100/11004


4150/11004


4200/11004


4250/11004


4300/11004


4350/11004


4400/11004


4450/11004


4500/11004


4550/11004


4600/11004


4650/11004


4700/11004


4750/11004


4800/11004


4850/11004


4900/11004


4950/11004


5000/11004


5050/11004


5100/11004


5150/11004


5200/11004


5250/11004


5300/11004


5350/11004


5400/11004


5450/11004


5500/11004


5550/11004


5600/11004


5650/11004


5700/11004


5750/11004


5800/11004


5850/11004


5900/11004


5950/11004


6000/11004


6050/11004


6100/11004


6150/11004


6200/11004


6250/11004


6300/11004


6350/11004


6400/11004


6450/11004


6500/11004


6550/11004


6600/11004


6650/11004


6700/11004


6750/11004


6800/11004


6850/11004


6900/11004


6950/11004


7000/11004


7050/11004


7100/11004


7150/11004


7200/11004


7250/11004


7300/11004


7350/11004


7400/11004


7450/11004


7500/11004


7550/11004


7600/11004


7650/11004


7700/11004


7750/11004


7800/11004


7850/11004


7900/11004


7950/11004


8000/11004


8050/11004


8100/11004


8150/11004


8200/11004


8250/11004


8300/11004


8350/11004


8400/11004


8450/11004


8500/11004


8550/11004


8600/11004


8650/11004


8700/11004


8750/11004


8800/11004


8850/11004


8900/11004


8950/11004


9000/11004


9050/11004


9100/11004


9150/11004


9200/11004


9250/11004


9300/11004


9350/11004


9400/11004


9450/11004


9500/11004


9550/11004


9600/11004


9650/11004


9700/11004


9750/11004


9800/11004


9850/11004


9900/11004


9950/11004


10000/11004


10050/11004


10100/11004


10150/11004


10200/11004


10250/11004


10300/11004


10350/11004


10400/11004


10450/11004


10500/11004


10550/11004


10600/11004


10650/11004


10700/11004


10750/11004


10800/11004


10850/11004


10900/11004


10950/11004


11000/11004


In [17]:
words = pd.read_csv("xai_bert.csv")
mean = (words
    .query('LM_score == 1')
    .groupby('word')
    .agg( # type: ignore
        mean_contribution=('contribution', 'mean'),
        word_count=('word', 'count')
    )
    .sort_values('mean_contribution', ascending=False)
    .query('word_count > 10')
)
mean.to_csv("../artifacts/tmp/xai.csv")