# NBME competition notebook

About this notebook :
- This notebook performs token binary classification on labeled datasets.
- Transfer learning weights are from another notebook (which performs sequence classification on all interview notes, https://www.kaggle.com/code/joanyeo/nbme-hf-distilbert-2train).
- Utilize DistilBERT instead of BERT to speed up training.
- Manage configurations with yaml files as used in Dockerfiles.
- Automatically changes learning rate which can be monitored by Weights & Biases.
- Performs 5 fold cross validation to adapt on general datasets.
- Simplify training code with `train` and `run` functions.
- Automatically save best models and load the best one from each fold in inference.
- Epoch is adjusted to prevent overfitting : at first training, validation losses were generally higher than training losses.

- This notebook got some great ideas from 2 notebooks that are:
    - https://www.kaggle.com/code/tanyadayanand/nbme-bert-base-uncased-using-pytorch
    - https://prgms.tistory.com/73
    

## Contents
- 0. Getting Ready
    - Import Lib
    - ConfigManager
    - Fix SEED
- 1. Data & Model
    - Load Data
    - Dataset
    - Model
    - AverageMeter
    - Loss
- 2. Training
- 3. Inference

## 0. Getting Ready
### Import Lib

In [None]:
import os
import re
import math
import time
import tqdm
import yaml
import torch
import random
import warnings
import tokenizers
import numpy as np
import pandas as pd
import transformers

from tqdm.auto import tqdm
from ast import literal_eval
from easydict import EasyDict
from torch.optim import AdamW
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from sklearn.model_selection import KFold, train_test_split

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [None]:
# Set True to use Wandb
# It is only avaiable with Internet connection.
WANDB = False

if WANDB:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value = user_secrets.get_secret("wandb")
    os.environ["WANDB_API_KEY"] = secret_value

    !pip -q install wandb
    !wandb login 

    import wandb
    wandb.init('nbme-study')
else:
    os.environ['WANDB_CONSOLE'] = 'off'

### ConfigManager
- Use .yaml file to manage configurations.
- We can test on different configs.

In [None]:
class YamlConfigManager:
    def __init__(self, config_file_path='../input/config/config.yaml', config_name='base'):
        super().__init__()
        self.values = EasyDict()
        if config_file_path:
            self.config_file_path = config_file_path
            self.config_name = config_name
            self.reload()
            
    def reload(self):
        self.clear()
        if self.config_file_path:
            with open(self.config_file_path, 'r') as f:
                self.values.update(yaml.safe_load(f)[self.config_name])
                
    def clear(self):
        self.values.clear()
        
    def update(self, yaml_dict):
        for k1, v1 in yaml_dict.items():
            if isinstance(v1, dict):
                for k2, v2 in v1.items():
                    if isinstance(v2, dict):
                        for k3, v3 in v2.items():
                            self.values[k1][k2][k3] = v3
                    else:
                        self.values[k1][k2] = v2
            else:
                self.values[k1] = v1
                
    def export(self, save_file_path):
        if save_file_path:
            with open(save_file_path, 'w') as f:
                yaml.dump(dict(self.values), f)

In [None]:
# config.yaml

# base:
#     seed: 1004
#     model_arc: 'distilbert'
#     num_classes: 2
#     input_dir: '../input/nbme-score-clinical-patient-notes'
#     output_dir: './results/'
#     train_only: False
#     max_len: 512
#     ckp_path: '../input/nbme-hf-distilbert-2train/train/nbme-case/checkpoint-10432/'
#     train_args:
#         num_epochs: 7
#         train_batch_size: 32
#         val_batch_size: 32
#         model_path: 'pytorch_model.bin'
#         dropout_rate: 0.2 # 0.1~0.3
#         max_grad_norm: 1.0
#         max_lr: 0.0001
#         min_lr: 0.00001
#         cycle: 3
#         gamma: 0.5
#         weight_decay: 0.000001
#         log_intervals: 10
#         eval_metric: 'accuracy'
#         n_splits: 5

In [None]:
cfg = YamlConfigManager()
SEED = cfg.values.seed
MODEL_ARC = cfg.values.model_arc
INPUT_DIR = cfg.values.input_dir
OUTPUT_DIR = cfg.values.output_dir
TRAIN_ONLY = cfg.values.train_only
MAX_LEN = cfg.values.max_len
TOKENIZER = tokenizers.BertWordPieceTokenizer(f"{cfg.values.ckp_path}/vocab.txt", lowercase = True)

In [None]:
yaml_dict = dict(cfg.values)
yaml_dict['train_args']['num_epochs'] = 4
cfg.update(yaml_dict)

### Fix SEED

In [None]:
def seed_everything(seed=1004):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
seed_everything(SEED)

## 1. Data & Model
### Load Data

In [None]:
train_df = pd.read_csv(os.path.join(INPUT_DIR, 'train.csv'))
feature_df = pd.read_csv(os.path.join(INPUT_DIR, 'features.csv'))
pn_df = pd.read_csv(os.path.join(INPUT_DIR, 'patient_notes.csv'))
test_df = pd.read_csv(os.path.join(INPUT_DIR, 'test.csv'))
submission_df = pd.read_csv(os.path.join(INPUT_DIR, 'sample_submission.csv'))

In [None]:
df = pd.merge(train_df, feature_df, on=['feature_num','case_num'], how='inner')
df = pd.merge(df, pn_df, on=['pn_num','case_num'], how='inner')
df.sample(5, random_state=SEED)

In [None]:
df['feature_text'].value_counts()

In [None]:
df["annotation"] = [literal_eval(x) for x in df["annotation"]]
df["location"] = [literal_eval(x) for x in df["location"]]
df.sample(5, random_state=SEED)

### Dataset


In [None]:
def loc_list_to_ints(loc_list):
    to_return = []
    for loc_str in loc_list:
        loc_strs = loc_str.split(";")
        for loc in loc_strs:
            start, end = loc.split()
            to_return.append((int(start), int(end)))
    return to_return

In [None]:
def preprocess(pn_history, feature_text, annotation, location):
    
    location_list = loc_list_to_ints(location)        
    char_targets = [0] * len(pn_history) 
    
    for loc, anno in zip(location_list, annotation): 
        len_st = loc[1] - loc[0]
        idx0 = None
        idx1 = None
        for ind in (i for i, e in enumerate(pn_history) if (e == anno[0] and i == loc[0])):
            if pn_history[ind: ind + len_st] == anno:
                idx0 = ind
                idx1 = ind + len_st - 1
                if idx0 != None and idx1 != None:
                    for ct in range(idx0, idx1 + 1):
                        char_targets[ct] = 1  
                break
      
    tokenized_input = TOKENIZER.encode(feature_text, pn_history)
    
    input_ids = tokenized_input.ids
    mask = tokenized_input.attention_mask
    token_type_ids = tokenized_input.type_ids
    offsets = tokenized_input.offsets
    
    target_idx = []
    for j, (offset1, offset2) in enumerate(offsets):
        if sum(char_targets[offset1: offset2]) > 0:
            target_idx.append(j)
            
    #padding
    padding_length = MAX_LEN - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        offsets = offsets + ([(0, 0)] * padding_length)
       
    #creating label
    ignore_idxes = np.where(np.array(token_type_ids) != 1)[0]

    label = np.zeros(len(offsets))
    label[ignore_idxes] = -1
    label[target_idx] = 1

    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'labels': label,
        'offsets': offsets
    }

In [None]:
class NBMEDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        super().__init__()
        self.df = df.reset_index()
        self.pn_history = df.pn_history.values
        self.feature_text = df.feature_text.values
        self.annotation = df.annotation.values
        self.location = df.location.values
        
    def __len__(self):
        return len(self.pn_history)
    
    def __getitem__(self, item):
        data = preprocess(
            self.pn_history[item],
            self.feature_text[item],
            self.annotation[item],
            self.location[item]
        )
        
        return {
            'ids': torch.tensor(data['ids'], dtype=torch.long),
            'mask': torch.tensor(data['mask'], dtype=torch.long),
            'token_type_ids': torch.tensor(data['token_type_ids'], dtype=torch.long),
            'labels': torch.tensor(data['labels'], dtype=torch.long),
            'offsets': torch.tensor(data['offsets'], dtype=torch.long)
        }

In [None]:
def get_dataloader(df, batch_size, shuffle):
    dataset = NBMEDataset(df=df)
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=4
    )
    
    return dataloader

### Model


In [None]:
class NBMEModel(transformers.DistilBertModel):
    def __init__(self, conf):
        super(NBMEModel, self).__init__(conf)
        self.pretrained_model = transformers.DistilBertModel.from_pretrained(cfg.values.ckp_path, config=conf)
        self.dropout = torch.nn.Dropout(cfg.values.train_args.dropout_rate)
        self.classifier = torch.nn.Linear(768, 1)
        torch.nn.init.normal_(self.classifier.weight, std=0.02)
    
    def forward(self, ids, mask, token_type_ids):
        sequence_output = self.pretrained_model(
            input_ids=ids, 
            attention_mask=mask,
            # DistilBert does not take in token_type_ids
        )[0]
        
        sequence_output = self.dropout(sequence_output)
        
        logits = self.classifier(sequence_output)
        logits = logits.squeeze(-1)
        
        return logits

### AverageMeter

In [None]:
class AverageMeter():
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

### Loss
- Use Binary Cross Entoropy loss

In [None]:
class ComputeMetric(object):
    def __init__(self, metric='bce') -> None:
        super().__init__()
        self.metric = metric

    def compute_loss(self, logits, labels):
        if self.metric == 'bce':
            loss_fct = torch.nn.BCEWithLogitsLoss(reduction = "none")
            loss = loss_fct(logits, labels)
        return loss

## 2. Training

- Do 5 Fold cross validation to be ready for general dataset.
- Train each model with different learning rate.
- Use CosineAnnealingWarmRestart Scheduler for learning rate annealing.
- When using Wandb, we log epoch, learning rate, loss and logits.

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def train(cfg, fold, train_loader, valid_loader):
        
    # Set train arguments
    num_epochs = cfg.values.train_args.num_epochs
    log_intervals = cfg.values.train_args.log_intervals
    max_lr = cfg.values.train_args.max_lr
    min_lr = cfg.values.train_args.min_lr
    cycle = cfg.values.train_args.cycle
    gamma = cfg.values.train_args.gamma
    weight_decay = cfg.values.train_args.weight_decay
    ckp_path = cfg.values.ckp_path
    max_grad_norm = cfg.values.train_args.max_grad_norm
    train_batch_size = cfg.values.train_args.train_batch_size
    val_batch_size = cfg.values.train_args.val_batch_size
    
    # Load model
    model_config = transformers.DistilBertConfig.from_pretrained(os.path.join(ckp_path, 'config.json'))
    model_config.output_hidden_states = True
    model = NBMEModel(conf=model_config)
    model.to(DEVICE)
    
    num_train_steps = int(len(train_loader) / train_batch_size * num_epochs)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]
    
    # Set optimizer and scheduler
    optimizer = AdamW(model.parameters(), 
                      lr=max_lr, 
                      weight_decay=weight_decay)
    first_cycle_steps = len(train_loader) * num_epochs // cycle
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=first_cycle_steps,
        eta_min=min_lr,
    )
    
    eval_metric = ComputeMetric(metric='bce')
    best_loss = np.inf
    
    os.makedirs(os.path.join(OUTPUT_DIR, MODEL_ARC), exist_ok=True)
    
    # Train num_epochs times
    for epoch in range(num_epochs):
 
        model.train()
        since = time.time()
        loss_values = AverageMeter()
        
        for idx, train_batch in enumerate(tqdm(train_loader, desc=f'Train')):
            
            ids = train_batch['ids'].to(DEVICE, dtype=torch.long)
            mask = train_batch['mask'].to(DEVICE, dtype=torch.long)
            token_type_ids = train_batch['token_type_ids'].to(DEVICE, dtype=torch.long)
            offsets = train_batch['offsets'].to(DEVICE, dtype=torch.long)
            labels = train_batch['labels'].to(DEVICE, dtype=torch.float64)
            
            model.zero_grad()
            logits = model(ids=ids, 
                           mask=mask,
                           token_type_ids=token_type_ids) #last_hidden_state
            
            # measure evaluation metric and record loss
            loss = eval_metric.compute_loss(logits, labels)
            loss = torch.masked_select(loss, labels > -1.0).mean()
            loss_values.update(loss.item(), ids.size(0))
            loss.requires_grad_(True)
            loss.backward()

            # compute gradient and do optimizer step
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            scheduler.step()

            if WANDB:
                wandb.log({
                    "epoch": epoch,
                    "lr": scheduler.get_lr()[0],
                    "loss": loss, 
                    "logits": wandb.Histogram(logits.cpu().detach().numpy()),
                })
            
            if idx % log_intervals == 0:
                current_lr = scheduler.get_lr()[0]
                time_elapsed = time.time() - since
                tqdm.write(f"Epoch : [{epoch + 1} / {num_epochs}][{idx}/{len(train_loader)}] || "
                           f"LR : {current_lr:.5f} || "
                           f"Train Loss : {loss_values.val:.4f} ({loss_values.avg:.4f}) || "
                           f"Training completed in {time_elapsed % 60:.0f}s"
                          )
    
        if not TRAIN_ONLY:
            
            since = time.time()
            
            with torch.no_grad():
                model.eval()
                loss_values = AverageMeter()

                for idx, val_batch in enumerate(tqdm(valid_loader, desc=f"Validation")):

                    ids = val_batch['ids'].to(DEVICE, dtype=torch.long)
                    mask = val_batch['mask'].to(DEVICE, dtype=torch.long)
                    token_type_ids = val_batch['token_type_ids'].to(DEVICE, dtype=torch.long)
                    offsets = val_batch['offsets'].to(DEVICE, dtype=torch.long)
                    labels = val_batch['labels'].to(DEVICE, dtype=torch.float64)

                    model.zero_grad()
                    logits = model(ids=ids, 
                                   mask=mask, 
                                   token_type_ids=token_type_ids) #last_hidden_state

                    # measure evaluation metric and record loss
                    loss = eval_metric.compute_loss(logits, labels)
                    loss = torch.masked_select(loss, labels > -1.0).mean()
                    loss_values.update(loss.item(), ids.size(0))

            time_elapsed = time.time() - since
            tqdm.write(f"Epoch : [{epoch + 1} / {num_epochs}] || "
                       f"Val Loss : {loss_values.avg:.4f} || "
                       f"Validation completed in {time_elapsed % 60:.0f}s"
                      )

            is_best = loss_values.avg < best_loss
            best_loss = min(loss_values.avg, best_loss)

            if is_best:
                os.makedirs(os.path.join(OUTPUT_DIR, MODEL_ARC, f"{fold+1}_fold"), exist_ok=True)
                torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, MODEL_ARC, f"{fold+1}_fold", f"{epoch+1}_epoch_{best_loss:.4f}%_with_val.pth"))

In [None]:
def run(cfg, df): 
    since = time.time()

    # Set train arguments
    n_splits = cfg.values.train_args.n_splits
    train_batch_size = cfg.values.train_args.train_batch_size
    val_batch_size = cfg.values.train_args.val_batch_size
    
    # Train on K-fold
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(df)):
        print('\n')
        print('*' * 15 + f" {fold + 1}-Fold Cross Validation " + '*' * 15)

        train_df = df.iloc[train_idx]
        val_df = df.iloc[val_idx]

        train_loader = get_dataloader(
            df=train_df, 
            batch_size=train_batch_size, 
            shuffle=True
        )
        
        val_loader = get_dataloader(
            df=val_df,
            batch_size=val_batch_size,
            shuffle=False
        )
    
        train(cfg, fold, train_loader, val_loader)
        
    time_elapsed = time.time() - since
    print('*' * 50)
    print(f"Total Time {time_elapsed // 3600}h {(time_elapsed // 60) % 60}m {time_elapsed % 60:.0f}s Elapsed.")

In [None]:
torch.cuda.empty_cache()
run(cfg, df)

## 3. Inference
- Make prediction, convert data types and make submission.csv file.

In [None]:
df_tst = pd.merge(test_df, feature_df, on=['feature_num','case_num'], how='inner')
df_tst = pd.merge(df_tst, pn_df, on=['pn_num','case_num'], how='inner')
df_tst.shape

In [None]:
def test_preprocess(pn_history, feature_text):
      
    tokenized_input = TOKENIZER.encode(feature_text, pn_history)
    
    input_ids = tokenized_input.ids
    mask = tokenized_input.attention_mask
    token_type_ids = tokenized_input.type_ids
    offsets = tokenized_input.offsets
            
    #padding
    padding_length = MAX_LEN - len(input_ids)
    if padding_length > 0:
        input_ids = input_ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        offsets = offsets + ([(0, 0)] * padding_length)

    return {
        'ids': input_ids,
        'mask': mask,
        'token_type_ids': token_type_ids,
        'offsets': offsets
    }

In [None]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        super().__init__()
        self.df = df.reset_index()
        self.pn_history = df.pn_history
        self.feature_text = df.feature_text
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        data = test_preprocess(
            self.pn_history[idx],
            self.feature_text[idx],
        )
        
        return {
            'ids': torch.tensor(data['ids'], dtype=torch.long),
            'mask': torch.tensor(data['mask'], dtype=torch.long),
            'token_type_ids': torch.tensor(data['token_type_ids'], dtype=torch.long),
            'offsets': torch.tensor(data['offsets'], dtype=torch.long)
        }

In [None]:
# Get the last checkpoint, which scored best in validation, from each folds
model_ckps = []
n_splits = cfg.values.train_args.n_splits

for fold in range(n_splits):
    path = os.listdir(os.path.join(OUTPUT_DIR, MODEL_ARC, f"{fold+1}_fold"))[-1]
    model_ckps.append(os.path.join(OUTPUT_DIR, MODEL_ARC,  f"{fold+1}_fold", path))
    
model_config = transformers.DistilBertConfig.from_pretrained(cfg.values.ckp_path)
model_config.output_hidden_states = True
model = NBMEModel(conf=model_config)


# Prepare Test DataLoader
test_dataset = TestDataset(df_tst)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=cfg.values.train_args.train_batch_size,
    num_workers=1
)

In [None]:
# Predict on Test DataLoader
avg_logits_list = []
results_ = []

with torch.no_grad():
    tk = tqdm(test_dataloader, total=len(test_dataloader)) 
    
    test_logits = []
    for idx, test_batch in enumerate(tk):
        ids = test_batch['ids'].to(DEVICE, dtype=torch.long)
        mask = test_batch["mask"].to(DEVICE, dtype=torch.long)
        token_type_ids = test_batch["token_type_ids"].to(DEVICE, dtype=torch.long)
        
        for model_ckp in model_ckps:
            model.load_state_dict(torch.load(model_ckp))
            model.to(DEVICE)
            model.eval()
            
            logits = model(ids=ids, 
                           mask=mask, 
                           token_type_ids=token_type_ids
                           ) #last_hidden_state
            
            test_logits.append(logits.cpu().detach().numpy())
        
        avg_logits = np.mean(test_logits, axis=0)
        results_.append(avg_logits)

    results_ = np.concatenate(results_)

In [None]:
def token_label2idx(text, tokens, token_label):
    """Converts token labels back to character indices."""
    
    char_indices = []

    token_len = len(tokens)
    text_len = len(text)
    char_idx, token_idx = 0, 0
    
    while char_idx < text_len and token_idx < token_len:
        if token_label[token_idx] == 1:
            s = char_idx
            while token_idx < token_len and token_label[token_idx] == 1:
                flag = False
                char_idx += len(re.sub('#', '', tokens[token_idx], flags=re.MULTILINE))
                while char_idx < text_len and text[char_idx] in " \t\n\r\f\v":
                    char_idx += 1
                    flag = True
                token_idx += 1
            e = char_idx - 1 if flag else char_idx
            char_indices.append(' '.join((str(s), str(e))))
        else:
            char_idx += len(re.sub('#', '', tokens[token_idx], flags=re.MULTILINE))
            while char_idx < text_len and text[char_idx] in " \t\n\r\f\v":
                char_idx += 1
            token_idx += 1
    
    return ';'.join(char_indices)

In [None]:
# Make prediction
for idx, ret in enumerate(results_):
    results_[idx] = list(map(lambda x: 1 if x > 0 else 0, ret))
    
assert len(df_tst) == len(results_), "Prediction length does not match input size."

In [None]:
# convert results to character indices.
char_target = []
for text, label in zip(df_tst.pn_history, results_):
    token = TOKENIZER.encode(text).tokens
    char_target.append(token_label2idx(text, token, label))

In [None]:
submission_df['location'] = char_target
submission_df.to_csv(f'submission.csv', index=False)
submission_df