In [None]:
!pip install pandas==1.2.3
!pip install numpy==1.20.2
!pip install jupyterlab==3.0.12
!pip install matplotlib==3.4.1
!pip install seaborn==0.11.1
!pip install shap==0.39.0
!pip install torch==1.7.1
!pip install datasets==1.5.0
!pip install transformers==4.3.3
!pip install ipywidgets==7.6.3
!pip install tqdm==4.49.0
!pip install checklist==0.0.10
!pip install allennlp==2.1.0
!pip install typing==3.7.4.3
!pip install pytreebank==0.2.7

# Fine-tuned `bert-base-uncased` on AG News

In [1]:
import os
import random

import pandas as pd
import numpy as np
import scipy as sp
import torch
import spacy
from torch.utils.data import \
    TensorDataset, \
    DataLoader
from transformers import \
    BertTokenizer, \
    BertForSequenceClassification, \
    AdamW, \
    BertConfig, \
    get_linear_schedule_with_warmup
import pytreebank
from tqdm import tqdm
import shap
from checklist.perturb import Perturb

In [2]:
# os.chdir('../..')

In [3]:
os.getcwd()

'/content'

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
os.chdir('drive/My Drive/Colab Notebooks/Github/ucl-nlp-group-project')

In [6]:
from src.data.dataload import load_sst, load_agnews
from src.models.bert_utils import \
    pad_sentence_at_end, \
    create_sentence_input_arrays, \
    AGN_MAX_LENGTH, \
    AGN_BERT_HYPERPARAMETERS, \
    AGN_NUM_LABELS, \
    fine_tune_bert, \
    make_predictions

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## AG News

In [8]:
agnews = load_agnews()

In [9]:
train_agn, dev_agn, test_agn = agnews.train_val_test
train_agn.shape, dev_agn.shape, test_agn.shape

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1817.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1227.0, style=ProgressStyle(description…




Using custom data configuration default


Downloading and preparing dataset ag_news/default (download: 29.88 MiB, generated: 30.23 MiB, post-processed: Unknown size, total: 60.10 MiB) to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=11045148.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=751209.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset ag_news downloaded and prepared to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a. Subsequent calls will reuse this data.


((108000, 2), (12000, 2), (7600, 2))

In [10]:
train_agn.head()

Unnamed: 0,label,text
99719,1,Pacers Season Tossed Into Doubt (AP) AP - Afte...
16354,0,Need for carbon sink technologies Climate scie...
112339,0,Putin: Iraq Still Too Dangerous for Russian Fi...
70216,0,Hassan #39;s husband makes another appeal This...
25883,1,All Change in Bundesliga Title Race Week Four ...


In [11]:
train_agn.rename(columns={'text': 'sentence'}, inplace=True)
dev_agn.rename(columns={'text': 'sentence'}, inplace=True)

In [12]:
(train_agn['label'].value_counts() / train_agn.shape[0]).sort_index()

0    0.250167
1    0.249750
2    0.250324
3    0.249759
Name: label, dtype: float64

In [13]:
train_agn.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 108000 entries, 99719 to 14325
Data columns (total 2 columns):
 #   Column    Non-Null Count   Dtype 
---  ------    --------------   ----- 
 0   label     108000 non-null  int64 
 1   sentence  108000 non-null  object
dtypes: int64(1), object(1)
memory usage: 2.5+ MB


### Tokenization

In [14]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [15]:
train_encoded_sentences = []

for sentence in train_agn['sentence'].values:
    enc_sent_as_list = tokenizer.encode(sentence, add_special_tokens=True)
    train_encoded_sentences.append(enc_sent_as_list)

In [16]:
dev_encoded_sentences = []

for sentence in dev_agn['sentence'].values:
    enc_sent_as_list = tokenizer.encode(sentence, add_special_tokens=True)
    dev_encoded_sentences.append(enc_sent_as_list)

In [17]:
train_array, train_attention_mask_array = create_sentence_input_arrays(
    train_encoded_sentences, 
    AGN_MAX_LENGTH
)

dev_array, dev_attention_mask_array = create_sentence_input_arrays(
    dev_encoded_sentences, 
    AGN_MAX_LENGTH
)

In [18]:
train_array.shape, train_attention_mask_array.shape, dev_array.shape, dev_attention_mask_array.shape

((108000, 380), (108000, 380), (12000, 380), (12000, 380))

Convert to tensors

In [19]:
train_tensor = torch.tensor(train_array)
train_attention_mask_tensor = torch.tensor(train_attention_mask_array)
train_labels_tensor = torch.tensor(train_agn['label'].values)

dev_tensor = torch.tensor(dev_array)
dev_attention_mask_tensor = torch.tensor(dev_attention_mask_array)
dev_labels_tensor = torch.tensor(dev_agn['label'].values)

In [20]:
train_dataset = TensorDataset(train_tensor, train_attention_mask_tensor, train_labels_tensor)
dev_dataset = TensorDataset(dev_tensor, dev_attention_mask_tensor, dev_labels_tensor)

In [21]:
train_data_loader = DataLoader(train_dataset, batch_size=AGN_BERT_HYPERPARAMETERS['batch_size'], shuffle=True)
dev_data_loader = DataLoader(dev_dataset, batch_size=AGN_BERT_HYPERPARAMETERS['batch_size'])

## Fine-tune BERT

Run on Colab

In [22]:
bert_agn = fine_tune_bert(
    device, 
    train_data_loader, 
    dev_data_loader, 
    num_labels=AGN_NUM_LABELS, 
    hyperparameter_dict=AGN_BERT_HYPERPARAMETERS
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch 1: train_acc=0.9583518518518519, dev_acc=0.9430833333333334


100%|██████████| 6750/6750 [1:10:52<00:00,  1.59it/s]


Epoch 2: train_acc=0.9744259259259259, dev_acc=0.9481666666666667


In [24]:
bert_agn.save_pretrained("models/fine-tuned-bert-base-agn")

## Load model

In [25]:
bert_agn = BertForSequenceClassification.from_pretrained("models/fine-tuned-bert-base-agn")

In [26]:
%%capture
bert_agn.to(device)

## Make predictions

In [27]:
train_agn.head()

Unnamed: 0,label,sentence
99719,1,Pacers Season Tossed Into Doubt (AP) AP - Afte...
16354,0,Need for carbon sink technologies Climate scie...
112339,0,Putin: Iraq Still Too Dangerous for Russian Fi...
70216,0,Hassan #39;s husband makes another appeal This...
25883,1,All Change in Bundesliga Title Race Week Four ...


In [28]:
predictions = make_predictions(
    train_agn, 
    bert_agn, 
    tokenizer, 
    'sentence', 
    device, 
    AGN_MAX_LENGTH, 
    AGN_BERT_HYPERPARAMETERS
)

100%|██████████| 6750/6750 [23:39<00:00,  4.75it/s]


In [29]:
(train_agn['label'].values == predictions).mean()

0.9744259259259259