In [1]:
from transformers import BertTokenizer, BertModel, TrainingArguments, Trainer, IntervalStrategy

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from torch.utils.data import Dataset
import torch.nn as nn
import torch

import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained(
    'bert-base-uncased',
    output_attentions=False,
    output_hidden_states=False
).to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
dataset_types = ['train', 'validation', 'test', 'control']
data = {dt: pd.read_csv(f'./data/thedeep.subset.{dt}.txt', header=None, usecols=[1, 2], names=['text', 'label']) for dt in dataset_types}

In [4]:
label_names = pd.read_csv('./data/thedeep.labels.txt', header=None, names=['id', 'name'])

In [5]:
max_document_length = 512

tokens = {dt: tokenizer(data[dt]['text'].tolist(), padding='max_length', max_length=max_document_length, truncation=True, return_tensors='pt') for dt in dataset_types}
labels = {dt: torch.tensor(data[dt]['label'].tolist()) for dt in dataset_types}

In [6]:
class TextDataset(Dataset):
    def __init__(self, tokens_dict, labels: torch.Tensor):
        self.input_ids = tokens_dict.input_ids
        self.attention_mask = tokens_dict.attention_mask
        self.token_type_ids = tokens_dict.token_type_ids
        self.y = labels

    def __len__(self):
        return len(self.y)

    def __getitem__(self, i):
        return {
            'input_ids': self.input_ids[i],
            'attention_mask': self.attention_mask[i],
            'token_type_ids': self.token_type_ids[i],
            'labels': self.y[i]
        }

In [7]:
datasets = {dt: TextDataset(tokens[dt], labels[dt]) for dt in dataset_types}

In [8]:
class ClassificationBERTModel(nn.Module):
    def __init__(self, bert_model: BertModel):
        super(ClassificationBERTModel, self).__init__()
        self.bert = bert_model
        self.linear = nn.Linear(768, 12)
        self.loss = nn.CrossEntropyLoss()

        for param in self.bert.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask, token_type_ids, labels):
        x = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).last_hidden_state
        attention = attention_mask.unsqueeze(2).expand(-1, -1, 768)
        x = x * attention
        x = x.sum(1)/(x != 0).sum(1)
        x = self.linear(x)

        return self.loss(x, labels), x

In [9]:
batch_size = 16
epochs = 3
lr = 1e-3

In [10]:
training_args = TrainingArguments(
    output_dir='ClassificationBERT',
    learning_rate=lr,
    evaluation_strategy=IntervalStrategy.EPOCH,
    save_strategy=IntervalStrategy.EPOCH,
    logging_strategy=IntervalStrategy.EPOCH,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    num_train_epochs=epochs
)

In [11]:
def compute_metrics(p):
    pred, true_labels = p
    pred = pred.argmax(1)
    accuracy = accuracy_score(y_true=true_labels, y_pred=pred)
    recall = recall_score(y_true=true_labels, y_pred=pred, average='weighted', zero_division=0)
    precision = precision_score(y_true=true_labels, y_pred=pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true=true_labels, y_pred=pred, average='weighted', zero_division=0)
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

In [12]:
trainer = Trainer(
    model=ClassificationBERTModel(bert_model).to(device),
    train_dataset=datasets['train'],
    eval_dataset=datasets['validation'],
    compute_metrics=compute_metrics,
    args=training_args
)

In [13]:
trainer.train()



Epoch,Training Loss,Validation Loss


TrainOutput(global_step=2271, training_loss=0.9619511243584187, metrics={'train_runtime': 1470.5816, 'train_samples_per_second': 24.705, 'train_steps_per_second': 1.544, 'total_flos': 0.0, 'train_loss': 0.9619511243584187, 'epoch': 3.0})

In [14]:
test_result = trainer.predict(datasets['test'])
test_result.metrics

{'test_loss': 0.8133374452590942,
 'test_accuracy': 0.7617021276595745,
 'test_precision': 0.7522458952494068,
 'test_recall': 0.7617021276595745,
 'test_f1': 0.738799164616739,
 'test_runtime': 82.3011,
 'test_samples_per_second': 31.409,
 'test_steps_per_second': 1.968}

In [15]:
control_result = trainer.predict(datasets['control'])
control_result.metrics

{'test_loss': 0.8438506126403809,
 'test_accuracy': 0.7,
 'test_precision': 0.825,
 'test_recall': 0.7,
 'test_f1': 0.7238095238095237,
 'test_runtime': 0.3198,
 'test_samples_per_second': 31.271,
 'test_steps_per_second': 3.127}