[GitHub](https://github.com/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb)
# Fine-tuning Electra on SST-2 and interpreting with Integrated Gradients

---


This notebook contains an example of [fine-tuning](https://huggingface.co/transformers/training.html) an [Electra](https://huggingface.co/transformers/model_doc/electra.html) model on the [GLUE SST-2](https://nlp.stanford.edu/sentiment/index.html) dataset. After fine-tuning, the [Integrated Gradients](https://arxiv.org/pdf/1703.01365.pdf) **interpretability** method is applied to compute tokens' attributions for each target class. 
* We will instantiate a pre-trained Electra model from the [Transformers](https://huggingface.co/transformers/) library. 
* The data is downloaded from the [nlp](https://huggingface.co/nlp/) library. The input text is tokenized with [ElectraTokenizerFast](https://huggingface.co/transformers/model_doc/electra.html#electratokenizerfast) tokenizer backed by HF [tokenizers](https://huggingface.co/transformers/main_classes/tokenizer.html) library.
* **Fine-tuning** for sentiment analysis is handled by the [Trainer](https://huggingface.co/transformers/main_classes/trainer.html) class. 
* After fine-tuning, the [Integrated Gradients](https://captum.ai/api/integrated_gradients.html) interpretability algorithm will assign importance scores to
input tokens. We will use a **PyTorch** implementation from the [Captum](https://captum.ai/) library. 
  - The algorithm requires providing a reference sample (a baseline) since importance attribution is performed based on the model's output, as inputs change from reference values to the actual sample. 
  - The Integrated Gradients method satisfies the **completeness** property. We will look at the sum of attributions for a sample and show that the sum approximates (explains) prediction's shift from the baseline value. 
* The final sections of the notebook contain a color-coded **visualization** of attribution results made with *captum.attr.visualization* library.

The notebook is based on the [Hugging Face documentation](https://huggingface.co/) and the implementation of Integrated Gradients attribution methods is adapted from the Captum.ai
[Interpreting BERT Models (Part 1)](https://captum.ai/tutorials/Bert_SQUAD_Interpret).

## Installation & imports

---

In [1]:
!pip install transformers
!pip install pyarrow
!pip install nlp
!pip install datasets

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [26]:
from typing import Dict

import matplotlib.pyplot as plt
import nlp
import numpy as np
import pandas as pd
import torch
import transformers
from torch.utils.data import Dataset
from transformers import (ElectraForSequenceClassification,
                          ElectraTokenizerFast, EvalPrediction, InputFeatures,
                          Trainer, TrainingArguments, glue_compute_metrics)

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (precision_recall_fscore_support, accuracy_score)

transformers.__version__

'4.27.3'

## Model

---

Sentiment analysis is a classification task that requires assigning a label to an entire sentence (sequence). We will use a PyTorch implementation of [`ElectraForSequenceClassification`](https://huggingface.co/transformers/model_doc/electra.html#electraforsequenceclassification) from the Hugging Face library. A matching tokenizer implemented in the [`ElectraTokenizerFast`](https://huggingface.co/transformers/model_doc/electra.html#electratokenizerfast) class will handle tokenization.

In [4]:
model = ElectraForSequenceClassification.from_pretrained(
    "google/electra-small-discriminator", num_labels=2)

tokenizer = ElectraTokenizerFast.from_pretrained(
    "google/electra-small-discriminator", do_lower_case=True)                   

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

## Data

---


**Download**

Let's now download the SST-2 dataset from the nlp library and take a brief look at it. It contains short movie reviews labelled for sentiment: 0 for negative and 1 for a positive review. The data is split into training, validation and test set. The labels for the test set are kept undisclosed.

In [5]:
# Load the SST2 dataset from the nlp library
dataset = load_dataset('ucberkeley-dlab/measuring-hate-speech')['train']
dataset = pd.DataFrame(dataset)

# Preprocess the data by removing unnecessary columns
dataset = dataset[['text', 'hate_speech_score']]
dataset['label'] = dataset['hate_speech_score'].apply(lambda x: 1 if x >= 0 else 0)
dataset = dataset[['text', 'label']]

train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

# Look at the labels
print("Training set labels: {}".format(set(train_data["label"])))
print("Test set labels: {}".format(set(test_data["label"])))

# Explore the dataset
df = pd.DataFrame({"sentence": train_data["text"],
                   "label": train_data["label"]})
pd.options.display.max_colwidth = 0
df.head()

Found cached dataset parquet (/home/r.bhairagond/.cache/huggingface/datasets/ucberkeley-dlab___parquet/ucberkeley-dlab--measuring-hate-speech-c32713cabe528196/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

Training set labels: {0, 1}
Test set labels: {0, 1}


Unnamed: 0,sentence,label
104822,"GET RID OF THOSE FUCKING MUSLIMS! the nasty ones that won't assimilate, rape women, beat children, threaten democracy, plan terror attacks, assault innocent secular people and force their disgusting religion on everyone else by way of violence. Assimilate OR GET THE FUCK OUT OR BE PREPARED TO DIE!",1
71077,"when addressing a crowd of white women @RepBarbaraLee says ""we have the resources"" to pay teachers more if we cut the $717B defense budget. when will #ADOS get this kind of advocacy @OfficialCBC ? why can't we find this same creativity to fund reparations? URL",0
128973,I love that you enjoy being trans.,0
87800,Super sexy redhead babes @bbwjulieginger and @ElizaAllure show just how sexy they can be in this hot lesbian scene now at URL URL,0
117885,Learn the fucking language you fucking useless immigrant.,1


**Create dataset**

We will now create a custom [map-style PyTorch dataset](https://pytorch.org/docs/stable/data.html#map-style-datasets) to serve model's key-value parameters in a seamless manner. 

The `TrainerDataset` class is derived from `torch.utils.data.Dataset`. The overridden `__getitem__` method yields an instance of the `InputFeatures` class. 

Conversion to torch tensors and placing on cuda/cpu is handled by the `trainer` object used for fine-tuning.

In [6]:
class TrainerDataset(Dataset):
    def __init__(self, inputs, targets, tokenizer):
        self.inputs = inputs
        self.targets = targets
        self.tokenizer = tokenizer

        # Tokenize the input
        self.tokenized_inputs = tokenizer(inputs, padding=True)   

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return InputFeatures(
            input_ids=self.tokenized_inputs['input_ids'][idx],
            token_type_ids=self.tokenized_inputs['token_type_ids'][idx],
            attention_mask=self.tokenized_inputs['attention_mask'][idx],
            label=self.targets[idx])         

We need to create the training and validation datasets. As GLUE SST-2 dataset does not disclose labels for the test set, so we will be using validation data for testing.

In [7]:
train_dataset = TrainerDataset(train_data["text"].tolist(),
                               train_data["label"].tolist(), tokenizer)
eval_dataset = TrainerDataset(test_data["text"].tolist(),
                              test_data["label"].tolist(), tokenizer)

## Fine-tuning

---

Fine-tuning with a `Trainer` class instance requires setting training arguments and creating a `trainer` object. The model, as well as training and validation datasets, are passed to the trainer's constructor, along with training arguments. We will pass one more parameter, the `compute_metrics` function, to have the `trainer` calculate the *Accuracy* metric after fine-tuning. The `Trainer` class takes care of conversion to tensor format and placement on a cpu/gpu device.

### Set parameters

Training parameters have been taken from the [Electra Github](https://github.com/google-research/electra/blob/master/configure_finetuning.py) repository or are default values. 

In [27]:
# Set seed for reproducibility
np.random.seed(123)
torch.manual_seed(123)

training_args = TrainingArguments(
    output_dir="./models/model_electra",
    num_train_epochs=3,  # 1 (1 epoch gives slightly lower accuracy)
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=32,
    evaluation_strategy='epoch',     
    dataloader_drop_last=True,  # Make sure all batches are of equal size
)


def compute_metrics(p: EvalPrediction) -> Dict:
    preds = np.argmax(p.predictions, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(p.label_ids, preds, average='weighted')
    acc = accuracy_score(p.label_ids, preds)
    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }


# Instantiate the Trainer class
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics)

### Run fine-tuning

Run the `train` method of the `trainer` object to fine-tune the model on the SST-2 dataset.

In [28]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1 Score
1,0.0434,0.292543,0.950465,0.950479,0.950465,0.950422
2,0.0648,0.23179,0.954596,0.955164,0.954596,0.954471
3,0.0591,0.200842,0.963559,0.963551,0.963559,0.963552


TrainOutput(global_step=10164, training_loss=0.05859436345166083, metrics={'train_runtime': 19234.7462, 'train_samples_per_second': 16.914, 'train_steps_per_second': 0.528, 'total_flos': 3569565926158848.0, 'train_loss': 0.05859436345166083, 'epoch': 3.0})

### Evaluate 

The metric used for evaluation of the Stanford Sentiment Treebank (SST) data is *Accuracy*. The result is returned by the `Trainer` class object used for fine-tuning. 

In [32]:
model_result = trainer.evaluate()
for key, value in model_result.items():
    print("{}: {}".format(key, value))

eval_loss: 0.20084160566329956
eval_accuracy: 0.9635585718501033
eval_precision: 0.963550516987135
eval_recall: 0.9635585718501033
eval_f1_score: 0.9635524213192739
eval_runtime: 656.4294
eval_samples_per_second: 41.302
eval_steps_per_second: 5.163
epoch: 3.0
