In [59]:
!nvidia-smi

## Library

In [60]:
import os
import json
import torch
import torch.nn as n
import torch.nn.functional as F
import argparse
import importlib
import pandas as pd
import numpy as np

from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import (AutoTokenizer, 
    AutoConfig, 
    AutoModelForTokenClassification, 
    Trainer, 
    DataCollatorWithPadding
)


## PLM

In [61]:
TOKENIZER = '/kaggle/input/roberta-large/tokenizer'
PLM_BASE = '/kaggle/input/roberta-large/'
MAX_LENGTH = 512
BATCH_SIZE = 4
K_FOLD = 5

## Load Datasets

In [62]:
dir_path = '/kaggle/input/nbme-score-clinical-patient-notes/'

In [63]:
test_df = pd.read_csv(os.path.join(dir_path, 'test.csv'))
patients_df = pd.read_csv(os.path.join(dir_path, 'patient_notes.csv'))
features_df = pd.read_csv(os.path.join(dir_path, 'features.csv'))

test_df = test_df.merge(features_df, on=['feature_num', 'case_num'], how='left')
test_df = test_df.merge(patients_df, on=['pn_num', 'case_num'], how='left')

In [64]:
test_size = len(test_df)
test_df.head()

## Preprocessing Datasets

In [65]:
feature_text = list(test_df['feature_text'])
pn_history = list(test_df['pn_history'])

In [66]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

In [67]:
inputs = [pn_history[i] + tokenizer.sep_token + feature_text[i] for i in range(len(test_df))]
encoded = tokenizer(inputs,
    return_offsets_mapping=True,
    return_token_type_ids=False,
    truncation=True,
)

In [68]:
input_ids = encoded['input_ids']
attention_mask = encoded['attention_mask']

## Datasets

In [69]:
class TestDataset(Dataset) :
    def __init__(self, input_ids, attention_mask) :
        super(TestDataset , self).__init__()
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        
    def __len__(self) :
        return len(self.input_ids)

    def __getitem__(self , idx) :
        return {'input_ids' : self.input_ids[idx], 'attention_mask' : self.attention_mask[idx]}

In [70]:
dataset = TestDataset(input_ids, attention_mask)

## Collator

In [71]:
collator = DataCollatorWithPadding(tokenizer=tokenizer, max_length=MAX_LENGTH)

## Device

In [72]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Dataloader

In [73]:
data_loader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=collator)

## Inference

In [74]:
offset_mapping = encoded.pop('offset_mapping')

In [75]:
predictions = []

for i in tqdm(range(K_FOLD)) :
    plm = 'fold-' + str(i+1)
    model_path = os.path.join(PLM_BASE, plm)
    
    config = AutoConfig.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path, config=config).to(device)
    model.eval()
    
    probs_list = []

    for data in tqdm(data_loader) :
        data = {k:v.to(device) for k,v in data.items()}
        results = model(**data)

        logits = results.logits
        probs = F.softmax(logits, dim=-1).detach().cpu().numpy()
        probs = [prob for prob in probs]
        
        probs_list.extend(probs)
        
    predictions.append(probs_list)

In [76]:
results = []

for i in tqdm(range(test_size)) :
    pred_list = [predictions[j][i] for j in range(K_FOLD)]
    preds = np.mean(pred_list, axis=0)
    pred_ids = np.argmax(preds, axis=-1)
    results.append(pred_ids)

## Postprocessing Datasets

In [77]:
def postprocess(pos_list) :
    idx = 0
    start = 0

    results = []
    while(idx < len(pos_list)) :
        if idx + 1 == len(pos_list) :
            prev_start, prev_end = pos_list[start]
            cur_start, cur_end = pos_list[idx]

            results.append([prev_start, cur_end])
            idx += 1
        else :
            prev_start, prev_end = pos_list[idx]
            cur_start, cur_end = pos_list[idx+1]

            if cur_start == prev_end + 1 or cur_start == prev_end :
                idx += 1
            else :
                span_start = pos_list[start][0]
                span_end = pos_list[idx][1]

                results.append([span_start, span_end])
                start = idx+1
                idx = start
    
    results = [[str(span[0]), str(span[1])] for span in results]
    span_list = [' '.join(span) for span in results]
    return ';'.join(span_list)

In [78]:
locations = []

for i, pred in enumerate(results) :
    offset = offset_mapping[i]
    input_ids = encoded['input_ids'][i]
    
    token_start_index = 1
    token_end_index = input_ids.index(tokenizer.sep_token_id)
    
    span_list = []
    for j in range(token_start_index, token_end_index) :
        if pred[j] == 1 :
            span_list.append(offset[j])
            
    span = '' if len(span_list) == 0 else postprocess(span_list)
    locations.append(span)

In [79]:
test_df['location'] = locations
test_df = test_df.drop(columns = ['case_num', 'pn_num', 'feature_num', 'feature_text', 'pn_history'])
test_df.head()

In [80]:
test_df.to_csv('submission.csv', index=False)