*Licensed under the MIT License.*

# Text Classification of MultiNLI Sentences using BERT

# Before You Start

> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`**. This will run the notebook on a small subset of the data and a use a smaller number of epochs. 

If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `BATCH_SIZE` and `MAX_LEN`, but note that model performance will be compromised. 

In [1]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = True

In [2]:
## Does not support multi GPU training, use only 1st
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
import sys
sys.path.append("../../")
import os
import json
import pandas as pd
import numpy as np
import scrapbook as sb
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn

from interpret_text.experimental.common.utils_bert import Language, Tokenizer, BERTSequenceClassifier
from interpret_text.experimental.common.timer import Timer

from notebooks.test_utils.utils_mnli import load_mnli_pandas_df

In [4]:
from interpret_text.experimental.unified_information import UnifiedInformationExplainer

## Introduction
In this notebook, we fine-tune and evaluate a pretrained [BERT](https://arxiv.org/abs/1810.04805) model on a subset of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset.

We use a [sequence classifier](https://github.com/microsoft/nlp/blob/master/utils_nlp/models/bert/sequence_classification.py) that wraps [Hugging Face's PyTorch implementation](https://github.com/huggingface/pytorch-pretrained-BERT) of Google's [BERT](https://github.com/google-research/bert).

### Set parameters
Here we set some parameters that we use for our modeling task.

In [5]:
TRAIN_DATA_FRACTION = 1
TEST_DATA_FRACTION = 1
NUM_EPOCHS = 1

if QUICK_RUN:
    TRAIN_DATA_FRACTION = 0.001
    TEST_DATA_FRACTION = 0.001
    NUM_EPOCHS = 1

if torch.cuda.is_available():
    BATCH_SIZE = 1
else:
    BATCH_SIZE = 8

DATA_FOLDER = "./temp"
BERT_CACHE_DIR = "./temp"
LANGUAGE = Language.ENGLISH
TO_LOWER = True
MAX_LEN = 150
BATCH_SIZE_PRED = 512
TRAIN_SIZE = 0.6
LABEL_COL = "genre"
TEXT_COL = "sentence1"

## Read Dataset
We start by loading a subset of the data. The following function also downloads and extracts the files, if they don't exist in the data folder.

The MultiNLI dataset is mainly used for natural language inference (NLI) tasks, where the inputs are sentence pairs and the labels are entailment indicators. The sentence pairs are also classified into *genres* that allow for more coverage and better evaluation of NLI models.

For our classification task, we use the first sentence only as the text input, and the corresponding genre as the label. We select the examples corresponding to one of the entailment labels (*neutral* in this case) to avoid duplicate rows, as the sentences are not unique, whereas the sentence pairs are.

In [6]:
df = load_mnli_pandas_df(DATA_FOLDER, "train")
df = df[df["gold_label"]=="neutral"]  # get unique sentences

These are the five genres in the dataset:

In [7]:
df[[LABEL_COL, TEXT_COL]].head()

Unnamed: 0,genre,sentence1
0,government,Conceptually cream skimming has two basic dime...
4,telephone,yeah i tell you what though if you go price so...
6,travel,But a few Christian mosaics survive above the ...
12,slate,It's not that the questions they asked weren't...
13,travel,"Thebes held onto power until the 12th Dynasty,..."


In [8]:
df[LABEL_COL].value_counts()

telephone     27783
government    25784
travel        25783
fiction       25782
slate         25768
Name: genre, dtype: int64

We start by splitting the data for training and testing, and then we encode the class labels:

In [9]:
# split
df_train, df_test = train_test_split(df, train_size = TRAIN_SIZE, random_state=0)
df_train = df_train.reset_index(drop=True)
df_test = df_test.reset_index(drop=True)

if QUICK_RUN:
    df_train = df_train.sample(frac=TRAIN_DATA_FRACTION).reset_index(drop=True)
    df_test = df_test.sample(frac=TEST_DATA_FRACTION).reset_index(drop=True)

In [10]:
# encode labels
label_encoder = LabelEncoder()
labels_train = label_encoder.fit_transform(df_train[LABEL_COL])
labels_test = label_encoder.transform(df_test[LABEL_COL])

num_labels = len(np.unique(labels_train))

In [11]:
print("Number of unique labels: {}".format(num_labels))
print("Number of training examples: {}".format(df_train.shape[0]))
print("Number of testing examples: {}".format(df_test.shape[0]))

Number of unique labels: 5
Number of training examples: 79
Number of testing examples: 52


## Tokenize and Preprocess

Before we start training, we tokenize the text documents and convert them to lists of tokens. The following steps instantiate a `BERT tokenizer` given the language, and tokenize the text of the training and testing sets.

In [12]:
tokenizer = Tokenizer(LANGUAGE, to_lower=TO_LOWER, cache_dir=BERT_CACHE_DIR)

tokens_train = tokenizer.tokenize(list(df_train[TEXT_COL]))
tokens_test = tokenizer.tokenize(list(df_test[TEXT_COL]))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 4627.15it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:00<00:00, 5544.64it/s]


In addition, we perform the following preprocessing steps in the cell below:
- Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary
- Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence, respectively
- Pad or truncate the token lists to the specified max length. In this case, `MAX_LEN = 150`
- Return mask lists that indicate the paddings' positions
- Return token type id lists that indicate which sentence the tokens belong to (not needed for one-sequence classification)

*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*

In [13]:
tokens_train, mask_train, _ = tokenizer.preprocess_classification_tokens(tokens_train, MAX_LEN)
tokens_test, mask_test, _ = tokenizer.preprocess_classification_tokens(tokens_test, MAX_LEN)

## Sequence Classifier Model
Next, we use a sequence classifier that loads a pre-trained BERT model, given the language and number of labels.

In [14]:
classifier = BERTSequenceClassifier(language=LANGUAGE, num_labels=num_labels, cache_dir=BERT_CACHE_DIR)

## Train Model
We train the classifier using the training set. This involves fine-tuning the BERT Transformer and learning a linear classification layer on top of that:

In [15]:
with Timer() as t:
    classifier.fit(token_ids=tokens_train,
                    input_mask=mask_train,
                    labels=labels_train,    
                    num_epochs=NUM_EPOCHS,
                    batch_size=BATCH_SIZE,    
                    verbose=True)    
print("[Training time: {:.3f} hrs]".format(t.interval / 3600))

t_total value of -1 results in schedule not being applied
Iteration:   0%|                                                                                                                                                      | 0/79 [00:00<?, ?it/s]This overload of add_ is deprecated:
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at /opt/conda/conda-bld/pytorch_1670525539683/work/torch/csrc/utils/python_arg_parser.cpp:1420.)
Iteration: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:06<00:00, 12.07it/s]

[Training time: 0.002 hrs]





## Score Model
We score the test set using the trained classifier:

In [16]:
preds = classifier.predict(token_ids=tokens_test, 
                           input_mask=mask_test, 
                           batch_size=BATCH_SIZE_PRED)

Iteration: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.26it/s]


## Evaluate Model
Finally, we compute the overall accuracy, precision, recall, and F1 metrics on the test set. We also look at the metrics for eact of the genres in the the dataset. 

In [17]:
report = classification_report(labels_test, preds, target_names=label_encoder.classes_, output_dict=True) 
accuracy = accuracy_score(labels_test, preds)
print("accuracy: {}".format(accuracy))
print(json.dumps(report, indent=4, sort_keys=True))

accuracy: 0.15384615384615385
{
    "accuracy": 0.15384615384615385,
    "fiction": {
        "f1-score": 0.0,
        "precision": 0.0,
        "recall": 0.0,
        "support": 16.0
    },
    "government": {
        "f1-score": 0.0,
        "precision": 0.0,
        "recall": 0.0,
        "support": 11.0
    },
    "macro avg": {
        "f1-score": 0.05333333333333333,
        "precision": 0.03076923076923077,
        "recall": 0.2,
        "support": 52.0
    },
    "slate": {
        "f1-score": 0.26666666666666666,
        "precision": 0.15384615384615385,
        "recall": 1.0,
        "support": 8.0
    },
    "telephone": {
        "f1-score": 0.0,
        "precision": 0.0,
        "recall": 0.0,
        "support": 8.0
    },
    "travel": {
        "f1-score": 0.0,
        "precision": 0.0,
        "recall": 0.0,
        "support": 9.0
    },
    "weighted avg": {
        "f1-score": 0.041025641025641026,
        "precision": 0.02366863905325444,
        "recall": 0.15384615

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


In [18]:
# for testing
sb.glue("accuracy", accuracy)
sb.glue("precision", report["macro avg"]["precision"])
sb.glue("recall", report["macro avg"]["recall"])
sb.glue("f1", report["macro avg"]["f1-score"])

## Explain Model

In [19]:
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")

classifier.model.to(device)
for param in classifier.model.parameters():
    param.requires_grad = False
classifier.model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
   

In [20]:
interpreter_unified = UnifiedInformationExplainer(model=classifier.model, 
                                 train_dataset=list(df_train[TEXT_COL]), 
                                 device=device, 
                                 target_layer=14, 
                                 classes=label_encoder.classes_)

In [21]:
idx = 7
text = df_test[TEXT_COL][idx]
true_label = df_test[LABEL_COL][idx]
predicted_label = label_encoder.inverse_transform([preds[idx]])
print(text, true_label, predicted_label)

sure but i'll bet your values are a lot higher you know your your and your self esteem and the way the way you uh you know think about things is probably a lot more common sense telephone ['slate']


In [22]:
explanation_unified = interpreter_unified.explain_local(text, true_label)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 231508/231508 [00:00<00:00, 686562.48B/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 4802.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:11<00:00, 12.56it/s]


In [31]:
explanation_unified.data()

{'mli': []}

In [34]:
print(tuple(zip(
    explanation_unified.get_ranked_local_names(),
    explanation_unified.get_ranked_local_values()
)))

(('common', 0.4707134962081909), ('about', 0.45109739899635315), ('and', 0.4103332757949829), ('way', 0.40106984972953796), ('esteem', 0.38087254762649536), ('you', 0.37914803624153137), ('lot', 0.3749009668827057), ('the', 0.36465057730674744), ('sure', 0.3559378981590271), ('the', 0.3449525833129883), ('higher', 0.3395047187805176), ('but', 0.3224298059940338), ('self', 0.31688380241394043), ('a', 0.3088061511516571), ('a', 0.30677908658981323), ('ll', 0.29877227544784546), ('more', 0.2849915623664856), ('you', 0.2829703986644745), ('think', 0.2810496389865875), ('probably', 0.2701112926006317), ('know', 0.26355981826782227), ('your', 0.2596324682235718), ('your', 0.25898277759552), ('way', 0.2571272552013397), ('your', 0.25656571984291077), ("'", 0.24042871594429016), ('you', 0.23976299166679382), ('bet', 0.23569077253341675), ('lot', 0.2264755219221115), ('your', 0.21608087420463562), ('is', 0.2147148996591568), ('uh', 0.20808209478855133), ('are', 0.194368377327919), ('know', 0.18

In [None]:
print(explanation_unified.get_ranked_local_names())

## Visualize Explanation

In [23]:
from interpret_text.experimental.widget import ExplanationDashboard

In [24]:
ExplanationDashboard(explanation_unified)

ExplanationWidget(value={'text': ['sure', 'but', 'i', "'", 'll', 'bet', 'your', 'values', 'are', 'a', 'lot', '…

<interpret_text.experimental.widget.ExplanationDashboard.ExplanationDashboard at 0x71ae22629a00>