# Text Data Explanation Benchmarking: Emotion Multiclass Classification

This notebook demonstrates how to use the benchmark utility to benchmark the performance of an explainer for text data. In this demo, we showcase explanation performance for partition explainer on an Emotion Multiclass Classification model. The metrics used to evaluate are "keep positive" and "keep negative". The masker used is Text Masker.

The new benchmark utility uses the new API with MaskedModel as wrapper around user-imported model and evaluates masked values of inputs.

In [None]:
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import shap.benchmark as benchmark
import shap
import scipy as sp
import nlp
import torch

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('max_colwidth', None)

### Load Data and Model

In [None]:
train, test = nlp.load_dataset("emotion", split = ["train", "test"])

data={'text':train['text'],
     'emotion':train['label']}
        
data = pd.DataFrame(data)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("nateraw/bert-base-uncased-emotion",use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained("nateraw/bert-base-uncased-emotion")

### Class Label Mapping

In [None]:
# set mapping between label and id
id2label = model.config.id2label
label2id = model.config.label2id
labels = sorted(label2id, key=label2id.get)

### Define Score Function

In [None]:
def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=128,truncation=True) for v in x])
    attention_mask = (tv!=0).type(torch.int64)
    outputs = model(tv,attention_mask=attention_mask)[0].detach().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores) 
    return val

### Create Explainer Object

In [None]:
explainer = shap.Explainer(f,tokenizer,output_names=labels)

### Run SHAP Explanation

In [None]:
shap_values = explainer(data['text'][0:20])

### Define Metrics (Sort Order & Perturbation Method)

In [None]:
sort_order = 'positive'
perturbation = 'keep'

### Benchmark Explainer

In [None]:
sequential_perturbation = benchmark.perturbation.SequentialPerturbation(explainer.model, explainer.masker, sort_order, perturbation)
xs, ys, auc = sequential_perturbation.model_score(shap_values, data['text'][0:20])
sequential_perturbation.plot(xs, ys, auc)

In [None]:
sort_order = 'negative'
perturbation = 'keep'

In [None]:
sequential_perturbation = benchmark.perturbation.SequentialPerturbation(explainer.model, explainer.masker, sort_order, perturbation)
xs, ys, auc = sequential_perturbation.model_score(shap_values, data['text'][0:20])
sequential_perturbation.plot(xs, ys, auc)