base on: https://www.kaggle.com/tomohiroh/nbme-bert-for-beginners

- v2: data cleaning from https://www.kaggle.com/yasufuminakama/nbme-deberta-base-baseline-train

In [None]:
import warnings
warnings.simplefilter('ignore')

import os
import ast 
from itertools import chain

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm, trange

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import precision_recall_fscore_support

import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig

# data preprocessing

In [None]:
ROOT = '../input/nbme-score-clinical-patient-notes'
N_FOLDS = 5

In [None]:
def create_train_df(debug=False):
    feats = pd.read_csv(f"{ROOT}/features.csv")
    feats.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
    
    notes = pd.read_csv(f"{ROOT}/patient_notes.csv")
    train = pd.read_csv(f"{ROOT}/train.csv")
    
    train['annotation'] = train['annotation'].apply(ast.literal_eval)
    train['location'] = train['location'].apply(ast.literal_eval)
    
    train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
    train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')

    train.loc[621, 'annotation'] = ast.literal_eval('[["for the last 2-3 months"]]')
    train.loc[621, 'location'] = ast.literal_eval('[["77 100"]]')

    train.loc[655, 'annotation'] = ast.literal_eval('[["no heat intolerance"], ["no cold intolerance"]]')
    train.loc[655, 'location'] = ast.literal_eval('[["285 292;301 312"], ["285 287;296 312"]]')

    train.loc[1262, 'annotation'] = ast.literal_eval('[["mother thyroid problem"]]')
    train.loc[1262, 'location'] = ast.literal_eval('[["551 557;565 580"]]')

    train.loc[1265, 'annotation'] = ast.literal_eval('[[\'felt like he was going to "pass out"\']]')
    train.loc[1265, 'location'] = ast.literal_eval('[["131 135;181 212"]]')

    train.loc[1396, 'annotation'] = ast.literal_eval('[["stool , with no blood"]]')
    train.loc[1396, 'location'] = ast.literal_eval('[["259 280"]]')

    train.loc[1591, 'annotation'] = ast.literal_eval('[["diarrhoe non blooody"]]')
    train.loc[1591, 'location'] = ast.literal_eval('[["176 184;201 212"]]')

    train.loc[1615, 'annotation'] = ast.literal_eval('[["diarrhea for last 2-3 days"]]')
    train.loc[1615, 'location'] = ast.literal_eval('[["249 257;271 288"]]')

    train.loc[1664, 'annotation'] = ast.literal_eval('[["no vaginal discharge"]]')
    train.loc[1664, 'location'] = ast.literal_eval('[["822 824;907 924"]]')

    train.loc[1714, 'annotation'] = ast.literal_eval('[["started about 8-10 hours ago"]]')
    train.loc[1714, 'location'] = ast.literal_eval('[["101 129"]]')

    train.loc[1929, 'annotation'] = ast.literal_eval('[["no blood in the stool"]]')
    train.loc[1929, 'location'] = ast.literal_eval('[["531 539;549 561"]]')

    train.loc[2134, 'annotation'] = ast.literal_eval('[["last sexually active 9 months ago"]]')
    train.loc[2134, 'location'] = ast.literal_eval('[["540 560;581 593"]]')

    train.loc[2191, 'annotation'] = ast.literal_eval('[["right lower quadrant pain"]]')
    train.loc[2191, 'location'] = ast.literal_eval('[["32 57"]]')

    train.loc[2553, 'annotation'] = ast.literal_eval('[["diarrhoea no blood"]]')
    train.loc[2553, 'location'] = ast.literal_eval('[["308 317;376 384"]]')

    train.loc[3124, 'annotation'] = ast.literal_eval('[["sweating"]]')
    train.loc[3124, 'location'] = ast.literal_eval('[["549 557"]]')

    train.loc[3858, 'annotation'] = ast.literal_eval('[["previously as regular"], ["previously eveyr 28-29 days"], ["previously lasting 5 days"], ["previously regular flow"]]')
    train.loc[3858, 'location'] = ast.literal_eval('[["102 123"], ["102 112;125 141"], ["102 112;143 157"], ["102 112;159 171"]]')

    train.loc[4373, 'annotation'] = ast.literal_eval('[["for 2 months"]]')
    train.loc[4373, 'location'] = ast.literal_eval('[["33 45"]]')

    train.loc[4763, 'annotation'] = ast.literal_eval('[["35 year old"]]')
    train.loc[4763, 'location'] = ast.literal_eval('[["5 16"]]')

    train.loc[4782, 'annotation'] = ast.literal_eval('[["darker brown stools"]]')
    train.loc[4782, 'location'] = ast.literal_eval('[["175 194"]]')

    train.loc[4908, 'annotation'] = ast.literal_eval('[["uncle with peptic ulcer"]]')
    train.loc[4908, 'location'] = ast.literal_eval('[["700 723"]]')

    train.loc[6016, 'annotation'] = ast.literal_eval('[["difficulty falling asleep"]]')
    train.loc[6016, 'location'] = ast.literal_eval('[["225 250"]]')

    train.loc[6192, 'annotation'] = ast.literal_eval('[["helps to take care of aging mother and in-laws"]]')
    train.loc[6192, 'location'] = ast.literal_eval('[["197 218;236 260"]]')

    train.loc[6380, 'annotation'] = ast.literal_eval('[["No hair changes"], ["No skin changes"], ["No GI changes"], ["No palpitations"], ["No excessive sweating"]]')
    train.loc[6380, 'location'] = ast.literal_eval('[["480 482;507 519"], ["480 482;499 503;512 519"], ["480 482;521 531"], ["480 482;533 545"], ["480 482;564 582"]]')

    train.loc[6562, 'annotation'] = ast.literal_eval('[["stressed due to taking care of her mother"], ["stressed due to taking care of husbands parents"]]')
    train.loc[6562, 'location'] = ast.literal_eval('[["290 320;327 337"], ["290 320;342 358"]]')

    train.loc[6862, 'annotation'] = ast.literal_eval('[["stressor taking care of many sick family members"]]')
    train.loc[6862, 'location'] = ast.literal_eval('[["288 296;324 363"]]')

    train.loc[7022, 'annotation'] = ast.literal_eval('[["heart started racing and felt numbness for the 1st time in her finger tips"]]')
    train.loc[7022, 'location'] = ast.literal_eval('[["108 182"]]')

    train.loc[7422, 'annotation'] = ast.literal_eval('[["first started 5 yrs"]]')
    train.loc[7422, 'location'] = ast.literal_eval('[["102 121"]]')

    train.loc[8876, 'annotation'] = ast.literal_eval('[["No shortness of breath"]]')
    train.loc[8876, 'location'] = ast.literal_eval('[["481 483;533 552"]]')

    train.loc[9027, 'annotation'] = ast.literal_eval('[["recent URI"], ["nasal stuffines, rhinorrhea, for 3-4 days"]]')
    train.loc[9027, 'location'] = ast.literal_eval('[["92 102"], ["123 164"]]')

    train.loc[9938, 'annotation'] = ast.literal_eval('[["irregularity with her cycles"], ["heavier bleeding"], ["changes her pad every couple hours"]]')
    train.loc[9938, 'location'] = ast.literal_eval('[["89 117"], ["122 138"], ["368 402"]]')

    train.loc[9973, 'annotation'] = ast.literal_eval('[["gaining 10-15 lbs"]]')
    train.loc[9973, 'location'] = ast.literal_eval('[["344 361"]]')

    train.loc[10513, 'annotation'] = ast.literal_eval('[["weight gain"], ["gain of 10-16lbs"]]')
    train.loc[10513, 'location'] = ast.literal_eval('[["600 611"], ["607 623"]]')

    train.loc[11551, 'annotation'] = ast.literal_eval('[["seeing her son knows are not real"]]')
    train.loc[11551, 'location'] = ast.literal_eval('[["386 400;443 461"]]')

    train.loc[11677, 'annotation'] = ast.literal_eval('[["saw him once in the kitchen after he died"]]')
    train.loc[11677, 'location'] = ast.literal_eval('[["160 201"]]')

    train.loc[12124, 'annotation'] = ast.literal_eval('[["tried Ambien but it didnt work"]]')
    train.loc[12124, 'location'] = ast.literal_eval('[["325 337;349 366"]]')

    train.loc[12279, 'annotation'] = ast.literal_eval('[["heard what she described as a party later than evening these things did not actually happen"]]')
    train.loc[12279, 'location'] = ast.literal_eval('[["405 459;488 524"]]')

    train.loc[12289, 'annotation'] = ast.literal_eval('[["experienced seeing her son at the kitchen table these things did not actually happen"]]')
    train.loc[12289, 'location'] = ast.literal_eval('[["353 400;488 524"]]')

    train.loc[13238, 'annotation'] = ast.literal_eval('[["SCRACHY THROAT"], ["RUNNY NOSE"]]')
    train.loc[13238, 'location'] = ast.literal_eval('[["293 307"], ["321 331"]]')

    train.loc[13297, 'annotation'] = ast.literal_eval('[["without improvement when taking tylenol"], ["without improvement when taking ibuprofen"]]')
    train.loc[13297, 'location'] = ast.literal_eval('[["182 221"], ["182 213;225 234"]]')

    train.loc[13299, 'annotation'] = ast.literal_eval('[["yesterday"], ["yesterday"]]')
    train.loc[13299, 'location'] = ast.literal_eval('[["79 88"], ["409 418"]]')

    train.loc[13845, 'annotation'] = ast.literal_eval('[["headache global"], ["headache throughout her head"]]')
    train.loc[13845, 'location'] = ast.literal_eval('[["86 94;230 236"], ["86 94;237 256"]]')

    train.loc[14083, 'annotation'] = ast.literal_eval('[["headache generalized in her head"]]')
    train.loc[14083, 'location'] = ast.literal_eval('[["56 64;156 179"]]')

    merged = train.merge(notes, how = "left")
    merged = merged.merge(feats, how = "left")
    
    def process_feature_text(text):
        return text.replace("-OR-", ";-").replace("-", " ")
    merged["feature_text"] = [process_feature_text(x) for x in merged["feature_text"]]
    
    merged["feature_text"] = merged["feature_text"].apply(lambda x: x.lower())
    merged["pn_history"] = merged["pn_history"].apply(lambda x: x.lower())
    
    if debug:
        merged = merged.sample(frac=0.5).reset_index(drop=True)
        
    skf = StratifiedKFold(n_splits=N_FOLDS)
    merged["stratify_on"] = merged["case_num"].astype(str) + merged["feature_num"].astype(str)
    merged["fold"] = -1
    for fold, (_, valid_idx) in enumerate(skf.split(merged["id"], y=merged["stratify_on"])):
        merged.loc[valid_idx, "fold"] = fold
    
    return merged

train_df = create_train_df()
display(train_df.head())

In [None]:
def loc_list_to_ints(loc_list):
    to_return = []
    for loc_str in loc_list:
        loc_strs = loc_str.split(";")
        for loc in loc_strs:
            start, end = loc.split()
            to_return.append((int(start), int(end)))
    return to_return

# tokenizer

In [None]:
MODEL_NAME = 'roberta-base'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
def tokenize_and_add_labels(tokenizer, example):
    tokenized_inputs = tokenizer(
        example["feature_text"],      # question
        example["pn_history"],        # content
        truncation="only_second",
        max_length=480,
        padding="max_length",
        return_offsets_mapping=True
    )
    labels = [0.0] * len(tokenized_inputs["input_ids"])
    tokenized_inputs["location_int"] = loc_list_to_ints(example["location"])
    tokenized_inputs["sequence_ids"] = tokenized_inputs.sequence_ids()
    
    for idx, (seq_id, offsets) in enumerate(zip(tokenized_inputs["sequence_ids"], tokenized_inputs["offset_mapping"])):
        if seq_id is None or seq_id == 0:     # seq_id == None: special tokens | seq_id == 0: question
            labels[idx] = -100
            continue
        exit = False
        token_start, token_end = offsets
        for feature_start, feature_end in tokenized_inputs["location_int"]:
            if exit:
                break
            if token_start >= feature_start and token_end <= feature_end:
                labels[idx] = 1.0
                exit = True
    tokenized_inputs["labels"] = labels
    
    return tokenized_inputs

In [None]:
first = train_df.loc[0]
example = {
    "feature_text": first.feature_text,
    "pn_history": first.pn_history,
    "location": first.location,
    "annotation": first.annotation
}
for key in example.keys():
    print(key)
    print(example[key])
    print("=" * 100)

In [None]:
tokenized_inputs = tokenize_and_add_labels(tokenizer, example)
for key in tokenized_inputs.keys():
    print(key)
    print(tokenized_inputs[key])
    print("=" * 100)

# dataset

In [None]:
class NBMEData(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data.loc[idx]
        tokenized = tokenize_and_add_labels(self.tokenizer, example)
        
        input_ids = np.array(tokenized["input_ids"])
        attention_mask = np.array(tokenized["attention_mask"])
        labels = np.array(tokenized["labels"])
        offset_mapping = np.array(tokenized["offset_mapping"])
        sequence_ids = np.array(tokenized["sequence_ids"]).astype("float16")
        
        return input_ids, attention_mask, labels, offset_mapping, sequence_ids

# model

In [None]:
class NBMEModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(MODEL_NAME)
        self.config = AutoConfig.from_pretrained(MODEL_NAME)
        self.dropout = torch.nn.Dropout(p=0.2)
        self.classifier = torch.nn.Linear(self.config.hidden_size, 1)
        
    def forward(self, input_ids, attention_mask):
        pooler_outputs = self.backbone(input_ids=input_ids, 
                                       attention_mask=attention_mask)[0]
        logits = self.classifier(self.dropout(pooler_outputs)).squeeze(-1)
        return logits

# training

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

In [None]:
BATCH_SIZE = 16
ACCUMULATION_STEPS = 2
EPOCHS = 5

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n = 1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def sigmoid(z):
    return 1 / (1 + np.exp(-z))


def get_location_predictions(preds, offset_mapping, sequence_ids, test = False):
    all_predictions = []
    for pred, offsets, seq_ids in zip(preds, offset_mapping, sequence_ids):
        pred = sigmoid(pred)
        start_idx = None
        current_preds = []
        for p, o, s_id in zip(pred, offsets, seq_ids):
            if s_id is None or s_id == 0:
                continue
            if p > 0.5:
                if start_idx is None:
                    start_idx = o[0]
                end_idx = o[1]
            elif start_idx is not None:
                if test:
                    current_preds.append(f"{start_idx} {end_idx}")
                else:
                    current_preds.append((start_idx, end_idx))
                start_idx = None
        if test:
            all_predictions.append("; ".join(current_preds))
        else:
            all_predictions.append(current_preds)
    return all_predictions


def calculate_char_CV(predictions, offset_mapping, sequence_ids, labels):
    all_labels = []
    all_preds = []
    for preds, offsets, seq_ids, labels in zip(predictions, offset_mapping, sequence_ids, labels):
        num_chars = max(list(chain(*offsets)))
        char_labels = np.zeros((num_chars))
        for o, s_id, label in zip(offsets, seq_ids, labels):
            if s_id is None or s_id == 0:
                continue
            if int(label) == 1:
                char_labels[o[0]:o[1]] = 1
        char_preds = np.zeros((num_chars))
        for start_idx, end_idx in preds:
            char_preds[start_idx:end_idx] = 1
        all_labels.extend(char_labels)
        all_preds.extend(char_preds)
    results = precision_recall_fscore_support(all_labels, all_preds, average = "binary")
    return {
        "precision": results[0],
        "recall": results[1],
        "f1": results[2]
    }


def compute_metrics(p):
    predictions, y_true = p
    y_true = y_true.astype(int)
    y_pred = [
        [int(p > 0.5) for (p, l) in zip(pred, label) if l != -100]
        for pred, label in zip(predictions, y_true)
    ]
    y_true = [
        [l for l in label if l != -100] for label in y_true
    ]
    results = precision_recall_fscore_support(list(chain(*y_true)), list(chain(*y_pred)), average = "binary")
    return {
        "token_precision": results[0],
        "token_recall": results[1],
        "token_f1": results[2]
    }

In [None]:
def train_fn(fold):
    
    # get dataloader
    train = train_df[train_df["fold"] != fold].reset_index(drop=True)
    valid = train_df[train_df["fold"] == fold].reset_index(drop=True)
    train_ds = NBMEData(train, tokenizer)
    valid_ds = NBMEData(valid, tokenizer)
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, 
                                           pin_memory=True, shuffle=True, drop_last=True)
    valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=BATCH_SIZE * 2, 
                                           pin_memory=True, shuffle=False, drop_last=False)
    
    # get model
    model = NBMEModel().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none")
    
    history = {"train": [], "valid": []}
    best_loss = np.inf
    
    # training
    for epoch in range(EPOCHS):
        model.train()
        train_loss = AverageMeter()
        pbar = tqdm(train_dl)
        for i, batch in enumerate(pbar):
            input_ids = batch[0].to(DEVICE)
            attention_mask = batch[1].to(DEVICE)
            labels = batch[2].to(DEVICE)
            logits = model(input_ids, attention_mask)
            loss = loss_fn(logits, labels)
            loss = torch.masked_select(loss, labels > -1).mean()
            loss /= ACCUMULATION_STEPS
            loss.backward()
            if (i+1) % ACCUMULATION_STEPS == 0: 
                optimizer.step()
                optimizer.zero_grad()
            train_loss.update(val=loss.item(), n=len(input_ids))
            pbar.set_postfix(Loss=train_loss.avg)
        print(f"EPOCH: {epoch} train loss: {train_loss.avg}")
        history["train"].append(train_loss.avg)
        
        # evaluation
        model.eval()
        valid_loss = AverageMeter()
        pbar = tqdm(valid_dl)
        with torch.no_grad():
            for i, batch in enumerate(pbar):
                input_ids = batch[0].to(DEVICE)
                attention_mask = batch[1].to(DEVICE)
                labels = batch[2].to(DEVICE)
                logits = model(input_ids, attention_mask)
                loss = loss_fn(logits, labels)
                loss = torch.masked_select(loss, labels > -1).mean()
                valid_loss.update(val=loss.item(), n=len(input_ids))
                pbar.set_postfix(Loss=valid_loss.avg)
        print(f"EPOCH: {epoch} valid loss: {valid_loss.avg}")
        history["valid"].append(valid_loss.avg)

        # save model
        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            torch.save(model.state_dict(), f"nbme_{fold}.pth")
        
    # evaluation summary
    model.load_state_dict(torch.load(f"nbme_{fold}.pth", map_location = DEVICE))
    model.eval()
    preds = []
    offsets = []
    seq_ids = []
    lbls = []
    with torch.no_grad():
        for batch in tqdm(valid_dl):
            input_ids = batch[0].to(DEVICE)
            attention_mask = batch[1].to(DEVICE)
            labels = batch[2].to(DEVICE)
            offset_mapping = batch[3]
            sequence_ids = batch[4]
            logits = model(input_ids, attention_mask)
            preds.append(logits.cpu().numpy())
            offsets.append(offset_mapping.numpy())
            seq_ids.append(sequence_ids.numpy())
            lbls.append(labels.cpu().numpy())
    preds = np.concatenate(preds, axis=0)
    offsets = np.concatenate(offsets, axis=0)
    seq_ids = np.concatenate(seq_ids, axis=0)
    lbls = np.concatenate(lbls, axis=0)
    location_preds = get_location_predictions(preds, offsets, seq_ids, test=False)
    score = calculate_char_CV(location_preds, offsets, seq_ids, lbls)
    print(f"Fold: {fold} CV score: {score}")

In [None]:
for fold in range(N_FOLDS):
    train_fn(fold)