In [18]:
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Dict

import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertModel, TrainingArguments, Trainer, IntervalStrategy, AutoModelForSequenceClassification, AutoTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

---
# Train Set

In [7]:
class Transformation(ABC):

    def __init__(self, **kwargs):
        pass

    @abstractmethod
    def transform(self, dataset: pd.DataFrame) -> pd.DataFrame:
        pass


@dataclass
class TransformationConfig:
    name: str
    kwargs: Dict

In [8]:
class DummyTransformation(Transformation):

    def transform(self, dataset: pd.DataFrame) -> pd.DataFrame:
        return dataset

In [9]:
TRANSFORMATIONS = {
    'DUMMY': DummyTransformation,
}

In [10]:
@dataclass
class TrainSetConfig:
    path: str
    transformations: List[TransformationConfig]

In [11]:
def train_set_select(config: TrainSetConfig) -> pd.DataFrame:
    df = pd.read_csv(config.path, header=None, names=['id', 'text', 'label'])

    for t in config.transformations:
        transformation = TRANSFORMATIONS[t.name](*t.kwargs)
        df = transformation.transform(df)

    return df

---
# Validation set

---
# Test set

---
# Model

In [20]:
class ModelConfig:
    def __init__(
        self, model_name_or_path: str, num_labels: int, 
        max_length: int, truncation, padding, return_tensors: str,
        output_attentions: bool = False, output_hidden_states: bool = False 
    ):
        self.model_name_or_path = model_name_or_path
        self.num_labels = num_labels
        self.max_length = max_length
        self.truncation = truncation
        self.padding = padding
        self.return_tensors = return_tensors
        self.output_attentions = output_attentions
        self.output_hidden_states = output_hidden_states

In [21]:
class ModelLoader:
    def __init__(self, config: ModelConfig):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name_or_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.config.model_name_or_path,
            num_labels=self.config.num_labels,
            output_attentions=self.config.output_attentions,
            output_hidden_states=self.config.output_hidden_states,
        )

    def get_model(self, device):
        return self.model.to(device)

    def get_tokenizer(self):
        return self.tokenizer

In [None]:
# TODO pipeline - e.g.
# class MyTokenClassificationPipeline(TokenClassificationPipeline):
#     def preprocess(self, sentence, offset_mapping=None):
#         truncation = False
#         padding = 'longest'
#         model_inputs = self.tokenizer(
#             sentence,
#             return_tensors=self.framework,
#             truncation=truncation,
#             padding=padding,
#             return_special_tokens_mask=True,
#             return_offsets_mapping=self.tokenizer.is_fast,
#         )
#         if offset_mapping:
#             model_inputs["offset_mapping"] = offset_mapping
    
#         model_inputs["sentence"] = sentence
#         return model_inputs

---
# Configuration

In [22]:
# example configuration
custom_configuration = {

    'train_set': {
        'path': './data/thedeep.subset.train.txt',
        'transformations': [
            {
                'name': 'DUMMY',
                'args': {}
            }
        ]
    },

    'validation_set': {
        'path': './data/thedeep.subset.validation.txt',
    },

    'test_set': {
        'path': './data/thedeep.subset.test.txt',
    },

    'control_set': {
        'path': './data/thedeep.subset.control.txt'
    },

    'labels': {
        'path': './data/thedeep.labels.txt'
    },

    'training': {
        'batch_size': 16,
        'epochs': 3,
        'learning_rate': 1e-3,
        'output_dir': 'ClassificationBERT',
        'metric_for_best_model': 'accuracy',
    },
    
    'model':{
        'model_name_or_path':'bert-base-cased',
        'num_labels':2,
        'max_length':512,
        'truncation':True,
        'padding': 'max_length',
        'return_tensors':'pt',
        'output_attentions':True,
        'output_hidden_states':True,        
    }
}

In [30]:
model_config = ModelConfig(**custom_configuration['model'])

model_loader = ModelLoader(model_config)

tokenizer = model_loader.get_tokenizer()
model = model_loader.get_model(device)
model, tokenizer

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

(BertForSequenceClassification(
   (bert): BertModel(
     (embeddings): BertEmbeddings(
       (word_embeddings): Embedding(28996, 768, padding_idx=0)
       (position_embeddings): Embedding(512, 768)
       (token_type_embeddings): Embedding(2, 768)
       (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
       (dropout): Dropout(p=0.1, inplace=False)
     )
     (encoder): BertEncoder(
       (layer): ModuleList(
         (0): BertLayer(
           (attention): BertAttention(
             (self): BertSelfAttention(
               (query): Linear(in_features=768, out_features=768, bias=True)
               (key): Linear(in_features=768, out_features=768, bias=True)
               (value): Linear(in_features=768, out_features=768, bias=True)
               (dropout): Dropout(p=0.1, inplace=False)
             )
             (output): BertSelfOutput(
               (dense): Linear(in_features=768, out_features=768, bias=True)
               (LayerNorm): LayerNorm((768

---
# Notebook flow

In [32]:
train_set_config = TrainSetConfig(
    custom_configuration['train_set']['path'],
    [TransformationConfig(t['name'], t['args']) for t in custom_configuration['train_set']['transformations']])

train_set = train_set_select(train_set_config)
train_set

In [9]:
# TODO
validation_set = pd.read_csv(custom_configuration['validation_set']['path'], header=None, names=['id', 'text', 'label'])
validation_set

Unnamed: 0,id,text,label
0,633,The veterans threw up roadblocks on the main n...,9
1,6001,Water department complains about lack of skill...,11
2,14014,"On 13 February 2018, the Ministry of Health of...",4
3,12225,"In Kakuma and Kalobeyei, both host and refugee...",7
4,10181,'Raqqa is now empty of civilians who had been ...,9
...,...,...,...
2591,5109,UNICEF-supported Child Health Days happen in D...,4
2592,5696,The residents of Karachi once again faced prol...,7
2593,4622,Poverty ripping off Malawi farmers Poverty has...,3
2594,9053,The Sheikh Jarrah residential neighbourhood is...,10


In [10]:
# TODO
test_set = pd.read_csv(custom_configuration['test_set']['path'], header=None, names=['id', 'text', 'label'])
test_set

Unnamed: 0,id,text,label
0,7162,"After the civil war, Lebanon’s healthcare syst...",4
1,10157,For many communities in Central River Region (...,3
2,4664,Violence in the southeastern part of the Centr...,4
3,11715,"Of the 35 interviewees, 30 reported experienci...",9
4,15090,Several schools in Apuk South County of Gogria...,2
...,...,...,...
2580,1979,Tropical Storm Dineo caused widespread damage ...,10
2581,7549,Monsoon rains and increased water levels in ma...,10
2582,6064,I am extremely concerned about possible outbre...,11
2583,3582,Newly displaced persons from Al Mukha and Dhub...,11


In [11]:
# TODO
control_set = pd.read_csv(custom_configuration['control_set']['path'], header=None, names=['id', 'text', 'label'])
control_set

Unnamed: 0,id,text,label
0,5805,Overall 30% decrease in MAM Children admission...,8
1,17120,"In 2014, fear of Ebola also led to attacks on ...",9
2,11901,"Wheat is the staple food for most Afghans, com...",3
3,2589,We have received serious allegations that two ...,9
4,4188,"Somali: 67 of 93 woredas hotspot . 410k MAM, 4...",8
5,8392,"In order to prevent epidemics, the displaced p...",11
6,12693,Local cereal prices declined in most of the So...,3
7,2937,Results from the Swaziland Vulnerability Analy...,3
8,2318,Desert locust population declined due to inten...,0
9,12386,"As a result of the influx, the open defecatio...",11


In [12]:
data = {
    'train': train_set,
    'validation': validation_set,
    'test': test_set,
    'control': control_set
}

In [13]:
# TODO
label_names = pd.read_csv(custom_configuration['labels']['path'], header=None, names=['id', 'name'])

In [15]:
# TODO
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained(
    'bert-base-uncased',
    output_attentions=False,
    output_hidden_states=False
).to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
max_document_length = 512  # TODO - where to put this parameter?

tokens = {
    dataset_type:
        tokenizer(dataset['text'].tolist(),
                  padding='max_length',
                  max_length=max_document_length,
                  truncation=True,
                  return_tensors='pt')

    for dataset_type, dataset in data.items()
}

labels = {
    dataset_type:
        torch.tensor(dataset['label'].tolist())

    for dataset_type, dataset in data.items()
}

In [17]:
# TODO: this can probably be left as it is
class TextDataset(Dataset):
    def __init__(self, tokens, labels: torch.Tensor):
        self.input_ids = tokens.input_ids
        self.attention_mask = tokens.attention_mask
        self.token_type_ids = tokens.token_type_ids
        self.y = labels

    def __len__(self):
        return len(self.y)

    def __getitem__(self, i):
        return {
            'input_ids': self.input_ids[i],
            'attention_mask': self.attention_mask[i],
            'token_type_ids': self.token_type_ids[i],
            'labels': self.y[i]
        }

In [18]:
datasets = {
    dataset_type:
        TextDataset(tokens[dataset_type], labels[dataset_type])

    for dataset_type in data.keys()
}

In [19]:
# TODO
class ClassificationBERTModel(nn.Module):
    def __init__(self, bert_model: BertModel):
        super(ClassificationBERTModel, self).__init__()
        self.bert = bert_model
        self.linear = nn.Linear(768, 12)
        self.loss = nn.CrossEntropyLoss()

        for param in self.bert.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask, token_type_ids, labels):
        x = self.bert(input_ids=input_ids, attention_mask=attention_mask,
                      token_type_ids=token_type_ids).last_hidden_state
        attention = attention_mask.unsqueeze(2).expand(-1, -1, 768)
        x = x * attention
        x = x.sum(1) / (x != 0).sum(1)
        x = self.linear(x)

        return self.loss(x, labels), x

In [21]:
# TODO
training_args = TrainingArguments(
    output_dir=custom_configuration['training']['output_dir'],
    learning_rate=custom_configuration['training']['learning_rate'],
    evaluation_strategy=IntervalStrategy.EPOCH,
    save_strategy=IntervalStrategy.EPOCH,
    logging_strategy=IntervalStrategy.EPOCH,
    per_device_train_batch_size=custom_configuration['training']['batch_size'],
    per_device_eval_batch_size=custom_configuration['training']['batch_size'],
    load_best_model_at_end=True,
    metric_for_best_model=custom_configuration['training']['metric_for_best_model'],
    num_train_epochs=custom_configuration['training']['epochs']
)

In [22]:
# TODO: ???
def compute_metrics(p):
    pred, true_labels = p
    pred = pred.argmax(1)
    accuracy = accuracy_score(y_true=true_labels, y_pred=pred)
    recall = recall_score(y_true=true_labels, y_pred=pred, average='weighted', zero_division=0)
    precision = precision_score(y_true=true_labels, y_pred=pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true=true_labels, y_pred=pred, average='weighted', zero_division=0)
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

In [23]:
classification_model = ClassificationBERTModel(bert_model).to(device)

trainer = Trainer(
    model=classification_model,
    train_dataset=datasets['train'],
    eval_dataset=datasets['validation'],
    compute_metrics=compute_metrics,
    args=training_args
)

In [24]:
# FIXME
os.environ['WANDB_DISABLED'] = 'true'

In [25]:
trainer.train()



Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
test_result = trainer.predict(datasets['test'])
test_result.metrics

In [None]:
control_result = trainer.predict(datasets['control'])
control_result.metrics