# Building the Pytorch Lightning Module

In [1]:
import os
import json

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from scipy import stats

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartTokenizer

import pytorch_lightning as pl

In [2]:
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')
model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

In [3]:
train = torch.load('../data/processed/train_dataset.pt')
test = torch.load('../data/processed/test_dataset.pt')
val = torch.load('../data/processed/val_dataset.pt')

In [4]:
def val_dataloader(val):
    return DataLoader(val, shuffle=True, batch_size=4)

In [5]:
val_dl = val_dataloader(val)

In [9]:
step, batch = next(enumerate(val_dl))

In [5]:
for step, batch in enumerate(train):
    
    input_ids, attn_mask, label_ids = batch
    break

In [6]:
tokenizer.decode(input_ids, True, True)

'Nick: Hey Dan, hey Eugenio Eugenio: Hi Dan: Hi, Nick Nick: Did you see that weird German guy yesterday at the party? He looked like fucking Harry Potter Dan: Lol! True Eugenio: And you look like fucking Hagrid, Nick XD'

In [7]:
tokenizer.decode(label_ids, True, True)

'There was an odd German at the party yesterday who resembled Harry Potter. Nick looks like Hagrid.'

In [8]:
train_loader = DataLoader(train, batch_size=1)

In [9]:
for step, batch in enumerate(train_loader):
    
    input_ids, attn_mask, label_ids = batch
    break

In [10]:
batch = {
    'input_ids': input_ids,
    'attention_mask': attn_mask,
    'labels': label_ids,
    'return_dict': True
}

In [11]:
output = model(**batch)

In [12]:
loss = output['loss']

In [13]:
loss

tensor(10.5333, grad_fn=<NllLossBackward>)

In [3]:
import torch
from torch.utils.data import DataLoader
from transformers import BartForConditionalGeneration, BartTokenizer
import pytorch_lightning as pl
from datasets import load_metric



class BartLightningModule(pl.LightningModule):
    
    def __init__(self, 
                 pretrained_nlp_model: str,
                 train_dataset: str,
                 test_dataset: str,
                 val_dataset: str,
                 batch_size: int,
                 learning_rate: float = 3e-05):
        """
        A Pytorch-Lightning Module that trains Bart from the  HuggingFace transformers 
        library.
        
        :param pretrained_nlp_model: (str) the name of the pretrained mode you want to use.
        :param train_dataset: (str) path to pytorch dataset containing train data.
        :param test_dataset: (str) path to pytorch dataset containing test data.
        :param val_dataset: (str) path to pytorch dataset containing validation data.
        :param batch_size: (int) Number of data points to pass per batch in the train, test, and validation sets.
        :param learning_rate: (float) Initial Learning Rate to set.
        :returns: None
        """
        super().__init__()
        
        self.batch_size = int(batch_size)
        self.train_dataset = str(train_dataset)
        self.test_dataset = str(test_dataset)
        self.val_dataset = str(val_dataset)
        self.hparams.learning_rate = learning_rate
        
        self.bart = BartForConditionalGeneration.from_pretrained(pretrained_nlp_model)
        self.tokenizer = BartTokenizer.from_pretrained(pretrained_nlp_model)

        
    def forward(self, x):
        
        # Run through NLP Model
        output = self.bart(**x)

        return output

    
    def training_step(self, batch, batch_idx):

        input_ids, attn_mask, labels = batch
        
        x = {
            'input_ids': input_ids,
            'attention_mask': attn_mask,
            'labels': labels,
            'return_dict': True
        }
        
        # Run through NLP Model
        out = self.bart(**x)
                
        loss = out['loss']
        print(f'current_epoch: {self.current_epoch};')
        print(f'global_step: {self.global_step}')
        print(f'train_loss: {loss};')
        
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, sync_dist=True, reduce_fx=torch.mean)
        
        return result
    
    
    def validation_step(self, batch, batch_idx):
        input_ids, attn_mask, labels = batch
        
        x = {
            'input_ids': input_ids,
            'attention_mask': attn_mask,
            'labels': labels,
            'return_dict': True
        }
        
        # Run through NLP Model
        out = self.bart(**x)
                
        loss = out['loss']
        
        return {
            'logits': out['logits'],
            'labels': x['labels'],
            'loss': loss.reshape(1, -1),
            'summary_ids': self.bart.generate(x['input_ids'], num_beams=4, max_length=90, early_stopping=True)
        }
  

    def validation_epoch_end(self, outputs):
        """
        Runs at the end of the validation epoch. Computing Rouge Scores
        
        """
        
        logits = torch.cat([out['logits'] for out in outputs])
        labels = torch.cat([out['labels'] for out in outputs])
        losses = torch.cat([out['loss'] for out in outputs])
        summary_ids = torch.cat([out['summary_ids'] for out in outputs])
        
        # Generating Rouge Scores
        predictions = torch.argmax(logits, dim=-1)
        predictions = self.tokenizer.batch_decode(
            predictions, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=True
        )
        predictions = self.tokenizer.batch_decode(
            summary_ids
        )
        
        
        references = self.tokenizer.batch_decode(
            labels, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=True
        )
        self.logger.experiment.add_text(
            tag='example_summaries', 
            text_string=f'''
            Model Summary: {predictions[0]}
            
            Target Summary: {references[0]}''', 
            global_step=self.global_step, 
        )
        
        metric = load_metric("rouge")
        metric.add_batch(predictions=predictions, references=references)
        rouge_score = metric.compute()

        rs_1_low = rouge_score['rouge1'].low
        rs_1_mid = rouge_score['rouge1'].mid
        rs_1_high = rouge_score['rouge1'].high

        rs_L_low = rouge_score['rougeL'].low
        rs_L_mid = rouge_score['rougeL'].mid
        rs_L_high = rouge_score['rougeL'].high
        
        result = pl.EvalResult()
        result.log('val_loss', losses, sync_dist=True, reduce_fx=torch.mean)
        print(f'val_loss: {torch.mean(losses)};')
        
        result.log('learning_rate', self.hparams.learning_rate)
        
        # Rouge-1 Score (Unigrams)
        ## Low
        result.log('rs1_low_precision', rs_1_low.precision, sync_dist=True, reduce_fx=torch.mean)
        result.log('rs1_low_recall', rs_1_low.recall, sync_dist=True, reduce_fx=torch.mean)
        result.log('rs1_low_fmeasure', rs_1_low.fmeasure, sync_dist=True, reduce_fx=torch.mean)
        
        ## Mid
        result.log('rs1_mid_precision', rs_1_mid.precision, sync_dist=True, reduce_fx=torch.mean)
        result.log('rs1_mid_recall', rs_1_mid.recall, sync_dist=True, reduce_fx=torch.mean)
        result.log('rs1_mid_fmeasure', rs_1_mid.fmeasure, sync_dist=True, reduce_fx=torch.mean)
        
        ## High
        result.log('rs1_high_precision', rs_1_high.precision, sync_dist=True, reduce_fx=torch.mean)
        result.log('rs1_high_recall', rs_1_high.recall, sync_dist=True, reduce_fx=torch.mean)
        result.log('rs1_high_fmeasure', rs_1_high.fmeasure, sync_dist=True, reduce_fx=torch.mean)
        
        # Rouge-L Score
        ## Low
        result.log('rsL_low_precision', rs_L_low.precision, sync_dist=True, reduce_fx=torch.mean)
        result.log('rsL_low_recall', rs_L_low.recall, sync_dist=True, reduce_fx=torch.mean)
        result.log('rsL_low_fmeasure', rs_L_low.fmeasure, sync_dist=True, reduce_fx=torch.mean)
        
        ## Mid
        result.log('rsL_mid_precision', rs_L_mid.precision, sync_dist=True, reduce_fx=torch.mean)
        result.log('rsL_mid_recall', rs_L_mid.recall, sync_dist=True, reduce_fx=torch.mean)
        result.log('rsL_mid_fmeasure', rs_L_mid.fmeasure, sync_dist=True, reduce_fx=torch.mean)
        
        ## High
        result.log('rsL_high_precision', rs_L_high.precision, sync_dist=True, reduce_fx=torch.mean)
        result.log('rsL_high_recall', rs_L_high.recall, sync_dist=True, reduce_fx=torch.mean)
        result.log('rsL_high_fmeasure', rs_L_high.fmeasure, sync_dist=True, reduce_fx=torch.mean)
        
        return result
    
    
    def test_step(self, batch, batch_idx):
        input_ids, attn_mask, labels = batch
        
        x = {
            'input_ids': input_ids,
            'attention_mask': attn_mask,
            'labels': labels,
            'return_dict': True
        }
        
        # Run through NLP Model
        out = self.bart(**x)
                
        loss = out['loss']
        print(f'test_loss: {loss};')
        
        result = pl.EvalResult()
        result.log('test_loss', loss, sync_dist=True, reduce_fx=torch.mean)
        
        return result

    
    def configure_optimizers(self):
        """
        Recreating the same Adam optimizer used in the author's code.
        
        """
        optimizer = torch.optim.Adam(
            self.parameters(), 
            lr=self.hparams.learning_rate, 
            weight_decay=0.01, 
            betas=(0.9, 0.999), 
            eps=1e-08
        )
        return optimizer
    
    # overriding optimizer_step() so we can implement the custom learning rate warmup
    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
        """
        overriding optimizer_step() so we can implement the custom learning rate warmup.
        
        For parameter information see docs: 
        
        """
        # warm up lr

        if self.trainer.global_step < 500:
            lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * self.hparams.learning_rate

        # update params
        optimizer.step()
        optimizer.zero_grad()
    
    
    def train_dataloader(self):
        return DataLoader(torch.load(self.train_dataset), shuffle=True, batch_size=self.batch_size)

    
    def val_dataloader(self):
        return DataLoader(torch.load(self.val_dataset), shuffle=False, batch_size=self.batch_size)
    
    
    def test_dataloader(self):
        return DataLoader(torch.load(self.test_dataset), shuffle=True, batch_size=self.batch_size)

# Testing the Lightning Module

In [18]:
step, batch = next(enumerate(bart.val_dataloader()))

In [21]:
batch[0].shape

torch.Size([4, 1024])

In [23]:
batch[1].shape

torch.Size([4, 1024])

In [24]:
batch[2].shape

torch.Size([4, 90])

In [25]:
batch[3]

IndexError: list index out of range

In [11]:
bart = BartLightningModule(
    pretrained_nlp_model='sshleifer/distilbart-cnn-12-6',
    train_dataset='../data/processed/train_dataset.pt',
    test_dataset='../data/processed/test_dataset.pt',
    val_dataset='../data/processed/val_dataset.pt',
    batch_size=4
)

## Callbacks

Here we set up tensorboard logging and early stopping

In [16]:
early_stop = pl.callbacks.EarlyStopping(
    monitor='val_loss', 
    min_delta=0.001, 
    patience=3, 
    verbose=False, 
    mode='min'
)

In [17]:
tb_logger = pl.loggers.TensorBoardLogger(
    save_dir='../models/', 
    name='bart_module_testing',
)

In [18]:
lr_logger = pl.callbacks.LearningRateLogger(logging_interval='step')

In [19]:
model_checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(
    filepath='../models/bart_checkpoints', 
    monitor='val_loss', 
    mode='min', 
    save_top_k=1
)



# Original BART codebase uses learning rate of 3e-05 with polynomial decay, with 20,000 total updates and 500 warmup steps      

In [20]:
trainer = pl.Trainer(
    callbacks=[lr_logger],
    logger=tb_logger, 
    early_stop_callback=early_stop, 
    checkpoint_callback=model_checkpoint,
    max_epochs=4,
    fast_dev_run=True
)

Running in fast_dev_run mode: will run a full train, val and test loop using a single batch
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [21]:
trainer.fit(bart)


  | Name | Type                         | Params
------------------------------------------------------
0 | bart | BartForConditionalGeneration | 305 M 


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

current_epoch: 0;
global_step: 0
train_loss: 9.449356079101562;


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val_loss: 10.54311752319336;


Saving latest checkpoint..





1

In [25]:
checkpoint = torch.load(trainer.checkpoint_callback.best_model_path)

In [None]:
os.rename(trainer.checkpoint_callback.best_model_path, f'{args.model_dir}/model-checkpoint.pth')

In [324]:
with open(os.path.join('../models/bart_checkpoints', 'bart-lightning-module.pth'), 'wb') as f:
    torch.save(bart.state_dict(), f)

In [327]:
model = BartLightningModule.load_state_dict(torch.load('../models/bart_checkpoints/bart-lightning-module.pth'))

TypeError: load_state_dict() missing 1 required positional argument: 'state_dict'