TODO:

<strike>- Download data</strike>
- Re-read transformers blog post and BERT paper
<strike>- Fine-tune BERT on dataset</strike>
- SHAP values for examples
- 1/2 CheckList attacks
<strike>- Read vision paper</strike>
- Metrics for evaluating explainability change

# BERT and Sentiment Treebank

Fine-tuning code based on [this](https://medium.com/@aniruddha.choudhury94/part-2-bert-fine-tuning-tutorial-with-pytorch-for-text-classification-on-the-corpus-of-linguistic-18057ce330e1)

In [41]:
import os
import random

import pandas as pd
import numpy as np
import torch
from torch.utils.data import \
    TensorDataset, \
    DataLoader
from transformers import \
    BertTokenizer, \
    BertForSequenceClassification, \
    AdamW, \
    BertConfig, \
    get_linear_schedule_with_warmup
import pytreebank
from tqdm import tqdm
import shap

In [2]:
os.chdir('../..')

## Load data

In [3]:
dataset = pytreebank.load_sst("data/raw/stanford_sentiment_treebank/")

In [4]:
train = dataset['train']
dev = dataset['dev']
# test = dataset['test']

In [5]:
def stanford_raw_to_df(pytree_dataset):
    """
    Convert list of pytreebank LabeledTree objects to DataFrame of full-sentence examples with labels
    """
    labels = []
    sentences = []
    
    for labeled_tree_obj in pytree_dataset:
        lab, sent = labeled_tree_obj.to_labeled_lines()[0]  # First index contains full sentence
        labels.append(lab)
        sentences.append(sent)
        
    output_df = pd.DataFrame(
        {
            'sentence': sentences,
            'label': labels
        }
    )
    
    return output_df

In [6]:
train_df = stanford_raw_to_df(train)
print(train_df.shape)
train_df.head()

(8544, 2)


Unnamed: 0,sentence,label
0,The Rock is destined to be the 21st Century 's...,3
1,The gorgeously elaborate continuation of `` Th...,4
2,Singer/composer Bryan Adams contributes a slew...,3
3,You 'd think by now America would have had eno...,2
4,Yet the act is still charming here .,3


In [7]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8544 entries, 0 to 8543
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   sentence  8544 non-null   object
 1   label     8544 non-null   int64 
dtypes: int64(1), object(1)
memory usage: 133.6+ KB


In [8]:
dev_df = stanford_raw_to_df(dev)
print(dev_df.shape)
dev_df.head()

(1101, 2)


Unnamed: 0,sentence,label
0,It 's a lovely film with lovely performances b...,3
1,"No one goes unindicted here , which is probabl...",2
2,And if you 're not nearly moved to tears by a ...,3
3,"A warm , funny , engaging film .",4
4,Uses sharp humor and insight into human nature...,4


## Pre-processing

In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [10]:
train_encoded_sentences = []

for sentence in train_df['sentence'].values:
    enc_sent_as_list = tokenizer.encode(sentence, add_special_tokens=True)
    train_encoded_sentences.append(enc_sent_as_list)

In [11]:
train_df.loc[2]['sentence']

'Singer/composer Bryan Adams contributes a slew of songs -- a few potential hits , a few more simply intrusive to the story -- but the whole package certainly captures the intended , er , spirit of the piece .'

In [12]:
[tokenizer.convert_ids_to_tokens(i_d) for i_d in train_encoded_sentences[2]]

['[CLS]',
 'singer',
 '/',
 'composer',
 'bryan',
 'adams',
 'contributes',
 'a',
 'sl',
 '##ew',
 'of',
 'songs',
 '-',
 '-',
 'a',
 'few',
 'potential',
 'hits',
 ',',
 'a',
 'few',
 'more',
 'simply',
 'int',
 '##rus',
 '##ive',
 'to',
 'the',
 'story',
 '-',
 '-',
 'but',
 'the',
 'whole',
 'package',
 'certainly',
 'captures',
 'the',
 'intended',
 ',',
 'er',
 ',',
 'spirit',
 'of',
 'the',
 'piece',
 '.',
 '[SEP]']

In [13]:
dev_encoded_sentences = []

for sentence in dev_df['sentence'].values:
    enc_sent_as_list = tokenizer.encode(sentence, add_special_tokens=True)
    dev_encoded_sentences.append(enc_sent_as_list)

In [14]:
max([len(sent) for sent in train_encoded_sentences]), max([len(sent) for sent in dev_encoded_sentences])

(66, 55)

In [15]:
def pad_sentence_at_end(sentence, max_length):
    """
    Pad tokenised sentence with zeros at end
    
    :param: sentence: list of encodings for a sentence
    :param: max_length: max length to pad up to
    """
    num_zeros_to_add = max_length - len(sentence)
    zero_list = list(
        np.zeros(num_zeros_to_add).astype(int)
    )
    padded_sentence = sentence + zero_list
    return np.array(padded_sentence)


def create_sentence_input_arrays(list_encoded_sentences, max_length):
    """
    Create input arrays for BERT
    
    :param: list_encoded_sentences: List of sentence encoding lists
    :param: max_length: max length to pad up to
    """
    encoded_sentences = [pad_sentence_at_end(sent, max_length) for sent in list_encoded_sentences]
    
    train_array = np.vstack(encoded_sentences)
    
    train_attention_mask_array = (train_array != 0).astype(int)
    
    return train_array, train_attention_mask_array

In [16]:
MAX_LENGTH = 70

train_array, train_attention_mask_array = create_sentence_input_arrays(
    train_encoded_sentences, 
    MAX_LENGTH
)

dev_array, dev_attention_mask_array = create_sentence_input_arrays(
    dev_encoded_sentences, 
    MAX_LENGTH
)

In [17]:
train_array.shape, train_attention_mask_array.shape

((8544, 70), (8544, 70))

In [18]:
dev_array.shape, dev_attention_mask_array.shape

((1101, 70), (1101, 70))

In [19]:
np.array(train_encoded_sentences[0])

array([  101,  1996,  2600,  2003, 16036,  2000,  2022,  1996,  7398,
        2301,  1005,  1055,  2047,  1036,  1036, 16608,  1005,  1005,
        1998,  2008,  2002,  1005,  1055,  2183,  2000,  2191,  1037,
       17624,  2130,  3618,  2084,  7779, 29058,  8625, 13327,  1010,
        3744,  1011, 18856, 19513,  3158,  5477,  4168,  2030,  7112,
       16562,  2140,  1012,   102])

In [20]:
train_array[0]

array([  101,  1996,  2600,  2003, 16036,  2000,  2022,  1996,  7398,
        2301,  1005,  1055,  2047,  1036,  1036, 16608,  1005,  1005,
        1998,  2008,  2002,  1005,  1055,  2183,  2000,  2191,  1037,
       17624,  2130,  3618,  2084,  7779, 29058,  8625, 13327,  1010,
        3744,  1011, 18856, 19513,  3158,  5477,  4168,  2030,  7112,
       16562,  2140,  1012,   102,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0])

In [21]:
train_attention_mask_array[0]

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0])

Convert to tensors

In [22]:
train_tensor = torch.tensor(train_array)
train_attention_mask_tensor = torch.tensor(train_attention_mask_array)
train_labels_tensor = torch.tensor(train_df['label'].values)

dev_tensor = torch.tensor(dev_array)
dev_attention_mask_tensor = torch.tensor(dev_attention_mask_array)
dev_labels_tensor = torch.tensor(dev_df['label'].values)

## Fine-tune BERT

Run on Colab

In [23]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [30]:
BATCH_SIZE = 32
LEARNING_RATE = 2e-5
EPS = 1e-8
RANDOM_SEED = 3
NUM_EPOCHS = 2

In [31]:
train_dataset = TensorDataset(train_tensor, train_attention_mask_tensor, train_labels_tensor)
dev_dataset = TensorDataset(dev_tensor, dev_attention_mask_tensor, dev_labels_tensor)

In [32]:
train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
dev_data_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)

https://huggingface.co/transformers/model_doc/bert.html#bertforsequenceclassification

In [33]:
bert_model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=5,
    output_attentions=False,
    output_hidden_states=False
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [34]:
%%capture
bert_model.to(device)

In [35]:
optimiser = AdamW(
    bert_model.parameters(),
    lr=LEARNING_RATE,
    eps=EPS
)

In [36]:
scheduler = get_linear_schedule_with_warmup(
    optimiser, 
    num_warmup_steps=0,
    num_training_steps=len(train_data_loader) * NUM_EPOCHS
)

In [65]:
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

for epoch in range(NUM_EPOCHS):
    
    #========================================#
    # TRAINING                               #
    #========================================#
    
    bert_model.train()
    
    for batch in tqdm(train_data_loader):
        
        batch_input_ids = batch[0].to(device)
        batch_attention_mask = batch[1].to(device)
        batch_labels = batch[2].to(device)

        optimiser.zero_grad()  # Set gradients to 0 otherwise will accumulate

        outputs = bert_model(
            input_ids=batch_input_ids,
            token_type_ids=None,
            attention_mask=batch_attention_mask,
            labels=batch_labels
        )
        
        loss = outputs[0]
        
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(bert_model.parameters(), 1.0)
        
        optimiser.step()
        scheduler.step()
        
    #========================================#
    # EVALUATE                               #
    #========================================#  
    
    bert_model.eval()
    
    # Train accuracy:
    train_pred_labels = []
    train_labels = []
    
    for batch in train_data_loader:
        
        batch_input_ids = batch[0].to(device)
        batch_attention_mask = batch[1].to(device)
        batch_labels = batch[2].to(device)
        
        with torch.no_grad():
            outputs = bert_model(
                input_ids=batch_input_ids,
                token_type_ids=None,
                attention_mask=batch_attention_mask
            )
            
        logits = outputs[0]
        
        batch_pred_labels = list(
            torch.argmax(logits, dim=1).cpu().numpy()
        )
        train_pred_labels = train_pred_labels + batch_pred_labels
        
        batch_labels = list(
            batch_labels.cpu().numpy()
        )
        train_labels = train_labels + batch_labels
    
    train_accuracy = (np.array(train_pred_labels) == np.array(train_labels)).mean()
    
    
    # Dev accuracy:
    dev_pred_labels = []
    dev_labels = []
    
    for batch in dev_data_loader:
        
        batch_input_ids = batch[0].to(device)
        batch_attention_mask = batch[1].to(device)
        batch_labels = batch[2].to(device)
        
        with torch.no_grad():
            outputs = bert_model(
                input_ids=batch_input_ids,
                token_type_ids=None,
                attention_mask=batch_attention_mask
            )
            
        logits = outputs[0]
        
        batch_pred_labels = list(
            torch.argmax(logits, dim=1).cpu().numpy()
        )
        dev_pred_labels = dev_pred_labels + batch_pred_labels
        
        batch_labels = list(
            batch_labels.cpu().numpy()
        )
        dev_labels = dev_labels + batch_labels
    
    dev_accuracy = (np.array(dev_pred_labels) == np.array(dev_labels)).mean()
    
    print(f"Epoch {epoch+1}: train_acc={train_accuracy}, dev_acc={dev_accuracy}")

## Save model

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
os.getcwd()

May need to change current directory

In [None]:
bert_model.save_pretrained("Colab Notebooks/fine-tuned-bert-base-sst")

Download and save to models folder

## Load model

In [38]:
bert_model = BertForSequenceClassification.from_pretrained("models/fine-tuned-bert-base-sst")

In [40]:
bert_model.eval()

# Dev accuracy:
dev_pred_labels = []
dev_labels = []

for batch in tqdm(dev_data_loader):

    batch_input_ids = batch[0].to(device)
    batch_attention_mask = batch[1].to(device)
    batch_labels = batch[2].to(device)

    with torch.no_grad():
        outputs = bert_model(
            input_ids=batch_input_ids,
            token_type_ids=None,
            attention_mask=batch_attention_mask
        )

    logits = outputs[0]

    batch_pred_labels = list(
        torch.argmax(logits, dim=1).cpu().numpy()
    )
    dev_pred_labels = dev_pred_labels + batch_pred_labels

    batch_labels = list(
        batch_labels.cpu().numpy()
    )
    dev_labels = dev_labels + batch_labels

dev_accuracy = (np.array(dev_pred_labels) == np.array(dev_labels)).mean()
dev_accuracy

100%|██████████| 35/35 [02:18<00:00,  3.97s/it]


0.5213442325158947

# SHAP