In [116]:
# Install required libraries
!pip install datasets transformers --quiet
!pip install pytorch-lightning --quiet

# Regular imports
import numpy as np
import os
import pytorch_lightning as pl
import torch
import transformers

# Specific imports
from argparse import ArgumentParser
from datasets import load_dataset, load_metric
from itertools import compress
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from tqdm import tqdm


In [117]:
# Set up models
model_checkpoint = 'distilbert-base-uncased'
transformer_model = transformers.AutoModel.from_pretrained(model_checkpoint)
transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint)
transformer_config = transformers.AutoConfig.from_pretrained(model_checkpoint)

In [118]:
# Load the squad 1.1 dataset directly from transformers
datasets = load_dataset('squad')
metric = load_metric('squad')
print(datasets)

## Prepare data
A script to collect and prepare the features which will be used for training. 

In [119]:
def prepare_train_features(examples):

    # Maximum length of a feature (question and context)
    max_length = 384 
    # Authorized overlap between two part of the context 
    # when splitting it is needed.
    doc_stride = 128 
    tokenizer = transformer_tokenizer
    # Padding side determines if we do (question|context) 
    # or (context|question).
    pad_on_right = tokenizer.padding_side == 'right'
    
    # Tokenize our examples with truncation and padding, but keep the overflows
    # using a stride. This results in one example possible giving several 
    # features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples['question' if pad_on_right else 'context'],
        examples['context' if pad_on_right else 'question'],
        truncation='only_second' if pad_on_right else 'only_first',
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length',
    )

    # Since one example might give us several features if it has a long context, 
    # we need a map from a feature to its corresponding example. 
    # This key gives us just that.
    sample_mapping = tokenized_examples.pop('overflow_to_sample_mapping')
    # The offset mappings will give us a map from token to character position 
    # in the original context. This will help us compute the start_positions 
    # and end_positions.
    offset_mapping = tokenized_examples.pop('offset_mapping')

    # Let's label those examples!
    tokenized_examples['start_positions'] = []
    tokenized_examples['end_positions'] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples['input_ids'][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is 
        # the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the 
        # example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples['answers'][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers['answer_start']) == 0:
            tokenized_examples['start_positions'].append(cls_index)
            tokenized_examples['end_positions'].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers['answer_start'][0]
            end_char = start_char + len(answers['text'][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span 
            # (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and 
                    offsets[token_end_index][1] >= end_char):
                tokenized_examples['start_positions'].append(cls_index)
                tokenized_examples['end_positions'].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index 
                # to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the 
                # last word (edge case).
                while token_start_index < len(offsets) and \
                        offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples['start_positions'].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples['end_positions'].append(token_end_index + 1)

    return tokenized_examples


## Define the SQuAD data model

In [120]:
class SquadDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32, num_workers: int = 2):
        super().__init__()
   
        # Defining batch size of our data
        self.batch_size = batch_size
          
        # Defining num_workers
        self.num_workers = num_workers

        # Defining Tokenizers
        self.tokenizer = transformer_tokenizer
  
    def prepare_data(self):
        self.train_data = datasets['train']
        self.val_data = datasets['validation']
  
    def setup(self, stage=None):
        # Loading the dataset
        self.train_dataset = self.train_data.map(
            prepare_train_features, 
            batched=True, 
            remove_columns=self.train_data.column_names
        )
        self.val_dataset = self.val_data.map(
            prepare_train_features, 
            batched=True, 
            remove_columns=self.val_data.column_names
        )
  
    def custom_collate(self,features):
        ## Pad the Batched data
        batch = self.tokenizer.pad(  
            features,
            padding=True,
            return_tensors='pt',
        )
        return batch
        
    def train_dataloader(self):
        #dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        #return DataLoader(train_dataset, sampler=dist_sampler, batch_size=32)
        return DataLoader(
            self.train_dataset, 
            shuffle=True, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers, 
            collate_fn=self.custom_collate
        )

    def val_dataloader(self):
         return DataLoader(
             self.val_dataset,
             batch_size=self.batch_size, 
             num_workers=self.num_workers, 
             collate_fn=self.custom_collate
         )

    # def test_dataloader(self):
    #     return DataLoader(self.test_dataset, batch_size=self.batch_size, 
    #                       num_workers=self.num_workers, collate_fn=self.custom_collate)

    def predict_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers, 
            collate_fn=self.custom_collate
        )

## Define the SQuAD QA model

In [121]:
class SquadQAModel(pl.LightningModule):

    def __init__(self, **kwargs):
        super().__init__()
        
        self.transformer = transformer_model
        # extract transformer name
        transformer_name = self.transformer.name_or_path
        # extract AutoConfig, from which relevant parameters can be extracted.
        transformer_config = transformers.AutoConfig.from_pretrained(transformer_name)

        self.num_labels = transformer_config.num_labels

        self.qa_outputs = torch.nn.Linear(
            transformer_config.hidden_size, 
            transformer_config.num_labels
        )
    
    def forward(self, batch)-> torch.Tensor:
        '''Model Forward Iteration
        Args:
            input_ids (torch.Tensor): Input IDs.
            masks (torch.Tensor): Attention Masks.
        Returns:
            torch.Tensor: predicted values.
        '''        

        outputs = self.transformer(
            input_ids = batch['input_ids'],
            attention_mask=batch['attention_mask']
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return {'start_logits':start_logits, 'end_logits':end_logits }



## Define the QA model

In [122]:
## The main Pytorch Lightning module
class SquadQuestionAnswering(pl.LightningModule):

    def __init__(self, learning_rate: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        # Metrics   
        self.model = SquadQAModel()         

    def training_step(self, batch, batch_nb):
        start_positions = batch['start_positions']
        end_positions = batch['end_positions']

        # fwd
        output = self.model(batch)
        start_logits = output['start_logits']
        end_logits = output['end_logits']
        
        total_loss = self.compute_loss(
            start_positions, 
            end_positions, 
            start_logits, 
            end_logits
        )
 
        return total_loss
 
    def validation_step(self, batch, batch_nb):
        start_positions = batch['start_positions']
        end_positions = batch['end_positions']

        # fwd
        output = self.model(batch)
        start_logits = output['start_logits']
        end_logits = output['end_logits']
        
        # loss
        total_loss = self.compute_loss(
            start_positions, 
            end_positions, 
            start_logits, 
            end_logits
        )
        
        # Calling self.log will surface up scalars for you in TensorBoard
        self.log_dict({'val_loss':total_loss}, prog_bar=True)
        return total_loss

    def test_step(self, batch, batch_nb):
        # loss - No test data in the dataset and that is why the loss is None
        total_loss = None
        self.log_dict({'test_loss':total_loss}, prog_bar=True)
        return total_loss
    
    def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
        # fwd
        output = self.model(batch)
        start_logits = output['start_logits']
        end_logits = output['end_logits']

        return {'start_logits':start_logits, 'end_logits': end_logits}

        return 
    def compute_loss(self, start_positions, end_positions, start_logits, end_logits):
        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # Sometimes the start/end positions are outside our model inputs.
            # We ignore these terms
            ignored_index = start_logits.size(1)
            # using clamp_(min, max) to make sure start_positions and 
            # end_positions don't go beyond max
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
        return total_loss 

    # ---------------------
    # TRAINING SETUP
    # ---------------------
    def configure_optimizers(self):
        '''
        Configure the optimizer (Adam) and the scheduler to use
        :return: The optimizer and the scheduler
        '''
        # Here, we could also use AdamW (Adam with weight decay), 
        # for improved generilization and speed of optimization
        optimizer = torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad], 
            lr=self.hparams.learning_rate, 
            eps=1e-08
        )
        scheduler = {
        'scheduler': torch.optim.lr_scheduler.OneCycleLR(
            optimizer, 
            max_lr=2e-5, 
            steps_per_epoch=len(self.trainer.datamodule.train_dataloader()), 
            epochs=self.hparams.max_epochs
        ),
        'interval': 'step'  # called after each training step
        } 
        #scheduler = StepLR(optimizer, step_size=1, gamma=0.2)
        #scheduler = torch.optim.lr_scheduler.CyclicLR(
        #    optimizer, 
        #    base_lr=1e-7, 
        #    max_lr=1e-4, 
        #    cycle_momentum=False,
        #    step_size_up=300
        #)
        #scheduler = ReduceLROnPlateau(optimizer, patience=0, factor=0.2)
        self.sched = scheduler
        self.optim = optimizer
        return [optimizer], [scheduler]
 
    @staticmethod
    def add_model_specific_args(parent_parser, root_dir):
        '''
        Define parameters that only apply to this model
        '''
        parser = ArgumentParser(parents=[parent_parser])

        # network params
        parser.add_argument('--drop_prob', default=0.2, type=float)

        # data
        parser.add_argument('--data_root', 
                            default=os.path.join(root_dir, 'train_val_data'), 
                            type=str)

        # training params (opt)
        parser.add_argument('--learning_rate', 
                            default=2e-5, 
                            type=float, 
                            help = 'type (default: %(default)f)')
        return parser

## Define training parameters

In [123]:
root_dir = os.getcwd()
parent_parser = ArgumentParser(add_help=False)
parent_parser = pl.Trainer.add_argparse_args(parent_parser)

# each LightningModule defines arguments relevant to it
parser = SquadQuestionAnswering.add_model_specific_args(parent_parser,root_dir)

parser.set_defaults(
    #profiler='simple',
    deterministic=True,
    max_epochs=3,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    limit_test_batches=1.0,
    gpus=1,
    distributed_backend=None,
    fast_dev_run=False,
    model_load=False,
    model_name='best_model',
)

args, extra = parser.parse_known_args()

''' Main training routine specific for this project. '''
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
if (vars(args)['model_load']):
  model = SquadQuestionAnswering.load_from_checkpoint(vars(args)['model_name'])
else:  
  model = SquadQuestionAnswering(**vars(args))

# ------------------------
# 2 CALLBACKS of MODEL
# ------------------------

# callbacks
early_stop = EarlyStopping(
    monitor='val_loss',
    min_delta=0.0,
    patience=3,
    verbose=True,
    mode='min',
    strict=True,
)

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
     monitor='val_loss',
     #dirpath='my/path/',
     filename='squad-questionanswer-epoch{epoch:02d}-val_loss{val_loss:.2f}',
     auto_insert_metric_name=False
)

# ------------------------
# 3 INIT TRAINER
# ------------------------
trainer = Trainer.from_argparse_args(args,
    callbacks=[early_stop,lr_monitor, checkpoint_callback]
    )    

seed_everything(42, workers=True)
squad_dm = SquadDataModule()

## Train

In [None]:
trainer.fit(model,squad_dm)


## Prepare features for validation

In [None]:
def prepare_validation_features(examples):
    # The maximum length of a feature (question and context)
    max_length = 384 
    # The authorized overlap between two part of the context when splitting it is needed.
    doc_stride = 128 
    tokenizer = transformer_tokenizer
    # Padding side determines if we do (question|context) or (context|question).
    pad_on_right = tokenizer.padding_side == 'right'

    # Tokenize our examples with truncation and maybe padding, 
    # but keep the overflows using a stride. This results
    # in one example possible giving several features when a 
    # context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples['question' if pad_on_right else 'context'],
        examples['context' if pad_on_right else 'question'],
        truncation='only_second' if pad_on_right else 'only_first',
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length',
    )

    # Since one example might give us several features if it has a 
    # long context, we need a map from a feature to its corresponding example. 
    # This key gives us just that.
    sample_mapping = tokenized_examples.pop('overflow_to_sample_mapping')

    # We keep the example_id that gave us this feature and we will store the 
    # offset mappings.
    tokenized_examples['example_id'] = []

    for i in range(len(tokenized_examples['input_ids'])):
        # Grab the sequence corresponding to that example 
        # (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of 
        # the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples['example_id'].append(examples['id'][sample_index])

        # Set to None the offset_mapping that are not part of the context 
        # so it's easy to determine if a token position is part of 
        # the context or not.
        tokenized_examples['offset_mapping'][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples['offset_mapping'][i])
        ]

    return tokenized_examples


In [None]:
## Define a smaller validation set to quickly assess performance
dataset_valid = datasets['validation'].select([idx for idx in range(16)])
# dataset_valid = datasets['validation']
validation_features = dataset_valid.map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets['validation'].column_names
)

In [None]:
def custom_collate(features):
    tokenizer = transformer_tokenizer
    ## Have to remove these to make the tensors
    example_ids = [feature.pop('example_id') for feature in features]
    offset_mapping =  [feature.pop('offset_mapping') for feature in features]
    ## Pad the Batched data
    batch = tokenizer.pad(  
        features,
        padding=True,
        return_tensors='pt',
    )
    return batch

val_dataloader = DataLoader(
    validation_features,
    batch_size=8, 
    num_workers=2, 
    collate_fn=custom_collate
)

## Get predictions

In [None]:
raw_predictions = trainer.predict(model,dataloaders=val_dataloader)

In [None]:
from tqdm.auto import tqdm
import collections

def postprocess_qa_predictions(
        examples, 
        features, 
        raw_predictions, 
        n_best_size = 20, 
        max_answer_length = 30
):
    tokenizer = transformer_tokenizer
    all_start_logits = torch.cat([predict['start_logits'] \
                                  for predict in raw_predictions]).cpu().numpy()
    all_end_logits = torch.cat([predict['end_logits'] \
                                for predict in raw_predictions]).cpu().numpy()
    example_id_to_index = {k: i for i, k in enumerate(examples['id'])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature['example_id']]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f'Post-processing {len(examples)} example ' \
          'predictions split into {len(features)} features.')

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]
        valid_answers = []
        
        context = example['context']
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions 
            # in our logits to span of texts in the original context.
            offset_mapping = features[feature_index]['offset_mapping']

            # Update minimum null prediction.
            cls_index = features[feature_index]['input_ids'].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]

            # Go through all possibilities for the `n_best_size` 
            # greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, 
                    # either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is 
                    # either < 0 or > max_answer_length.
                    if (end_index < start_index or 
                        end_index - start_index + 1 > max_answer_length):
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            'score': start_logits[start_index] + end_logits[end_index],
                            'text': context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x['score'], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, 
            # we create a fake prediction to avoid failure.
            best_answer = {'text': '', 'score': 0.0}
        
        answer = best_answer['text'] # if best_answer['score'] > min_null_score else ''
        predictions[example['id']] = answer

    return predictions

In [None]:
final_predictions = postprocess_qa_predictions(
    dataset_valid, 
    validation_features, 
    raw_predictions
)

In [None]:
final_predictions['56be4db0acb8001400a502ec']

In [None]:
metric = load_metric('squad')

In [None]:
formatted_predictions = [{'id': k, 'prediction_text': v} for k, v in final_predictions.items()]

In [None]:
references = [{'id': ex['id'], 'answers': ex['answers']} for ex in dataset_valid]
metric.compute(predictions=formatted_predictions, references=references)

In [None]:
## Define a smaller validation set to quickly assess performance
dataset_valid = datasets['validation'].select([idx for idx in range(16)])
validation_features = dataset_valid.map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets['validation'].column_names
)

In [None]:
val_dataloader = DataLoader(
    validation_features,
    batch_size=8, 
    num_workers=2, 
    collate_fn=custom_collate
)

In [None]:
#model = SquadQuestionAnswering.load_from_checkpoint('lightning_logs/version_0/checkpoints/squad-questionanswer-epoch00-val_loss0.20.ckpt')
model = SquadQuestionAnswering()
raw_predictions = trainer.predict(model,dataloaders=val_dataloader)


In [None]:
final_predictions = postprocess_qa_predictions(
    dataset_valid, 
    validation_features, 
    raw_predictions
)
print(final_predictions)

In [None]:
# Save model to disk
torch.save(trainer, 'model.pt')