In [None]:
DEBUG = True
KAGGLE = False

In [None]:
import warnings
if DEBUG:
    warnings.filterwarnings('ignore', category=UserWarning)
import os
import gc
gc.enable()
import math
import json
import time
import random
import multiprocessing
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from sklearn import model_selection
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
import transformers
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    logging,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
)
logging.set_verbosity_warning()
logging.set_verbosity_error()
os.environ['CUDA_VISIBLE_DEVICES'] = '0' if KAGGLE else '1'

# Config

In [None]:
VER = 'v6'
DATA_PATH = '../input/chaii-hindi-and-tamil-question-answering' if KAGGLE else './data'
MDLS_PATH = f'../input/chaii-models-{VER}' if KAGGLE else f'./models_{VER}'
with open(f'{MDLS_PATH}/base_config.json', 'r') as file:
    config = json.load(file)
print('config loaded:', config)

def optimal_workers():
    num_cpus = multiprocessing.cpu_count()
    num_gpus = torch.cuda.device_count()
    optimal_value = min(num_cpus, num_gpus*4) if num_gpus else num_cpus - 1
    print('optimal number of workers is', optimal_value)
    return optimal_value

start_time = time.time()

# Dataset Retriever

In [None]:
class DatasetRetriever(Dataset):
    def __init__(self, features, mode='train'):
        super(DatasetRetriever, self).__init__()
        self.features = features
        self.mode = mode
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, item):   
        feature = self.features[item]
        if self.mode == 'train':
            return {
                'input_ids':torch.tensor(feature['input_ids'], dtype=torch.long),
                'attention_mask':torch.tensor(feature['attention_mask'], dtype=torch.long),
                'offset_mapping':torch.tensor(feature['offset_mapping'], dtype=torch.long),
                'start_position':torch.tensor(feature['start_position'], dtype=torch.long),
                'end_position':torch.tensor(feature['end_position'], dtype=torch.long)
            }
        else:
            return {
                'input_ids':torch.tensor(feature['input_ids'], dtype=torch.long),
                'attention_mask':torch.tensor(feature['attention_mask'], dtype=torch.long),
                'offset_mapping':feature['offset_mapping'],
                'sequence_ids':feature['sequence_ids'],
                'id':feature['example_id'],
                'context': feature['context'],
                'question': feature['question']
            }

# Model

In [None]:
class Model(nn.Module):
    def __init__(self, modelname_or_path, config):
        super(Model, self).__init__()
        self.config = config
        self.xlm_roberta = AutoModel.from_pretrained(
            modelname_or_path, 
            config=config
        )
        self.qa_outputs = nn.Linear(config.hidden_size, 2)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self._init_weights(self.qa_outputs)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(
                mean=0, 
                std=self.config.initializer_range
            )
            if module.bias is not None:
                module.bias.data.zero_()

    def forward(
        self, 
        input_ids, 
        attention_mask=None, 
        # token_type_ids=None
    ):
        outputs = self.xlm_roberta(
            input_ids,
            attention_mask=attention_mask,
        )
        sequence_output = outputs[0]
        pooled_output = outputs[1]
        # sequence_output = self.dropout(sequence_output)
        qa_logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = qa_logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        return start_logits, end_logits

# Utilities

In [None]:
def make_model(config, model_path=None):
    if model_path:
        model_config = AutoConfig.from_pretrained(model_path)
        model = Model(model_path, config=model_config)
    else:
        model_config = AutoConfig.from_pretrained(config['config_name'])
        model = Model(config['model_name_or_path'], config=model_config)
    return model_config, model

# Preprocess

In [None]:
def prepare_test_features(config, example, tokenizer):
    example['question'] = example['question'].lstrip()
    tokenized_example = tokenizer(
        example['question'],
        example['context'],
        truncation='only_second',
        max_length=config['max_seq_length'],
        stride=config['doc_stride'],
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length',
    )
    features = []
    for i in range(len(tokenized_example['input_ids'])):
        feature = {}
        feature['example_id'] = example['id']
        feature['context'] = example['context']
        feature['question'] = example['question']
        feature['input_ids'] = tokenized_example['input_ids'][i]
        feature['attention_mask'] = tokenized_example['attention_mask'][i]
        feature['offset_mapping'] = tokenized_example['offset_mapping'][i]
        feature['sequence_ids'] = [0 if i is None else i for i in tokenized_example.sequence_ids(i)]
        features.append(feature)
    return features

# Postprocess

In [None]:
import collections

def postprocess_qa_predictions(examples, features, raw_predictions, tokenizer,
                               n_best_size=20, max_answer_length=30):
    all_start_logits, all_end_logits = raw_predictions
    example_id_to_index = {k: i for i, k in enumerate(examples['id'])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature['example_id']]].append(i)
    predictions = collections.OrderedDict()
    print(f'post-processing {len(examples)} example predictions',
          f'split into {len(features)} features')
    for example_index, example in examples.iterrows():
        feature_indices = features_per_example[example_index]
        min_null_score = None
        valid_answers = []
        context = example['context']
        for feature_index in feature_indices:
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            sequence_ids = features[feature_index]['sequence_ids']
            context_index = 1
            features[feature_index]['offset_mapping'] = [
                (o if sequence_ids[k] == context_index else None)
                for k, o in enumerate(features[feature_index]['offset_mapping'])
            ]
            offset_mapping = features[feature_index]['offset_mapping']
            cls_index = features[feature_index]['input_ids'].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if (
                        end_index < start_index 
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue
                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            'score': start_logits[start_index] + end_logits[end_index],
                            'text': context[start_char: end_char]
                        }
                    )
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x['score'], reverse=True)[0]
        else:
            best_answer = {'text': '', 'score': 0.0}
        predictions[example['id']] = best_answer['text']
    return predictions

# Inference

In [None]:
def get_predictions(config, model_path):
    model_config, model = make_model(config, model_path)
    model.cuda()
    checkpoint_path = model_path + '/pytorch_model.bin'
    model.load_state_dict(
        torch.load(checkpoint_path)
    )
    start_logits = []
    end_logits = []
    for batch in test_dataloader:
        with torch.no_grad():
            outputs_start, outputs_end = model(
                batch['input_ids'].cuda(), 
                batch['attention_mask'].cuda()
            )
            start_logits.append(outputs_start.cpu().numpy().tolist())
            end_logits.append(outputs_end.cpu().numpy().tolist())
            del outputs_start, outputs_end
    del model, model_config
    gc.collect()
    return np.vstack(start_logits), np.vstack(end_logits)

In [None]:
TOKENIZER = AutoTokenizer.from_pretrained(f'{MDLS_PATH}')

test = pd.read_csv(f'{DATA_PATH}/test.csv')
test_features = []
for i, row in tqdm(test.iterrows(), total=len(test)):
    test_features += prepare_test_features(config, row, TOKENIZER)
test_dataset = DatasetRetriever(test_features, mode='test')
test_dataloader = DataLoader(
    test_dataset,
    batch_size=config['eval_batch_size'], 
    sampler=SequentialSampler(test_dataset),
    num_workers=optimal_workers(),
    pin_memory=True, 
    drop_last=False
)

In [None]:
start_logits = []
end_logits = []
for fold_num in tqdm(range(config['folds'])):
    start_logits_, end_logits_ = get_predictions(
        config, 
        f'{MDLS_PATH}/checkpoint-fold-{fold_num}'
    )
    start_logits.append(start_logits_)
    end_logits.append(end_logits_)

In [None]:
start_logits = np.mean(start_logits, axis=0)
end_logits = np.mean(end_logits, axis=0)
predictions = postprocess_qa_predictions(
    test, 
    test_features, 
    (start_logits, end_logits), 
    TOKENIZER
)

In [None]:
from string import punctuation

subm = []
for p1, p2 in predictions.items():
    p2 = ' '.join(p2.split())
    p2 = p2.strip(punctuation)
    subm.append((p1, p2))
sample = pd.DataFrame(subm, columns=["id", "PredictionString"])
test =pd.merge(left=test, right=sample, on='id')

In [None]:
bad_starts = ['.', ',', '(', ')', '-', '–', ',', ';']
bad_endings = ['...', '-', '(', ')', '–', ',', ';']

tamil_ad = 'கி.பி'
tamil_bc = 'கி.மு'
tamil_km = 'கி.மீ'
hindi_ad = 'ई'
hindi_bc = 'ई.पू'

cleaned_preds = []
for pred, context in test[['PredictionString', 'context']].to_numpy():
    if pred == '':
        cleaned_preds.append(pred)
        continue
    while any([pred.startswith(y) for y in bad_starts]):
        pred = pred[1:]
    while any([pred.endswith(y) for y in bad_endings]):
        if pred.endswith('...'):
            pred = pred[:-3]
        else:
            pred = pred[:-1]
    
    if any([
        pred.endswith(tamil_ad), 
        pred.endswith(tamil_bc), 
        pred.endswith(tamil_km), 
        pred.endswith(hindi_ad), 
        pred.endswith(hindi_bc)
    ]) and (pred + '.') in context:
        pred = pred + '.'
    cleaned_preds.append(pred)
test['PredictionString'] = cleaned_preds
test[['id', 'PredictionString']].to_csv('submission.csv', index=False)
display(test[['id', 'PredictionString']])
print('subm shape:', test[['id', 'PredictionString']].shape)

elapsed_time = time.time() - start_time
print(f'time elapsed: {elapsed_time // 60:.0f} min {elapsed_time % 60:.0f} sec')