# Explainable AI applied to assessors

In [1]:
from typing import *
from pathlib import Path

from transformers.models.auto.modeling_auto import AutoModelForSequenceClassification
from transformers.models.auto.tokenization_auto import AutoTokenizer
import numpy as np
import pandas as pd
import bigbench.api.results as bb

from lass.log_handling import LogLoader
from lass.datasets import split_instance_level, huggingfaceify

In [2]:
print("Loading data...")
loader = (LogLoader(logdir = Path('../artifacts/logs'))
        .with_tasks('paper-full')
        .with_model_families(['BIG-G T=0'])
        .with_model_sizes(['128b'])
        .with_shots([0])
        .with_query_types([bb.MultipleChoiceQuery])
)

_train, test = split_instance_level(loader, seed=42, test_fraction=0.2)
print("Data loaded.")

Loading data...


In [None]:
import transformers
%env TOKENIZERS_PARALLELISM=true

dataset = huggingfaceify(_train[:1], test)
dataset['test'][0]

In [None]:
transformers.logging.set_verbosity_error() # type: ignore
model = AutoModelForSequenceClassification.from_pretrained("../artifacts/assessors/bert-bs32-0sh/checkpoint-1500")

# Tokenize according to specific model tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

# This "transformer-interpret" library doesn't really deal well with truncation
# of long sequences, so we'll just truncate the sequences ourselves.
# 
# Note: Encoding is a destructive process,
# so we need to work with offset_mapping instead of just reversing.
# https://github.com/huggingface/tokenizers/issues/826#issuecomment-966082496
def truncate(batch):
   tokens = tokenizer(batch["text"], padding="max_length", truncation=True, 
      return_tensors="np", return_offsets_mapping=True)

   # We use `amax` as a trick to find the last non-zero offset mapping.
   # Array dimensions are [batch size, sequence length (in tokens), 2],
   # where the last dimension is [start, end] of the token (referring to index in the string).
   # With [:,:,-1], we make it [batch size, sequence length], taking only the end offset of the token.
   # Then we take the max of each sequence, producing batch_size numbers.
   lengths = np.amax(tokens['offset_mapping'][:,:,-1], axis=1) # type: ignore
   assert len(lengths) == len(batch['text'])

   # Now we cut all the strings
   texts = [text[:end] for text, end in zip(batch["text"], lengths)]
   return {'text': texts, 'label': batch['label']}

truncated_datasets = dataset.map(truncate, batched=True)
truncated_datasets['test'][0]

In [None]:
from transformers_interpret import SequenceClassificationExplainer
from collections import defaultdict

cls_explainer = SequenceClassificationExplainer(model, tokenizer) #type: ignore

# Print header
path = Path("xai.csv")
pd.DataFrame([],columns=['word', 'contribution', 'LM_score', 'Assr_pred']).to_csv(path, index=False)

# for index, instance in enumerate(truncated_datasets['test'].select(range(255, 260))):
for index, instance in enumerate(truncated_datasets['test']):
    if index % 50 == 0: # type: ignore
        print(f"{index}/{len(truncated_datasets['test'])}")

    # EXPLAINABILITY
    text, LM_correct = instance['text'], instance['label'] # type: ignore
    exp_neg = cls_explainer(text, class_name='LABEL_0')

    frame = pd.DataFrame(exp_neg, columns=['word', 'contribution'])
    frame['LM_score'] = LM_correct
    frame['Assr_pred'] = cls_explainer.predicted_class_index
    frame.to_csv(path, mode='a', header=False, index=False)

In [None]:
# frame.groupby('word').agg( # type: ignore
#     mean_contribution=('contribution', 'mean'),
#     word_count=('word', 'count')
# ).sort_values('mean_contribution', ascending=False)