In [None]:
CUDA_LAUNCH_BLOCKING="1"

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

data_path = '/kaggle/input/data-split/chaii_folds.csv'
df = pd.read_csv(data_path)
df.head()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from transformers import (
    AutoConfig,
    AutoModel,
    AdamW,
    AutoTokenizer,
    get_linear_schedule_with_warmup
)

In [None]:
import pandas as pd
from datasets import Dataset

class Config():
    #data
    trainval_path = '/kaggle/input/data-split/chaii_folds.csv'
    test_path = '/kaggle/input/chaii-hindi-and-tamil-question-answering/test.csv'
    pretrained_files_path = '/kaggle/input/muril-large-pt/muril-large-cased'
    # model
    model_name = pretrained_files_path
    config_name = pretrained_files_path
    gradient_accumulation_steps = 2

    # tokenizer
    tokenizer_name = pretrained_files_path
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    max_seq_length = 384
    doc_stride = 128

    # train
    epochs = 20
    train_batch_size = 4
    eval_batch_size = 8

    # optimizer
    optimizer_type = 'AdamW'
    learning_rate = 1.5e-5
    weight_decay = 1e-2
    adam_epsilon = 1e-8
    max_grad_norm = 1.0

    # scheduler
    warmup_ratio = 0.1

    # logging
    logging_steps = 10

    # evaluate
    output_dir = 'output'
    seed = 42
    
    #inference
    pred_topk_logits = 10
    pred_max_length = 300
    

In [None]:
class ChaiiDatamodule(pl.LightningDataModule):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.tv_df = pd.read_csv(config.trainval_path)
        self.test_df = pd.read_csv(config.test_path)
        
    def setup(self,val_fold=1,stage=None):
#         column_names = ['id','context','question','answer_text','answer_start']
#         self.training_data = self.tv_df[self.tv_df.folds!=val_fold][column_names]
#         self.validation_data = self.tv_df[self.tv_df.folds==val_fold][column_names]
        self.testing_data = self.test_df[['id','context','question']]
        
#         self.training_dataset = Dataset.from_pandas(self.training_data)
#         self.training_dataset = self.preprocess(self.training_dataset,self.config.train_batch_size)
        
#         self.validation_dataset = Dataset.from_pandas(self.validation_data)
#         self.validation_dataset= self.preprocess(self.validation_dataset,self.config.eval_batch_size)
        
        self.testing_dataset = Dataset.from_pandas(self.testing_data)
        self.testing_dataset = self.preprocess(self.testing_dataset,self.config.eval_batch_size)
    
    def preprocess_(self,example):
        tokenizer = self.config.tokenizer
        
        example["question"] = [q.strip() for q in example["question"]]
        example["context"] = [c.strip() for c in example["context"]]
        
        tokenized_example = tokenizer(
            example["question"],
            example["context"],
            truncation="only_second",
            max_length=self.config.max_seq_length,
            stride=self.config.doc_stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )
        sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
        offset_mapping = tokenized_example["offset_mapping"]

        tokenized_example['id']=[]
        tokenized_example['start_idx']=[]
        tokenized_example['end_idx']=[]
        
        for i, offsets in enumerate(offset_mapping):
            input_ids = tokenized_example["input_ids"][i]
            attention_mask = tokenized_example["attention_mask"][i]
            
            cls_index = input_ids.index(tokenizer.cls_token_id)
            sequence_ids = tokenized_example.sequence_ids(i)
            sample_index = sample_mapping[i]
            tokenized_example['id'].append(example['id'][sample_index])
            
            if "answer_text" not in example:
                tokenized_example["start_idx"].append(cls_index)
                tokenized_example["end_idx"].append(cls_index)
                continue
                
            answer_text = example["answer_text"][sample_index]
            if len(answer_text) == 0:
                tokenized_example["start_idx"].append(cls_index)
                tokenized_example["end_idx"].append(cls_index)
                continue
            
            start_char = example["answer_start"][sample_index]
            end_char = start_char + len(answer_text)

            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            if offsets[token_start_index][0] > end_char or offsets[token_end_index][1] < start_char:
                #answer and this span are disjoint -bugfix- why confuse model?
                tokenized_example["start_idx"].append(cls_index)
                tokenized_example["end_idx"].append(cls_index)
                continue
            
            #binary search! ffffaaassstttt TODO
            while token_start_index < len(offsets) and offsets[token_start_index][1] < start_char:
                token_start_index += 1
            tokenized_example["start_idx"].append(token_start_index)
            
            while token_end_index > token_start_index and offsets[token_end_index][0] > end_char:
                #assuming : answers are not subwords
                token_end_index -= 1
            tokenized_example["end_idx"].append(token_end_index)
    
        return tokenized_example
    
    def preprocess(self,dataset,batch_size):
        dataset = dataset.map( 
            self.preprocess_,
            batched=True,
            remove_columns=dataset.column_names
        )
        dataset.set_format(
            type='torch',
            columns=['attention_mask', 'input_ids', 'start_idx', 'end_idx'],
            output_all_columns=True
        )
        return dataset
    
    def train_dataloader(self):
        return DataLoader(self.training_dataset,self.config.train_batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.validation_dataset,self.config.eval_batch_size)
    
    def test_dataloader(self): 
        return DataLoader(self.testing_dataset,self.config.eval_batch_size)

In [None]:
from math import inf
class ChaiiHFmodel(pl.LightningModule):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.model_config = AutoConfig.from_pretrained(config.config_name)
        self.hf_model = AutoModel.from_pretrained(config.model_name,config = self.model_config)
        self.linear_layer = nn.Linear(self.model_config.hidden_size,2) #hidden dim -> start/end logits
    
    def forward(self,input_ids,attention_mask):
        outputs = self.hf_model(input_ids,attention_mask = attention_mask)
        se_logits = self.linear_layer(outputs[0])
        return (x.squeeze(-1) for x in se_logits.split(1,dim=-1))
    
    def loss_fn(self, s_h,e_h, s,e):
        loss = nn.CrossEntropyLoss()
        s_loss = loss(s_h,s)
        e_loss = loss(e_h,e)
        return (s_loss + e_loss)/2
    
    def shared_step(self, batch, compute_loss = True):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        s,e = self(input_ids,attention_mask)
        if compute_loss:
            return self.loss_fn(s,e,batch['start_idx'],batch['end_idx']),(s,e)
        else:
            return (s,e)
    
    def training_step(self, batch, batch_idx):
        loss,_ = self.shared_step(batch)
        self.log('train_loss',loss)
        return loss
    
    def validation_step(self,batch, batch_idx):
        loss,(s,e) = self.shared_step(batch)
        self.log('val_loss',loss)
        
        start_index = batch['start_idx'].detach().cpu().tolist()
        end_index = batch['end_idx'].detach().cpu().tolist()
        input_ids = batch['input_ids'].detach().cpu().tolist()
        
        answers = [self.id_to_string(start_index[b],end_index[b],input_ids[b]) for b in range(len(input_ids))]
        valid_range = [(-inf,inf) for b in range(len(input_ids))] #correct this TODO
        
        predicted_answers = self.postprocess_(s,e,valid_range,batch['input_ids'])
        ious = [self.jaccard_score(x,y) for x,y in zip(predicted_answers,answers) if len(y)>0]
        if len(ious)>0:
            iou = sum(ious)/max(1,len(ious))
            self.log('val_metric',iou)
        return loss
    
    def setup(self, stage=None):
        if stage != "fit":
            return
        # Get dataloader by calling it - train_dataloader() is called after setup() by default
        train_loader = self.train_dataloader()

        # Calculate total steps
        tb_size = self.config.train_batch_size #* max(1, self.trainer.gpus)
        ab_size = self.trainer.accumulate_grad_batches
        print('Max epochs',self.trainer.max_epochs)
        self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size
        print(self.total_steps,'Total size')


    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self
        args = self.config
        no_decay = ["bias", "LayerNorm.weight"]
        group1=['layer.0.','layer.1.','layer.2.','layer.3.']
        group2=['layer.4.','layer.5.','layer.6.','layer.7.']    
        group3=['layer.8.','layer.9.','layer.10.','layer.11.']
        group_all=['layer.0.','layer.1.','layer.2.','layer.3.','layer.4.','layer.5.','layer.6.','layer.7.','layer.8.','layer.9.','layer.10.','layer.11.']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.hf_model.named_parameters() if not any(nd in n for nd in no_decay) and not any(nd in n for nd in group_all)],'weight_decay': args.weight_decay},
            {'params': [p for n, p in model.hf_model.named_parameters() if not any(nd in n for nd in no_decay) and any(nd in n for nd in group1)],'weight_decay': args.weight_decay, 'lr': args.learning_rate/2.6},
            {'params': [p for n, p in model.hf_model.named_parameters() if not any(nd in n for nd in no_decay) and any(nd in n for nd in group2)],'weight_decay': args.weight_decay, 'lr': args.learning_rate},
            {'params': [p for n, p in model.hf_model.named_parameters() if not any(nd in n for nd in no_decay) and any(nd in n for nd in group3)],'weight_decay': args.weight_decay, 'lr': args.learning_rate*2.6},
            {'params': [p for n, p in model.hf_model.named_parameters() if any(nd in n for nd in no_decay) and not any(nd in n for nd in group_all)],'weight_decay': 0.0},
            {'params': [p for n, p in model.hf_model.named_parameters() if any(nd in n for nd in no_decay) and any(nd in n for nd in group1)],'weight_decay': 0.0, 'lr': args.learning_rate/2.6},
            {'params': [p for n, p in model.hf_model.named_parameters() if any(nd in n for nd in no_decay) and any(nd in n for nd in group2)],'weight_decay': 0.0, 'lr': args.learning_rate},
            {'params': [p for n, p in model.hf_model.named_parameters() if any(nd in n for nd in no_decay) and any(nd in n for nd in group3)],'weight_decay': 0.0, 'lr': args.learning_rate*2.6},
            {'params': [p for n, p in model.named_parameters() if 'hf_model' not in n], 'lr':args.learning_rate*20, "weight_decay": 0.0},
        ]

#         for n,p in model.named_parameters():
#             print(n)
        
        optimizer = AdamW([p for n,p in self.named_parameters()],lr = self.config.learning_rate, eps = self.config.adam_epsilon)
#         optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.total_steps * self.config.warmup_ratio,
            num_training_steps=self.total_steps,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]
    
    def jaccard_score(self, str1, str2):
        a = set(str1.lower().split()) 
        b = set(str2.lower().split())
        c = a.intersection(b)
        return float(len(c)) / (len(a) + len(b) - len(c))

    def id_to_string(self,start_index,end_index,ids):
        if start_index==end_index and start_index==0: #CLS
            return ""
        s = self.config.tokenizer.decode(ids[start_index:end_index+1])
        return s
    
    def postprocess_(self, start, end, valid_range, input_ids):
        start_logits = start.detach().cpu().numpy()
        end_logits = end.detach().cpu().numpy()
        n_best_size = self.config.pred_topk_logits
        
        start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
        end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
        
        batch_answers = []
        for b in range(len(start_indexes)):
            valid_answers = []
            for start_index in start_indexes[b]:
                for end_index in end_indexes[b]:
                    if start_index > end_index:
                        continue
                    if end_index - start_index + 1 > self.config.pred_max_length:
                        continue
                    if start_index < valid_range[b][0] or end_index > valid_range[b][1]:
                        continue
                     
                    valid_answers.append(
                        {
                            "score": start_logits[b][start_index] + end_logits[b][end_index], #<-Heuristic
                            "start_index": start_index,
                            "end_index": end_index,
                        }
                    )
 
            best_answer = {"text": "", "score": 0.0} #dummy
            if len(valid_answers) > 0:
                best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
                best_answer['text'] = self.id_to_string(best_answer['start_index'],best_answer['end_index'],input_ids[b])
                
            cls_index = input_ids[b].detach().cpu().tolist().index(self.config.tokenizer.cls_token_id)
            min_null_score = start_logits[b][cls_index] + end_logits[b][cls_index]
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            batch_answers.append(answer)
        return batch_answers

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

lr_monitor = LearningRateMonitor(logging_interval='step')
# Checkpoint
checkpoint_callback = ModelCheckpoint(monitor='val_metric',
                                      save_top_k=1,
                                      save_last=True,
                                      save_weights_only=True,
                                      filename='checkpoint/{epoch:02d}-{val_metric:.4f}-{val_loss:.4f}',
                                      verbose=False,
                                      mode='max')

# Earlystopping
earlystopping = EarlyStopping(monitor='val_metric', patience=5, mode='max')

In [None]:
config = Config()
datamodule = ChaiiDatamodule(config)
datamodule.setup()

In [None]:
PATH ='/kaggle/input/muril-large-train/lightning_logs/version_0/checkpoints/checkpoint/epoch=00-val_metric=0.6430-val_loss=4.0711.ckpt'
model = ChaiiHFmodel.load_from_checkpoint(PATH,config=config)

**Inference**

In [None]:
model.to('cuda')
print(model.device)
model.eval()
answers = {}
with torch.no_grad():
    for batch in datamodule.test_dataloader():
        batch['attention_mask']=batch['attention_mask'].to('cuda')
        batch['input_ids']=batch['input_ids'].to('cuda')
        (s,e) = model.shared_step(batch, compute_loss = False)
        valid_range = [(-inf,inf) for b in range(model.config.eval_batch_size)]
        predicted_answers = model.postprocess_(s,e,valid_range,batch['input_ids'])
        for b in range(len(predicted_answers)):
            id_col = batch['id'][b]
            if id_col not in answers or answers[id_col]=='':
                answers[id_col]=predicted_answers[b]

In [None]:
submission = pd.DataFrame\
            .from_dict(answers,orient='index',columns=['PredictionString'])\
            .reset_index()\
            .rename(columns={'index': 'id'})

In [None]:
submission.to_csv("submission.csv", index=False)

In [None]:
submission