In [None]:
!pip install pandas==1.2.3
!pip install numpy==1.20.2
!pip install jupyterlab==3.0.12
!pip install matplotlib==3.4.1
!pip install seaborn==0.11.1
!pip install shap==0.39.0
!pip install torch==1.7.1
!pip install datasets==1.5.0
!pip install transformers==4.3.3
!pip install ipywidgets==7.6.3
!pip install tqdm==4.49.0
!pip install checklist==0.0.10
!pip install allennlp==2.1.0
!pip install allennlp-models==2.1.0
!pip install typing==3.7.4.3
!pip install pytreebank==0.2.7
!pip install spacy==3.0.5
!pip install importlib-metadata==3.3.0

# Fine-tuned `bert-base-uncased` on AG News

In [1]:
import os
import random

import pandas as pd
import numpy as np
import scipy as sp
import torch
import spacy
from torch.utils.data import \
    TensorDataset, \
    DataLoader
from transformers import \
    BertTokenizer, \
    BertForSequenceClassification, \
    AdamW, \
    BertConfig, \
    get_linear_schedule_with_warmup
import pytreebank
from tqdm import tqdm
import shap
from checklist.perturb import Perturb

In [2]:
# os.chdir('../..')

In [3]:
os.getcwd()

'/Users/stevengeorge/Documents/Github/ucl-nlp-group-project'

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
os.chdir('drive/My Drive/Colab Notebooks/Github/ucl-nlp-group-project')

In [10]:
from src.data.dataload import load_sst, load_agnews
from src.models.bert_utils import \
    _pad_sentence_at_end, \
    _create_sentence_input_arrays, \
    fine_tune_bert, \
    make_predictions

In [11]:
AGN_MAX_LENGTH = 380
AGN_BERT_HYPERPARAMETERS = dict(
    batch_size=16,
    learning_rate=2e-5,
    number_of_epochs=2,
    max_length=AGN_MAX_LENGTH
)
AGN_NUM_LABELS = 4

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## AG News

In [13]:
agnews = load_agnews()

In [14]:
train_agn, dev_agn, test_agn = agnews.train_val_test
train_agn.shape, dev_agn.shape, test_agn.shape

Using custom data configuration default
Reusing dataset ag_news (/Users/stevengeorge/.cache/huggingface/datasets/ag_news/default/0.0.0/17ec33e23df9e89565131f989e0fdf78b0cc4672337b582da83fc3c9f79fe34d)


((108000, 3), (12000, 3), (7600, 3))

In [15]:
train_agn.head()

Unnamed: 0,sentence,label,title
0,"Reuters - Short-sellers, Wall Street's dwindli...",2,Wall St. Bears Claw Back Into the Black (Reuters)
1,Reuters - Private investment firm Carlyle Grou...,2,Carlyle Looks Toward Commercial Aerospace (Reu...
2,Reuters - Soaring crude prices plus worriesabo...,2,Oil and Economy Cloud Stocks' Outlook (Reuters)
3,Reuters - Authorities have halted oil exportfl...,2,Iraq Halts Oil Exports from Main Southern Pipe...
4,"AFP - Tearaway world oil prices, toppling reco...",2,"Oil prices soar to all-time record, posing new..."


In [16]:
(train_agn['label'].value_counts() / train_agn.shape[0]).sort_index()

0    0.251315
1    0.249769
2    0.248278
3    0.250639
Name: label, dtype: float64

In [17]:
train_agn.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 108000 entries, 0 to 107999
Data columns (total 3 columns):
 #   Column    Non-Null Count   Dtype 
---  ------    --------------   ----- 
 0   sentence  108000 non-null  object
 1   label     108000 non-null  int64 
 2   title     108000 non-null  object
dtypes: int64(1), object(2)
memory usage: 3.3+ MB


### Tokenization

In [18]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [19]:
train_encoded_sentences = []

for sentence in train_agn['sentence'].values:
    enc_sent_as_list = tokenizer.encode(sentence, add_special_tokens=True)
    train_encoded_sentences.append(enc_sent_as_list)

In [20]:
dev_encoded_sentences = []

for sentence in dev_agn['sentence'].values:
    enc_sent_as_list = tokenizer.encode(sentence, add_special_tokens=True)
    dev_encoded_sentences.append(enc_sent_as_list)

In [22]:
train_array, train_attention_mask_array = _create_sentence_input_arrays(
    train_encoded_sentences, 
    AGN_MAX_LENGTH
)

dev_array, dev_attention_mask_array = _create_sentence_input_arrays(
    dev_encoded_sentences, 
    AGN_MAX_LENGTH
)

In [23]:
train_array.shape, train_attention_mask_array.shape, dev_array.shape, dev_attention_mask_array.shape

((108000, 380), (108000, 380), (12000, 380), (12000, 380))

Convert to tensors

In [24]:
train_tensor = torch.tensor(train_array)
train_attention_mask_tensor = torch.tensor(train_attention_mask_array)
train_labels_tensor = torch.tensor(train_agn['label'].values)

dev_tensor = torch.tensor(dev_array)
dev_attention_mask_tensor = torch.tensor(dev_attention_mask_array)
dev_labels_tensor = torch.tensor(dev_agn['label'].values)

In [25]:
train_dataset = TensorDataset(train_tensor, train_attention_mask_tensor, train_labels_tensor)
dev_dataset = TensorDataset(dev_tensor, dev_attention_mask_tensor, dev_labels_tensor)

In [26]:
train_data_loader = DataLoader(train_dataset, batch_size=AGN_BERT_HYPERPARAMETERS['batch_size'], shuffle=True)
dev_data_loader = DataLoader(dev_dataset, batch_size=AGN_BERT_HYPERPARAMETERS['batch_size'])

## Fine-tune BERT

Run on Colab

In [28]:
bert_agn = fine_tune_bert(
    device, 
    train_data_loader, 
    dev_data_loader, 
    num_labels=AGN_NUM_LABELS, 
    hyperparameter_dict=AGN_BERT_HYPERPARAMETERS
)

In [24]:
bert_agn.save_pretrained("models/fine-tuned-bert-base-agn")

## Load model

In [25]:
bert_agn = BertForSequenceClassification.from_pretrained("models/fine-tuned-bert-base-agn")

In [26]:
%%capture
bert_agn.to(device)

## Make predictions

In [27]:
train_agn.head()

Unnamed: 0,label,sentence
99719,1,Pacers Season Tossed Into Doubt (AP) AP - Afte...
16354,0,Need for carbon sink technologies Climate scie...
112339,0,Putin: Iraq Still Too Dangerous for Russian Fi...
70216,0,Hassan #39;s husband makes another appeal This...
25883,1,All Change in Bundesliga Title Race Week Four ...


In [28]:
predictions = make_predictions(
    train_agn['sentence'], 
    bert_agn, 
    tokenizer, 
    device, 
    AGN_BERT_HYPERPARAMETERS
)

100%|██████████| 6750/6750 [23:39<00:00,  4.75it/s]


In [1]:
(train_agn['label'].values == predictions[1].argmax(axis=1)).mean()