# Assignment 6: CheckList a BERT-based Sentiment Classifier (100 Points)

Instructor: Ziyu Yao; Class: CS478 Fall 2024

This assignment will extend from the previous sentiment classification assignments. The goal is to analyze a BERT-based sentiment classifier, checking when it may fail and gain more sense about how well it works.

To this end, this assignment will have two parts:
- **Part 1 (70 Points):** We will use the CheckList tool (https://github.com/marcotcr/checklist) to _perturb_ test examples and see how robust the classifier is;
- **Part 2 (30 Points):** Your creative work -- it's your turn to propose new tests and reveal additional potentially issues with the classifier!

Follow the instructions on the assignment PDF for submission.

To get started, we will first install the checklist library and the spacy model `en_core_web_sm`.

In [23]:
!pip install checklist



In [24]:
!python -m spacy download en_core_web_sm # uncomment it and run it if you haven't installed `spacy` and its model `en_core_web_sm`

Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


(After installing the spacy package, if you see a message of "Restart to reload dependencies", you may ignore it.)

Import the following libraries:

In [25]:
import torch

from typing import List, Dict
import random
import numpy as np
from collections import Counter
import os

# transformers
import transformers

# Set up overall seed
seed = 12345
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [26]:
import spacy

import checklist
from checklist.editor import Editor
from checklist.perturb import Perturb
from checklist.test_types import MFT, INV, DIR

Set up device as before:

In [27]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Part 0: Data and Model Setup

The following code was adapted from the previous assignment for storing the sentiment data and creating the BERT-based classifier. There's nothing to fill out but you want to study the code.

In [28]:
class SentimentExample:
    """
    Data wrapper for a single example for sentiment analysis.

    Attributes:
        sentence (string): a string-type sentence (untokenized)
        label (int): 0 or 1 (0 = negative, 1 = positive)
        word_indices (List[int]): list of word indices in the vocab, which will generated by the `indexing_sentiment_examples` method
    """

    def __init__(self, sentence, label):
        self.sentence = sentence
        self.label = label
        self.words = None
        self.word_indices = None

    def __repr__(self):
        return self.sentence + "; label=" + repr(self.label)

    def __str__(self):
        return self.__repr__()

def read_sentiment_examples(infile: str) -> List[SentimentExample]:
    """
    Reads sentiment examples in the format [0 or 1]<TAB>[raw sentence]; tokenizes and cleans the sentences and forms
    SentimentExamples. Note that all words have been lowercased.

    :param infile: file to read from
    :return: a list of SentimentExamples parsed from the file
    """
    f = open(infile)
    exs = []
    for line in f:
        if len(line.strip()) > 0:
            line = line.strip()
            fields = line.split("\t")
            if len(fields) != 2:
                fields = line.split()
                label = 0 if "0" in fields[0] else 1
                sent = " ".join(fields[1:])
            else:
                # Slightly more robust to reading bad output than int(fields[0])
                label = 0 if "0" in fields[0] else 1
                sent = fields[1]
            sent = sent.lower() # lowercasing
            exs.append(SentimentExample(sent, label))
    f.close()
    return exs

def calculate_metrics(golds: List[int], predictions: List[int], print_only: bool=False):
    """
    Calculate evaluation statistics comparing golds and predictions, each of which is a sequence of 0/1 labels.
    Returns accuracy, precision, recall, and F1.

    :param golds: gold labels
    :param predictions: pred labels
    :param print_only: set to True if printing the stats without returns
    :return: accuracy, precision, recall, and F1 (all floating numbers), or None (when print_only is True)
    """
    num_correct = 0
    num_pos_correct = 0
    num_pred = 0
    num_gold = 0
    num_total = 0
    if len(golds) != len(predictions):
        raise Exception("Mismatched gold/pred lengths: %i / %i" % (len(golds), len(predictions)))
    for idx in range(0, len(golds)):
        gold = golds[idx]
        prediction = predictions[idx]
        if prediction == gold:
            num_correct += 1
        if prediction == 1:
            num_pred += 1
        if gold == 1:
            num_gold += 1
        if prediction == 1 and gold == 1:
            num_pos_correct += 1
        num_total += 1
    acc = float(num_correct) / num_total
    prec = float(num_pos_correct) / num_pred if num_pred > 0 else 0.0
    rec = float(num_pos_correct) / num_gold if num_gold > 0 else 0.0
    f1 = 2 * prec * rec / (prec + rec) if prec > 0 and rec > 0 else 0.0

    print("Accuracy: %i / %i = %f" % (num_correct, num_total, acc))
    print("Precision (fraction of predicted positives that are correct): %i / %i = %f" % (num_pos_correct, num_pred, prec)
          + "; Recall (fraction of true positives predicted correctly): %i / %i = %f" % (num_pos_correct, num_gold, rec)
          + "; F1 (harmonic mean of precision and recall): %f" % f1)

    if not print_only:
        return acc, prec, rec, f1

In [29]:
# The following code is for Google Colab users only.
from google.colab import drive
drive.mount('/content/drive')
dev_path = "/content/drive/My Drive/data/dev.txt"

# If you run the notebook locally, please set up the path to your dev set accordingly.
# dev_path = "data/dev.txt"

# Load dev exs
dev_exs = read_sentiment_examples(dev_path)
print(repr(len(dev_exs)) + " dev examples")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
872 dev examples


Now, load in the transformer library and the BERT-based classifier checkpoint.

In [30]:
from transformers import BertTokenizer, BertForSequenceClassification
pretrained_checkpoint = "textattack/bert-base-uncased-yelp-polarity"
tokenizer = BertTokenizer.from_pretrained(pretrained_checkpoint, use_fast=True)
model = BertForSequenceClassification.from_pretrained(pretrained_checkpoint).to(device)

Next, helper functions for batch prediction and evaluation using the loaded classifier:

In [31]:
import torch.nn.functional as F

def batch_predict_prob(sentences, return_probs=False): # for a list of test sentences
    batch_inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(device)
    batch_outputs = model(**batch_inputs)
    batch_logits = batch_outputs.logits
    pred_probs = F.softmax(batch_logits, dim=-1).to('cpu').detach().numpy()

    if return_probs:
        return pred_probs # probability matrix of shape (num_sents x 2)
    else:
        return pred_probs.argmax(-1) # prediction labels of shape (num_sents)

def evaluate(exs: List[SentimentExample]):
    """
    Evaluates a given classifier on the given examples
    :param classifier: classifier to evaluate
    :param exs: the list of SentimentExamples to evaluate on
    :param return_metrics: set to True if returning the stats
    :return: None (but prints output)
    """
    all_labels = []
    all_preds = []

    batch_size = 32

    ex_idx = 0
    while ex_idx < len(exs):
        batch_data = exs[ex_idx:ex_idx+batch_size]
        batch_sentences = [ex.sentence for ex in batch_data]
        batch_labels = [ex.label for ex in batch_data]

        preds = batch_predict_prob(batch_sentences)

        all_labels += list(batch_labels)
        all_preds += list(preds)

        ex_idx += batch_size

    calculate_metrics(all_labels, all_preds, print_only=True)
    return all_preds, all_labels

In [32]:
all_preds, all_labels = evaluate(dev_exs)

Accuracy: 698 / 872 = 0.800459
Precision (fraction of predicted positives that are correct): 363 / 456 = 0.796053; Recall (fraction of true positives predicted correctly): 363 / 444 = 0.817568; F1 (harmonic mean of precision and recall): 0.806667


You should be able to see an accuracy of 0.800 for the loaded classifier tested on our movie review sentiment dev set.

## Part 1: Invariance Test (INV) using CheckList (70 points)

In this part, you will use the CheckList tool to analyze the loaded classifier! To help you understand the concepts of check list tests, you are recommended to read its paper "Beyond Accuracy: Behavioral Testing of NLP Models with CheckList" (https://arxiv.org/pdf/2005.04118.pdf, which won the best paper award in ACL 2020).

While CheckList contains three test types, in this assignment, we will look at only the second one, i.e., invariance test (INV). This test aims to apply _label-preserving_ perturbations to existing test inputs and checks if the model prediction can remain the same.

**You are recommended to study the tutorials of CheckList from here https://github.com/marcotcr/checklist#tutorials, especially the part about data perturbation (https://github.com/marcotcr/checklist/blob/master/notebooks/tutorials/2.%20Perturbing%20data.ipynb, go to "General-purpose perturbations").**

**Requirements:**
- Perform two types of perturbations: punctuation perturbation and typos.
- For each type, report the model performance on the perturbated examples:
    - For how many of the dev-set examples does the model flip its label (i.e., pos -> neg or neg -> pos)?
    - For dev examples on which the classifier initially made a correct prediction, how many of them get their labels flipped after perturbation?
- Describe and discuss your findings.

(Note: you are provided with sample code to calculate these statistics; however, it is assumed that you will create perturbation data following the same format as the sample code. That said, you are also free to modify the code as long as you can collect these statistics!)

### INV - Punctuations (35 points)

<font color='blue'>Complete the following code for creating INV examples based on perturbing punctuations in the dev set. </font>

In [33]:
editor = Editor()

data = [ex.sentence for ex in dev_exs]

nlp = spacy.load('en_core_web_sm')
pdata = list(nlp.pipe(data))

# TODO: complete the assignment for creating
# a set of dev-set examples with perturbed punctuations
# using the Perturb class
ret1 = Perturb.perturb(pdata, Perturb.punctuation)

Print the perturbation results for the first three sentences on the dev set:

In [34]:
ret1.data[:3]

[["it 's a lovely film with lovely performances by buy and accorsi .",
  "it 's a lovely film with lovely performances by buy and accorsi",
  "it 's a lovely film with lovely performances by buy and accorsi."],
 ["and if you 're not nearly moved to tears by a couple of scenes , you 've got ice water in your veins .",
  "and if you 're not nearly moved to tears by a couple of scenes , you 've got ice water in your veins",
  "and if you 're not nearly moved to tears by a couple of scenes , you 've got ice water in your veins."],
 ['a warm , funny , engaging film .',
  'a warm , funny , engaging film',
  'a warm , funny , engaging film.']]

Now, let's apply the same classifier to the perturbed examples:

In [35]:
assert len(ret1.data) == len(dev_exs)

all_labels = []
all_preds = []

for ex_idx, ex in enumerate(dev_exs):
    assert ex.sentence == ret1.data[ex_idx][0]
    gold_label = ex.label
    preds = batch_predict_prob(ret1.data[ex_idx])

    all_labels.append(gold_label)
    all_preds.append(preds)

The following helper functions calculate the two statistical numbers (i.e., how many predictions were flipped in total, and how many *initially correct* predictions were flipped) required by this assignment.

In [36]:
# helper functions

# label flips
def count_label_flips(all_preds):
    indices = []
    for idx, preds in enumerate(all_preds):
        ori = preds[0]
        for pert in preds[1:]:
            if pert != ori:
                indices.append(idx)
                break

    print("%s examples got their label flipped!" % len(indices))
    return indices

# label flips for initially correct predictions
def count_correct_label_flips(all_labels, all_preds):
    indices = []
    for idx, preds in enumerate(all_preds):
        gold = all_labels[idx]
        ori = preds[0]
        if gold != ori:
            continue
        for pert in preds[1:]:
            if pert != ori:
                indices.append(idx)
                break

    print("%s examples got their *correct* label flipped!" % len(indices))
    return indices

In [37]:
indices1 = count_label_flips(all_preds)
indices2 = count_correct_label_flips(all_labels, all_preds)

26 examples got their label flipped!
15 examples got their *correct* label flipped!


Let's print all predictions whose labels got flipped due to the punctuation perturbation. In what prints out, you will see:
- **Original**: showing the original sentence and the ground-truth label;
- **Perturbations**: a list of three sentences, which are the original sentence and two perturbed sentences;
- **Preds**: the model's predictions for the three sentences.

In [38]:
for idx in indices1:
    print("Original:", dev_exs[idx])
    print("Perturbations:", ret1.data[idx])
    print("Preds:", all_preds[idx])
    print("=" * 10)

Original: without ever becoming didactic , director carlos carrera expertly weaves this novelistic story of entangled interrelationships and complex morality .; label=1
Perturbations: ['without ever becoming didactic , director carlos carrera expertly weaves this novelistic story of entangled interrelationships and complex morality .', 'without ever becoming didactic , director carlos carrera expertly weaves this novelistic story of entangled interrelationships and complex morality', 'without ever becoming didactic , director carlos carrera expertly weaves this novelistic story of entangled interrelationships and complex morality.']
Preds: [1 0 1]
Original: it deserves to be seen by anyone with even a passing interest in the events shaping the world beyond their own horizons .; label=1
Perturbations: ['it deserves to be seen by anyone with even a passing interest in the events shaping the world beyond their own horizons .', 'it deserves to be seen by anyone with even a passing interest

**Describe:** Read through the printed examples carefully and include two most interesting ones to the assignment PDF.

**Discuss**: Changing or removing the punctuation of a sentence may or may not change the sentence's original semantic meaning. Use your common sense to judge the model's predictions, and discuss:
- Should all labels *NOT* be flipped? Or do you see cases where the prediction labels are hard to decide or should actually be flipped when the punctuation is changed?
- Overall, do you think your model is robust to punctuation perturbation?

### INV - Typos (35 points)

<font color='blue'>Complete the following code for creating INV examples based on perturbing typos in the dev set. </font>

Now, similarly implement the typo perturbation, which intentionally replaces words in the original sentence with typos.

In [39]:
# TODO: complete the assignment for creating
# a set of dev-set examples with perturbed typos
# using the Perturb class
ret2 = Perturb.perturb(data, Perturb.add_typos)

Print the perturbation results for the first three sentences on the dev set:

In [40]:
ret2.data[:3]

[["it 's a lovely film with lovely performances by buy and accorsi .",
  "it 's a lovely film with lovely pefrormances by buy and accorsi ."],
 ["and if you 're not nearly moved to tears by a couple of scenes , you 've got ice water in your veins .",
  "and if you 're not nearly movde to tears by a couple of scenes , you 've got ice water in your veins ."],
 ['a warm , funny , engaging film .', 'aw arm , funny , engaging film .']]

As in the punctuation perturbation analysis, we will apply the classifier to make predictions on the perturbed examples, and then calculate the number of label flips overall and for initially correct predictions.

In [41]:
assert len(ret2.data) == len(dev_exs)

all_labels2 = []
all_preds2 = []

for ex_idx, ex in enumerate(dev_exs):
    assert ex.sentence == ret2.data[ex_idx][0]
    gold_label = ex.label
    preds = batch_predict_prob(ret2.data[ex_idx])

    all_labels2.append(gold_label)
    all_preds2.append(preds)

In [42]:
indices1 = count_label_flips(all_preds2)
indices2 = count_correct_label_flips(all_labels2, all_preds2)

96 examples got their label flipped!
59 examples got their *correct* label flipped!


Now, let's look at all the examples whose labels were flipped after introducing typos. The output format of the following code block is the same as in the prior analysis, except that by default, the library produces only one perturbed sentence for each original sentence on the dev set. Therefore, you should see onlye two sentences in `Perturbations` and two prediction labels in `Preds`.

In [43]:
for idx in indices1:
    print("Original:", dev_exs[idx])
    print("Perturbations:", ret2.data[idx])
    print("Preds:", all_preds2[idx])
    print("=" * 10)

Original: nothing 's at stake , just a twisty double-cross you can smell a mile away -- still , the derivative nine queens is lots of fun .; label=1
Perturbations: ["nothing 's at stake , just a twisty double-cross you can smell a mile away -- still , the derivative nine queens is lots of fun .", "ntohing 's at stake , just a twisty double-cross you can smell a mile away -- still , the derivative nine queens is lots of fun ."]
Preds: [1 0]
Original: the band 's courage in the face of official repression is inspiring , especially for aging hippies -lrb- this one included -rrb- .; label=1
Perturbations: ["the band 's courage in the face of official repression is inspiring , especially for aging hippies -lrb- this one included -rrb- .", "the band 's courage in the face of official repression is isnpiring , especially for aging hippies -lrb- this one included -rrb- ."]
Preds: [1 0]
Original: writer\/director joe carnahan 's grimy crime drama is a manual of precinct cliches , but it moves f

**Describe:** Read through the printed examples carefully and include two most interesting ones to the assignment PDF.

**Discuss**: Overall, do you think your model is robust to typos? Discuss any other findings, e.g., for cases where the model predictions got flipped, do the typos exist more commonly in nouns or verbs, or in words with other part-of-speech properties?

## Part 2: Explore A Different Test or Analysis (30 points)

Propose a different test or analysis other than the punctuation and typo perturbation! You can feel free to use any libraries or tools, or maybe just reuse the CheckList.

For example, here are a few analyses you can consider:
- Trying another INV perturbation category (e.g., named entity change) or a different CheckList test type (MFT or DIR);
- Identifying potential ethical problems (e.g., gender or racial bias) by testing the sentiment classifier on test exmaples created by yourself;
- Testing the multilingual capabilities of the sentiment classifier (for this you should switch to a multilingual BERT sentiment classifier from the model hub https://huggingface.co/models?search=multilingual%20sentiment);
- Visualizing the BERT attention (using BertViz https://github.com/jessevig/bertviz or other tools) of this classifier: does the model attend to the right contents when making a prediction (i.e., "right for the right reason")?


**Requirements:**
- Propose one different test or analysis and implement the procedure in this notebook (please add necessary comments so grader/instructor can understand your test/analysis);
- Describe the procedure, your results, and the findings on the report PDF.

*Note: While you can perform any test or analysis you have interest, to claim full credits, the analysis procedure should be careful and be well justified.*

In [46]:
# Minimum Functionality Tests (MFT)
from checklist.pred_wrapper import PredictorWrapper
from pattern.en import sentiment
import numpy as np

editor = Editor()

# prepare the data
sentData = []
labels = []
numSamples = 200
for ex_idx, ex in enumerate(dev_exs):
    sentData.append(ex.sentence)
    labels.append(ex.label)
sentData = sentData[:numSamples]
labels = labels[:numSamples]

# Define MFT test
test = MFT(
    sentData,
    labels=labels,
    name='Simple negation',
    capability='Negation',
    description='Very simple negations.'
)
def predict_probabilities(inputs):
    return batch_predict_prob(inputs, True)
wrapped_probs = PredictorWrapper.wrap_softmax(predict_probabilities)

# # Run the test
test.run(wrapped_probs)

# # Summarize the results
print("\nNegation MFT Summary:")
test.summary()


Predicting 200 examples

Negation MFT Summary:
Test cases:      200
Fails (rate):    27 (13.5%)

Example fails:
0.0 ... mafia , rap stars and hood rats butt their ugly heads in a regurgitation of cinematic violence that gives brutal birth to an unlikely , but likable , hero . '
----
0.1 a movie that successfully crushes a best selling novel into a timeframe that mandates that you avoid the godzilla sized soda .
----
0.1 old-form moviemaking at its best .
----


<font color='blue'>YOUR TASK: Implement a different model test/analysis below. </font>

## You have completed this assignment! Please upload your notebook to Blackboard.