In [None]:
from IPython.core.magic import register_cell_magic
import os
from pathlib import Path
import pickle

## define custom magic to save most useful classes and use them in inference notebook 
## instead of copying the code every time you have changes in the classes
@register_cell_magic
def write_and_run(line, cell):
    argz = line.split()
    file = argz[-1]
    mode = 'w'
    if len(argz) == 2 and argz[0] == '-a':
        mode = 'a'
    with open(file, mode) as f:
        f.write(cell)
    get_ipython().run_cell(cell)
    
Path('/kaggle/working/scripts').mkdir(exist_ok=True)
models_dir = Path('/kaggle/working/models')
models_dir.mkdir(exist_ok=True)

In [None]:
import os
import gc
gc.enable()
import math
import json
import time
import random
import multiprocessing
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from sklearn import model_selection

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import torch.optim as optim
from torch.utils.data import (
    Dataset, DataLoader,
    SequentialSampler, RandomSampler
)
import matplotlib.pyplot as plt
import transformers
from transformers import (
    WEIGHTS_NAME,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    logging,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
)
logging.set_verbosity_warning()
logging.set_verbosity_error()
import pytorch_lightning as pl
import collections
from time import time
from tqdm import tqdm
tqdm.pandas()

MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING)
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

In [None]:
import pickle
with open('../input/chaiiqa-sampling/xlm-roberta-base-squad2_features_kfold.pkl', 'rb') as f:
    feats_data = pickle.load(f)
features_df = pd.DataFrame(data=feats_data[1], columns=feats_data[0])
features_df['index'] = features_df['index'].apply(str)

raw_kfold = pd.read_csv('../input/chaiiqa-sampling/kfold_raw.csv', converters={'index': str, 'answers': exec})

In [None]:
%%write_and_run scripts/config.py

class Config:
    # model
    seed = 42
    model_type = "deepset/xlm-roberta-base-squad2"
#     gradient_accumulation_steps = 2

    max_seq_length = 384
    doc_stride = 128

    # train
    epochs = 5
    train_batch_size = 8
    eval_batch_size = 8

    # optimizer
    lr = 0.5e-5
    weight_decay = 1e-2
    epsilon = 1e-8
    max_grad_norm = 1.0

    # scheduler
    warmup_ratio = 0.1

    
pl.utilities.seed.seed_everything(Config.seed, workers=True)


### Dataset 

In [None]:
%%write_and_run scripts/dataset.py


class QADataset(Dataset):
    def __init__(self, features, mode='train'):
        super().__init__()
        self.features = features
        self.mode = mode
        self.size = len(self.features)
        
    def __len__(self):
        return self.size
    
    def __getitem__(self, item):   
        feature = self.features.iloc[item]
        if self.mode == 'train':
            return {
                'index':feature['index'],
                'input_ids':torch.tensor(feature['input_ids'], dtype=torch.long),
                'attention_mask':torch.tensor(feature['attention_mask'], dtype=torch.long),
                'offset_mapping':torch.tensor(feature['offset_mapping'], dtype=torch.long),
                'start_position':torch.tensor(feature['start_position'], dtype=torch.long),
                'end_position':torch.tensor(feature['end_position'], dtype=torch.long)
            }
        else:
            return {
                'input_ids':torch.tensor(feature['input_ids'], dtype=torch.long),
                'attention_mask':torch.tensor(feature['attention_mask'], dtype=torch.long),
                'index':feature['index'],
            }

### Model

In [None]:
%%write_and_run scripts/postprocess.py


def postprocess_qa_predictions(samples, all_features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    CONTEXT_INDEX = 1
    k = 0
    
    predictions = collections.OrderedDict()
    for _, sample in samples.iterrows():
     
        min_null_score = None
        valid_answers = []
        
        context = sample["context"]
        features = all_features[all_features['index']==sample['index']]
        for feat_i, (_, feature) in enumerate(features.iterrows()):

            start_logits = raw_predictions[sample['index']][feat_i][0]
            end_logits = raw_predictions[sample['index']][feat_i][1]

            sequence_ids = feature["sequence_ids"]

            offset_mapping = [
                (o if sequence_ids[k] == CONTEXT_INDEX else None)
                for k, o in enumerate(feature["offset_mapping"])
            ]

            cls_index = feature["input_ids"].index(tokenizer.cls_token_id)
            
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]

            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            start_indices = np.argsort(-start_logits)[:n_best_size].tolist()
            end_indices = np.argsort(-end_logits)[:n_best_size].tolist()

            for start_index in start_indices:
                for end_index in end_indices:
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                        or end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue
                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]

                    valid_answers.append(
                        {
                            'score': start_logits[start_index] + end_logits[end_index],
                            'text': context[start_char: end_char].strip(),
                            'token_positions': (start_index, end_index),
                            'feature_n': feat_i
                        }
                    )
                    
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x['score'], reverse=True)[0]
        else:
            best_answer = {'text': '', 'score': 0.0}
        
        predictions[sample['index']] = best_answer
    return predictions

In [None]:
%%write_and_run scripts/model.py

class PTModel(nn.Module):
    def __init__(self, net, config):
        super().__init__()
        self.net = net
        self.config = config
        self.qa_outputs = nn.Linear(config.hidden_size, 2)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.qa_outputs.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_normal_(module.weight)
            module.bias.data.zero_()
                
                
    def forward(self, input_ids, attention_mask):
        outputs = self.net(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        sequence_output = outputs[0]
        pooled_output = outputs[1]
        
        # sequence_output = self.dropout(sequence_output)
#         print(sequence_output.shape)
        qa_logits = self.qa_outputs(sequence_output)
        
        start_logits, end_logits = qa_logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        return start_logits, end_logits 
    
class Model(pl.LightningModule):
    def __init__(self, ptmodel, loss_fn=None, params=None):
        super().__init__()
        self.ptmodel = ptmodel
        self.loss_fn = loss_fn
        self.params = params 
        self.best_score = float('-inf')
                
    def forward(self, input_ids, attention_mask):
        return self.ptmodel(input_ids, attention_mask) 
    
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, start_position, end_position = batch['input_ids'], batch['attention_mask'], batch['start_position'], batch['end_position']
        y_hat_start, y_hat_end = self(input_ids, attention_mask)
        loss = self.loss_fn((y_hat_start, y_hat_end), (start_position, end_position))
        return {
            'loss': loss,
        }
  
    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, start_position, end_position = batch['input_ids'], batch['attention_mask'], batch['start_position'], batch['end_position']
        y_hat_start, y_hat_end = self(input_ids, attention_mask)
        loss = self.loss_fn((y_hat_start, y_hat_end), (start_position, end_position))
        
        
        return {
            'loss': loss,
            'batch_index': batch['index'],
            'y_hat_start': y_hat_start.cpu().numpy(),
            'y_hat_end': y_hat_end.cpu().numpy(),
        }
    
    def predict_step(self, batch, batch_idx, dataloader_idx):
        input_ids, attention_mask = batch['input_ids'], batch['attention_mask']
        y_hat_start, y_hat_end = self(input_ids, attention_mask)
        return batch['index'], y_hat_start.cpu().numpy(), y_hat_end.cpu().numpy()
  
    def validation_epoch_end(self, validation_step_outputs):
        example_id_to_preds = collections.defaultdict(list)
        losses = []
        for out in validation_step_outputs:
#             print(out.keys())
#             print(out['batch_index'])
            losses.append(out['loss'].item())
            
            batch_indices, batch_start_preds, batch_end_preds = out['batch_index'], out['y_hat_start'], out['y_hat_end']
            for index, feat_start_preds, feat_end_preds in zip(batch_indices, batch_start_preds, batch_end_preds):
                example_id_to_preds[index].append((feat_start_preds, feat_end_preds))
            

        ### calc val_loss and save if val_loss is less than min_loss
        val_loss = sum(losses) / len(losses)
        print('Validation loss:', val_loss)

            
        #### postprocessing
        
        final_preds = postprocess_qa_predictions(self.params['val_raw_df'], self.params['val_feats_df'], example_id_to_preds, n_best_size=20, max_answer_length=30)
        
        jaccard_scores = []
        for _, sample in self.params['val_raw_df'].iterrows():
            y = sample['answer_text']
            index = sample['index']
            y_hat = final_preds[index]['text']
            jaccard_score = jaccard(y, y_hat)
            jaccard_scores.append(jaccard_score)
#             print(y, y_hat, '->', jaccard_score)
            
        avg_jaccard_score = sum(jaccard_scores) / len(jaccard_scores)
        print('Validation Jaccard score:', round(avg_jaccard_score, 3))
        
        if avg_jaccard_score > self.best_score:
            if self.params.get('save_best'):
                if self.best_score != float('-inf'):
                    (self.params['models_dir'] / f"model_fold_{self.params['fold']}_score_{self.best_score:.3f}.pth").unlink()
                torch.save(ptmodel, self.params['models_dir'] / f"model_fold_{self.params['fold']}_score_{avg_jaccard_score:.3f}.pth")
            self.best_score = avg_jaccard_score



    def configure_optimizers(self):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.params['weight_decay'],
            },
            {
                "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        
        
#         no_decay = ["bias", "LayerNorm.weight"]
#         group1=['layer.0.','layer.1.','layer.2.','layer.3.']
#         group2=['layer.4.','layer.5.','layer.6.','layer.7.']    
#         group3=['layer.8.','layer.9.','layer.10.','layer.11.']
#         group_all=['layer.0.','layer.1.','layer.2.','layer.3.','layer.4.','layer.5.','layer.6.','layer.7.','layer.8.','layer.9.','layer.10.','layer.11.']
#         optimizer_grouped_parameters = [
#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if not any(nd in n for nd in no_decay) and not any(nd in n for nd in group_all)],'weight_decay': self.params['weight_decay']},
#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if not any(nd in n for nd in no_decay) and any(nd in n for nd in group1)],'weight_decay': self.params['weight_decay'], 'lr': self.params['lr']/2.6},
#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if not any(nd in n for nd in no_decay) and any(nd in n for nd in group2)],'weight_decay': self.params['weight_decay'], 'lr': self.params['lr']},
#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if not any(nd in n for nd in no_decay) and any(nd in n for nd in group3)],'weight_decay': self.params['weight_decay'], 'lr': self.params['lr']*2.6},

#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if any(nd in n for nd in no_decay) and not any(nd in n for nd in group_all)],'weight_decay': 0.0},
#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if any(nd in n for nd in no_decay) and any(nd in n for nd in group1)],'weight_decay': 0.0, 'lr': self.params['lr']/2.6},
#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if any(nd in n for nd in no_decay) and any(nd in n for nd in group2)],'weight_decay': 0.0, 'lr': self.params['lr']},
#             {'params': [p for n, p in model.ptmodel.net.named_parameters() if any(nd in n for nd in no_decay) and any(nd in n for nd in group3)],'weight_decay': 0.0, 'lr': self.params['lr']*2.6},
#             {'params': [p for n, p in model.named_parameters() if 'net' not in n], 'lr': self.params['lr']*5, "weight_decay": 0.0},
#         ]


        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.params['lr'],
            eps=self.params['epsilon'],
#             correct_bias=True
        )
        
#         scheduler = get_cosine_schedule_with_warmup(
#             optimizer,
#             num_warmup_steps=self.params['num_warmup_steps'],
#             num_training_steps=self.params['num_training_steps']
#         )    

        return optimizer

    def on_train_epoch_start(self):
        self.epoch_start = time()
        
    def on_train_epoch_end(self):
        print('Epoch took:', time() - self.epoch_start, 'secs')

    def on_validation_epoch_start(self):
        self.val_epoch_start = time()
        
    def on_validation_epoch_end(self):
        print('Validation epoch took:', time() - self.val_epoch_start, 'secs')


### Loss

In [None]:
def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))




def loss_fn(preds, labels):
    start_preds, end_preds = preds
    start_labels, end_labels = labels
    
    start_loss = nn.CrossEntropyLoss(ignore_index=-1)(start_preds, start_labels)
    end_loss = nn.CrossEntropyLoss(ignore_index=-1)(end_preds, end_labels)
    total_loss = (start_loss + end_loss) / 2
    return total_loss

def get_scheduler_steps(dl_len):
    num_training_steps = dl_len * Config.epochs
    if Config.warmup_ratio > 0:
        num_warmup_steps = int(Config.warmup_ratio * num_training_steps)
    else:
        num_warmup_steps = 0
    
    return num_training_steps, num_warmup_steps



### Run

In [None]:
for fold in range(1):
    print(f'Fold#{fold+1}')
    train_ds = QADataset(features_df[features_df.kfold!=fold])
    val_feats_fold_df = features_df[features_df.kfold==fold]
    val_ds = QADataset(val_feats_fold_df)
    val_raw_fold_df = raw_kfold[raw_kfold.kfold==fold]
    
    train_dl = DataLoader(
        train_ds,
        batch_size=Config.train_batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        drop_last=False 
    )

    val_dl = DataLoader(
        val_ds,
        batch_size=Config.eval_batch_size, 
        shuffle=False,
        num_workers=8,
        pin_memory=True, 
        drop_last=False
    )

    config = AutoConfig.from_pretrained(Config.model_type)
    
    tokenizer = AutoTokenizer.from_pretrained(Config.model_type)
    torch.save(tokenizer, 'tokenizer.pth')
    net = AutoModel.from_pretrained(Config.model_type, config)
    num_training_steps, num_warmup_steps = get_scheduler_steps(len(train_dl))
    
    ptmodel = PTModel(net, config)
    model = Model(
        ptmodel,
        loss_fn=loss_fn,
        params = {
            'num_training_steps': num_training_steps,
            'num_warmup_steps': num_warmup_steps,
            'lr': Config.lr,
            'epsilon': Config.epsilon,
            'weight_decay': Config.weight_decay,
            'models_dir': models_dir,
            'fold': fold,
            'save_best': True,
            'val_raw_df': val_raw_fold_df,
            'val_feats_df': val_feats_fold_df
        }
    )  
    
    
    trainer = pl.Trainer(
        fast_dev_run=False, max_epochs=Config.epochs,
        checkpoint_callback=False,
         gpus=1,precision=16,
#          auto_lr_find=True,  
        limit_train_batches=1.0, limit_val_batches=1.0, 
         num_sanity_val_steps=0, val_check_interval=0.33,

     )
    
    
    trainer.fit(model, train_dl, val_dl)

    torch.cuda.empty_cache()
    del trainer, net
    del model, config
    del train_dl, val_dl, train_ds, val_ds
    gc.collect()


In [None]:
!ls models