In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import math
import random
import warnings
from functools import *
from typing import *

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import pytorch_lightning as pl
import sklearn.metrics as metrics
from torch.utils.data import Dataset, IterableDataset, DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ExponentialLR
from transformers import (BertTokenizer,
                          BertModel, 
                          BertForMaskedLM, 
                          BertForNextSentencePrediction, 
                          AdamW)

In [3]:
pretrained_model = 'bert-large-cased'

epochs = 8
batch_size = 6
max_len = 256

# run trainer.fit or not
run_fit = True
# run trainer.test or not
run_test = True

In [15]:
class STSBData(Dataset):
    
    def __init__(self, file: str):
        super().__init__()
        self.data = pd.read_csv(file, sep='\t', error_bad_lines=False)[['sentence1', 'sentence2', 'score']]
        self.data['score'] /= 5
        
    def __getitem__(self, index):
        return tuple(self.data.iloc[index])

    def __len__(self):
        return len(self.data)

In [16]:
class STSBDataModule(pl.LightningDataModule):
    
    def __init__(self,
                 train_file: str,
                 dev_file: str,
                 test_file: str,
                 pretrained_model: str,
                 batch_size: int,
                 max_len: int,
                 train_batch_size: Optional[int] = None,
                 val_batch_size: Optional[int] = None,
                 test_batch_size: Optional[int] = None):
        super().__init__()
        self.train_file = train_file
        self.dev_file = dev_file
        self.test_file = test_file
        self.pretrained_model = pretrained_model
        self.train_batch_size = train_batch_size or batch_size
        self.dev_batch_size = val_batch_size or batch_size
        self.test_batch_size = test_batch_size or batch_size
        self.max_len = max_len

    def __call__(self, batch: List[Tuple[torch.Tensor, ...]]) -> Dict[str, torch.Tensor]:
        s1, s2, label = zip(*batch)
        pairs = list(zip(s1, s2))
        next_sentence_label = torch.tensor(label).long()
        encoded = self.tokenizer.batch_encode_plus(batch_text_or_text_pairs=pairs,
                                                   max_length=self.max_len,
                                                   add_special_tokens=True,
                                                   return_token_type_ids=True,
                                                   return_attention_mask=True)
        input_ids = self._make_batch(encoded['input_ids'])
        token_type_ids = self._make_batch(encoded['token_type_ids'])
        attention_mask = self._make_batch(encoded['attention_mask'])
        return {'input_ids': input_ids, 
                  'token_type_ids': token_type_ids, 
                  'attention_mask': attention_mask,
                  'next_sentence_label': next_sentence_label}
    
    def _make_batch(self, sequence: List[List[torch.Tensor]]) -> torch.Tensor:
        return rnn.pad_sequence(list(map(torch.tensor, sequence)), batch_first=True)
    
    def prepare_data(self):
        self.tokenizer = BertTokenizer.from_pretrained(self.pretrained_model)
        self.train_data = STSBData(self.train_file)
        self.dev_data = STSBData(self.dev_file)
        self.test_data = None

    def setup(self, stage: Optional[str] = None):
        pass

    def train_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.train_data,
                          batch_size=self.train_batch_size,
                          collate_fn=self,
                          num_workers=1) # ensure num_workers=1 to avoid duplicated entries in each epoch

    def val_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.val_data,
                          batch_size=self.val_batch_size,
                          collate_fn=self,
                          num_workers=1) # ensure num_workers=1 to avoid duplicated entries in each epoch
    
    def test_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.test_data,
                          batch_size=self.test_batch_size,
                          collate_fn=self,
                          num_workers=1) # ensure num_workers=1 to avoid duplicated entries in each epoch

In [17]:
dm = STSBDataModule('train.tsv', 'dev.tsv', 'test.tsv', pretrained_model, batch_size, max_len)
dm.prepare_data()

b'Skipping line 2509: expected 10 fields, saw 11\nSkipping line 2650: expected 10 fields, saw 11\nSkipping line 2727: expected 10 fields, saw 11\nSkipping line 3071: expected 10 fields, saw 11\nSkipping line 3393: expected 10 fields, saw 11\n'
b'Skipping line 1042: expected 10 fields, saw 11\nSkipping line 1066: expected 10 fields, saw 11\nSkipping line 1083: expected 10 fields, saw 11\nSkipping line 1137: expected 10 fields, saw 11\nSkipping line 1150: expected 10 fields, saw 11\n'


In [18]:
next(iter(dm.train_dataloader()))

{'input_ids': tensor([[  101,   138,  4261,  1110,  1781,  1228,   119,   102,  1760,  1586,
           4261,  1110,  1781,  1228,   119,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0],
         [  101,   138,  1299,  1110,  1773,   170,  1415, 10284,   119,   102,
            138,  1299,  1110,  1773,   170, 10284,   119,   102,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0],
         [  101,   138,  1299,  1110,  9243,   188,  8167, 15018,  1181,  9553,
           1113,   170, 13473,   119,   102,   138,  1299,  1110,  9243,   188,
           8167, 23372,  1174,  9553,  1113,  1126,  8362,  2528, 27499, 13473,
            119,   102],
         [  101,  2677,  1441,  1132,  1773, 10924,   119,   102,  1960,  1441,
           1132,  1773, 10924,   119,   102,     0,     0,     0,     0,     0,
              0,     0,     0,  

In [7]:
class ScoringModel(pl.LightningModule):
    def __init__(self, *,
                 pretrained_model: str, 
                 learning_rate: float = 0.001):
        super().__init__()
        self.learning_rate = learning_rate
        self.nsp = BertForNextSentencePrediction.from_pretrained(pretrained_model)
        self.save_hyperparameters('pretrained_model')
    
    def forward(self, 
                input_ids: torch.Tensor,
                token_type_ids: torch.Tensor, 
                attention_mask: torch.Tensor, 
                next_sentence_label: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        outputs = self.nsp(input_ids=input_ids, 
                           token_type_ids=token_type_ids, 
                           attention_mask=attention_mask, 
                           next_sentence_label=next_sentence_label)
        return outputs
    
    def configure_optimizers(self) -> Optimizer:
        self.optimizer = AdamW(model.parameters(), lr=self.learning_rate, weight_decay=0.01)
        self.scheduler = ExponentialLR(self.optimizer, gamma=0.9)
        return self.optimizer
    
    def training_step(self, batch: torch.Tensor, batch_idx: int) -> pl.TrainResult:
        losses, score = self(**batch)
        loss = losses.mean()
        result = pl.TrainResult(loss)
        result.log_dict({'train_loss': loss,
                         'lr': self.scheduler.get_lr()[0]}, 
                        on_step=True, 
                        on_epoch=False, 
                        logger=True, 
                        prog_bar=True)
        return result
    
    def training_epoch_end(self, results: pl.EvalResult) -> pl.TrainResult:
        # Shrink learning rate after each epoch
        self.scheduler.step()
        return results
    
    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> pl.EvalResult:
        losses, score = self(**batch)
        loss = losses.mean()
        result = pl.EvalResult(checkpoint_on=loss)
        result.log_dict({'val_loss': loss}, 
                        on_step=True, 
                        on_epoch=False, 
                        logger=True, 
                        prog_bar=True)
        result.label = batch['next_sentence_label'].tolist()
        result.pred = score.argmax(dim=1).tolist()
        return result
    
    def validation_epoch_end(self, results: pl.EvalResult) -> pl.EvalResult:
        label = sum(results.label, [])
        pred = sum(results.pred, [])
        accuracy = metrics.accuracy_score(label, pred)
        precision = metrics.precision_score(label, pred)
        recall = metrics.recall_score(label, pred)
        f1 = metrics.f1_score(label, pred)
        confusion = metrics.confusion_matrix(label, pred)
        scores = f'accuracy: {accuracy:.2f}, precision: {precision:.2f}, recall: {recall:.2f}, f1: {f1:.2f}'
        title = 'sanity check' if self.current_epoch == 0 else f'epoch {self.current_epoch}'
        
        # printing only from master process
        self.print(f' {title} '.center(len(scores), "="))
        self.print(pd.DataFrame(confusion, 
                           columns=['pred: -', 'pred: +'], 
                           index=['label: -', 'label: +']))
        self.print(scores)
        return results
    
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> pl.EvalResult:
        score, = self(**batch)
        result = pl.EvalResult()
        result.score = score[:, 0].tolist()
        return result
    
    def test_epoch_end(self, results: pl.EvalResult) -> pl.EvalResult:
        torch.save(torch.tensor(sum(results.score, [])), 'predictions.pt')
        return results

In [16]:
dataset = CtrlRegMappingNSPDataModule(data_file=data_file, 
                                      pretrained_model=pretrained_model, 
                                      batch_size=batch_size,
                                      max_len=max_len)
model = CtrlRegAlignerModel(pretrained_model=pretrained_model)

In [9]:
trainer = pl.Trainer(gpus=1, 
                     max_epochs=epochs,
                     auto_lr_find=True,
                     accumulate_grad_batches=4,
                     precision=32) # Due to a bug in pytorch-lightning, 16 bit training won't work with auto_lr_find

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


In [10]:
if run_fit:
    warnings.filterwarnings('ignore')
    trainer.fit(model, datamodule=dataset)

missing regs for train split:
CFR:12:1026:1026.19
CFR:12:1026:1026.21
CFR:12:1026:1026.23:a
CFR:12:1026:1026.23:b
CFR:12:1026:1026.24:a
CFR:12:1026:1026.24:c
CFR:12:1026:1026.24:d
CFR:12:1026:1026.24:f:2
CFR:12:1026:1026.24:f:3
CFR:12:1026:1026.24:g
CFR:12:1026:1026.24:h
CFR:12:1026:1026.24:i
CFR:12:1026:1026.26
CFR:12:1026:1026.32:a
CFR:12:1026:1026.32:c
CFR:12:1026:1026.32:d
CFR:12:1026:1026.34
CFR:12:1026:1026.35
CFR:12:1026:1026.40:a
CFR:12:1026:1026.40:b
CFR:12:1026:1026.40:c
CFR:12:1026:1026.40:d
CFR:12:1026:1026.40:e
CFR:12:1026:1026.40:f
CFR:12:1026:1026.42
CFR:12:1026:1026.51:a
CFR:12:1026:1026.51:b
CFR:12:1026:1026.5:b:i
CFR:12:1026:1026.5:b:ii
CFR:12:1026:1026.60:a
CFR:12:1026:1026.60:b
CFR:12:1026:1026.60:c
CFR:12:1026:1026.6:a
Data quality check on DCO applications 
USC:15:41:1681g:e
USC:15:41:1681g:g:1
missing regs for val split:
USC:15:41:1681g:e

  | Name | Type                          | Params
-------------------------------------------------------
0 | nsp  | BertForN

          pred: -  pred: +
label: -        4        2
label: +        5        1
accuracy: 0.42, precision: 0.33, recall: 0.17, f1: 0.22


HBox(children=(FloatProgress(value=0.0, description='Finding best initial lr', style=ProgressStyle(description…




LR finder stopped early due to diverging loss.
Learning rate set to 1.9054607179632464e-05

  | Name | Type                          | Params
-------------------------------------------------------
0 | nsp  | BertForNextSentencePrediction | 333 M 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

          pred: -  pred: +
label: -        6        0
label: +        6        0
accuracy: 0.50, precision: 0.00, recall: 0.00, f1: 0.00


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

          pred: -  pred: +
label: -       46      196
label: +        7      235
accuracy: 0.58, precision: 0.55, recall: 0.97, f1: 0.70


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

          pred: -  pred: +
label: -       68      174
label: +       13      229
accuracy: 0.61, precision: 0.57, recall: 0.95, f1: 0.71


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

          pred: -  pred: +
label: -       82      160
label: +       12      230
accuracy: 0.64, precision: 0.59, recall: 0.95, f1: 0.73


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

          pred: -  pred: +
label: -      112      130
label: +       24      218
accuracy: 0.68, precision: 0.63, recall: 0.90, f1: 0.74


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

          pred: -  pred: +
label: -       90      152
label: +       17      225
accuracy: 0.65, precision: 0.60, recall: 0.93, f1: 0.73


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

          pred: -  pred: +
label: -       87      155
label: +       14      228
accuracy: 0.65, precision: 0.60, recall: 0.94, f1: 0.73


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

          pred: -  pred: +
label: -       61      181
label: +        6      236
accuracy: 0.61, precision: 0.57, recall: 0.98, f1: 0.72


In [17]:
if run_test:
    # I believe this is a bug of pytorch_lightning. If I use auto_lr_find 
    # for trainer, I can't reuse the trainer and model for testing, and I have 
    # to use a separate tester and load model from checkpoint like this
    tester = pl.Trainer(gpus=1, auto_lr_find=False) 
    test_model = CtrlRegAlignerModel(pretrained_model=pretrained_model)
    tester.test(test_model, ckpt_path=testing_ckpt, datamodule=dataset)
    scores = torch.load('predictions.pt')
    prediction = dataset.get_top_k_predictions_from_test_results(scores, k=5)
    print(prediction)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
missing regs for train split:
CFR:12:1026:1026.19
CFR:12:1026:1026.21
CFR:12:1026:1026.23:a
CFR:12:1026:1026.23:b
CFR:12:1026:1026.24:a
CFR:12:1026:1026.24:c
CFR:12:1026:1026.24:d
CFR:12:1026:1026.24:f:2
CFR:12:1026:1026.24:f:3
CFR:12:1026:1026.24:g
CFR:12:1026:1026.24:h
CFR:12:1026:1026.24:i
CFR:12:1026:1026.26
CFR:12:1026:1026.32:a
CFR:12:1026:1026.32:c
CFR:12:1026:1026.32:d
CFR:12:1026:1026.34
CFR:12:1026:1026.35
CFR:12:1026:1026.40:a
CFR:12:1026:1026.40:b
CFR:12:1026:1026.40:c
CFR:12:1026:1026.40:d
CFR:12:1026:1026.40:e
CFR:12:1026:1026.40:f
CFR:12:1026:1026.42
CFR:12:1026:1026.51:a
CFR:12:1026:1026.51:b
CFR:12:1026:1026.5:b:i
CFR:12:1026:1026.5:b:ii
CFR:12:1026:1026.60:a
CFR:12:1026:1026.60:b
CFR:12:1026:1026.60:c
CFR:12:1026:1026.6:a
Data quality check on DCO applications 
USC:15:41:1681g:e
USC:15:41:1681g:g:1
missing regs for val split:
USC:15:41:1681g:e


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------

{'CDD Review of Regional Public Files': ['CFR:12:1003:1003.5:e', 'CFR:12:1003:1003.5:b:2', 'CFR:12:25:25.43:b:2', 'CFR:12:25:25.43:a:2', 'CFR:12:25:25.43:b:1:ii'], 'Incentive Plan Document Distribution': ['CFR:12:222:222.90:d:1', 'CFR:12:222:222.90:d:2:iv', 'CFR:12:1022:1022.137:a:2:iii:C', 'CFR:12:222:222.90:d:2:i', 'CFR:12:222:222.90:d:2:iii'], 'Quality Check on Dating of Payment Activity': ['CFR:12:1024:1024.35:e:3:i:C', 'CFR:12:1022:1022.23:a:4', 'CFR:12:1022:1022.42:c', 'CFR:12:1024:1024.35:d', 'CFR:12:1024:1024.35:e:3:i:A'], 'Quality check on manual approvals': ['CFR:12:25:25.24:a', 'CFR:12:25:25.43:b:3:i', 'CFR:12:25:25.43:a:2', 'CFR:12:1003:1003.5:e', 'CFR:12:1026:1026.5:a:3:vi']}
