##### This noteboos is based on:
-  [Roberta Strikes Back !](https://www.kaggle.com/code/theoviel/roberta-strikes-back) 
-  [Pytorch Bert baseline to RoBERTa NBME
](https://www.kaggle.com/code/iamsdt/pytorch-bert-baseline-to-roberta-nbme)
-  [【Train】Deberta-v3-large baseline
](https://www.kaggle.com/code/librauee/train-deberta-v3-large-baseline)


##### Apply trained model from this book can get 0.876 LB. I know in Roberta Strikes Back ! the author use Roberta-large get 0.882, hope someone can use it to get higher LB.

# Library

In [None]:
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn import metrics, model_selection
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Config

In [None]:
config = dict(
    seed = 2020,
    num_labels=2,
    num_folds=1,
    fold_to_train = [0],
    device = 'cuda' if torch.cuda.is_available() else 'cpu',
    model_checkpoint = 'roberta-large', 
    learning_rate = 1e-5,
    weight_decay = 1e-2,
    max_length = 512,
    train_batch_size = 4,
    valid_batch_size = 8,
    epochs_to_train = 1,
    total_epochs = 1,
    grad_acc_steps = 4,
    num_cycles=0.5,
    scheduler='linear', # ['linear', 'cosine']
    output_dir = '',
    debug = None,
    precompute_tokens = True
)

# Data Loading

In [None]:
# ====================================================
# Data Loading
# ====================================================
train = pd.read_csv('../input/nbme-score-clinical-patient-notes/train.csv')
train['annotation'] = train['annotation'].apply(ast.literal_eval)
train['location'] = train['location'].apply(ast.literal_eval)
features = pd.read_csv('../input/nbme-score-clinical-patient-notes/features.csv')
def preprocess_features(features):
    features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
    return features
features = preprocess_features(features)
patient_notes = pd.read_csv('../input/nbme-score-clinical-patient-notes/patient_notes.csv')

print(f"train.shape: {train.shape}")
display(train.head())
print(f"features.shape: {features.shape}")
display(features.head())
print(f"patient_notes.shape: {patient_notes.shape}")
display(patient_notes.head())

In [None]:
def process_feature_text(text):
    text = re.sub('I-year', '1-year', text)
    text = re.sub('-OR-', " or ", text)
    text = re.sub('-', ' ', text)
    return text


def clean_spaces(txt):
    txt = re.sub('\n', ' ', txt)
    txt = re.sub('\t', ' ', txt)
    txt = re.sub('\r', ' ', txt)
#     txt = re.sub(r'\s+', ' ', txt)
    return txt

In [None]:
train = train.merge(features, on=['feature_num', 'case_num'], how='left')
train = train.merge(patient_notes, on=['pn_num', 'case_num'], how='left')
train['pn_history'] = train['pn_history'].apply(lambda x: x.strip())
train['feature_text'] = train['feature_text'].apply(process_feature_text)

train['feature_text'] = train['feature_text'].apply(clean_spaces)
train['clean_text'] = train['pn_history'].apply(clean_spaces)
display(train.head())

In [None]:
# incorrect annotation
train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')

train.loc[621, 'annotation'] = ast.literal_eval('[["for the last 2-3 months"]]')
train.loc[621, 'location'] = ast.literal_eval('[["77 100"]]')

train.loc[655, 'annotation'] = ast.literal_eval('[["no heat intolerance"], ["no cold intolerance"]]')
train.loc[655, 'location'] = ast.literal_eval('[["285 292;301 312"], ["285 287;296 312"]]')

train.loc[1262, 'annotation'] = ast.literal_eval('[["mother thyroid problem"]]')
train.loc[1262, 'location'] = ast.literal_eval('[["551 557;565 580"]]')

train.loc[1265, 'annotation'] = ast.literal_eval('[[\'felt like he was going to "pass out"\']]')
train.loc[1265, 'location'] = ast.literal_eval('[["131 135;181 212"]]')

train.loc[1396, 'annotation'] = ast.literal_eval('[["stool , with no blood"]]')
train.loc[1396, 'location'] = ast.literal_eval('[["259 280"]]')

train.loc[1591, 'annotation'] = ast.literal_eval('[["diarrhoe non blooody"]]')
train.loc[1591, 'location'] = ast.literal_eval('[["176 184;201 212"]]')

train.loc[1615, 'annotation'] = ast.literal_eval('[["diarrhea for last 2-3 days"]]')
train.loc[1615, 'location'] = ast.literal_eval('[["249 257;271 288"]]')

train.loc[1664, 'annotation'] = ast.literal_eval('[["no vaginal discharge"]]')
train.loc[1664, 'location'] = ast.literal_eval('[["822 824;907 924"]]')

train.loc[1714, 'annotation'] = ast.literal_eval('[["started about 8-10 hours ago"]]')
train.loc[1714, 'location'] = ast.literal_eval('[["101 129"]]')

train.loc[1929, 'annotation'] = ast.literal_eval('[["no blood in the stool"]]')
train.loc[1929, 'location'] = ast.literal_eval('[["531 539;549 561"]]')

train.loc[2134, 'annotation'] = ast.literal_eval('[["last sexually active 9 months ago"]]')
train.loc[2134, 'location'] = ast.literal_eval('[["540 560;581 593"]]')

train.loc[2191, 'annotation'] = ast.literal_eval('[["right lower quadrant pain"]]')
train.loc[2191, 'location'] = ast.literal_eval('[["32 57"]]')

train.loc[2553, 'annotation'] = ast.literal_eval('[["diarrhoea no blood"]]')
train.loc[2553, 'location'] = ast.literal_eval('[["308 317;376 384"]]')

train.loc[3124, 'annotation'] = ast.literal_eval('[["sweating"]]')
train.loc[3124, 'location'] = ast.literal_eval('[["549 557"]]')

train.loc[3858, 'annotation'] = ast.literal_eval('[["previously as regular"], ["previously eveyr 28-29 days"], ["previously lasting 5 days"], ["previously regular flow"]]')
train.loc[3858, 'location'] = ast.literal_eval('[["102 123"], ["102 112;125 141"], ["102 112;143 157"], ["102 112;159 171"]]')

train.loc[4373, 'annotation'] = ast.literal_eval('[["for 2 months"]]')
train.loc[4373, 'location'] = ast.literal_eval('[["33 45"]]')

train.loc[4763, 'annotation'] = ast.literal_eval('[["35 year old"]]')
train.loc[4763, 'location'] = ast.literal_eval('[["5 16"]]')

train.loc[4782, 'annotation'] = ast.literal_eval('[["darker brown stools"]]')
train.loc[4782, 'location'] = ast.literal_eval('[["175 194"]]')

train.loc[4908, 'annotation'] = ast.literal_eval('[["uncle with peptic ulcer"]]')
train.loc[4908, 'location'] = ast.literal_eval('[["700 723"]]')

train.loc[6016, 'annotation'] = ast.literal_eval('[["difficulty falling asleep"]]')
train.loc[6016, 'location'] = ast.literal_eval('[["225 250"]]')

train.loc[6192, 'annotation'] = ast.literal_eval('[["helps to take care of aging mother and in-laws"]]')
train.loc[6192, 'location'] = ast.literal_eval('[["197 218;236 260"]]')

train.loc[6380, 'annotation'] = ast.literal_eval('[["No hair changes"], ["No skin changes"], ["No GI changes"], ["No palpitations"], ["No excessive sweating"]]')
train.loc[6380, 'location'] = ast.literal_eval('[["480 482;507 519"], ["480 482;499 503;512 519"], ["480 482;521 531"], ["480 482;533 545"], ["480 482;564 582"]]')

train.loc[6562, 'annotation'] = ast.literal_eval('[["stressed due to taking care of her mother"], ["stressed due to taking care of husbands parents"]]')
train.loc[6562, 'location'] = ast.literal_eval('[["290 320;327 337"], ["290 320;342 358"]]')

train.loc[6862, 'annotation'] = ast.literal_eval('[["stressor taking care of many sick family members"]]')
train.loc[6862, 'location'] = ast.literal_eval('[["288 296;324 363"]]')

train.loc[7022, 'annotation'] = ast.literal_eval('[["heart started racing and felt numbness for the 1st time in her finger tips"]]')
train.loc[7022, 'location'] = ast.literal_eval('[["108 182"]]')

train.loc[7422, 'annotation'] = ast.literal_eval('[["first started 5 yrs"]]')
train.loc[7422, 'location'] = ast.literal_eval('[["102 121"]]')

train.loc[8876, 'annotation'] = ast.literal_eval('[["No shortness of breath"]]')
train.loc[8876, 'location'] = ast.literal_eval('[["481 483;533 552"]]')

train.loc[9027, 'annotation'] = ast.literal_eval('[["recent URI"], ["nasal stuffines, rhinorrhea, for 3-4 days"]]')
train.loc[9027, 'location'] = ast.literal_eval('[["92 102"], ["123 164"]]')

train.loc[9938, 'annotation'] = ast.literal_eval('[["irregularity with her cycles"], ["heavier bleeding"], ["changes her pad every couple hours"]]')
train.loc[9938, 'location'] = ast.literal_eval('[["89 117"], ["122 138"], ["368 402"]]')

train.loc[9973, 'annotation'] = ast.literal_eval('[["gaining 10-15 lbs"]]')
train.loc[9973, 'location'] = ast.literal_eval('[["344 361"]]')

train.loc[10513, 'annotation'] = ast.literal_eval('[["weight gain"], ["gain of 10-16lbs"]]')
train.loc[10513, 'location'] = ast.literal_eval('[["600 611"], ["607 623"]]')

train.loc[11551, 'annotation'] = ast.literal_eval('[["seeing her son knows are not real"]]')
train.loc[11551, 'location'] = ast.literal_eval('[["386 400;443 461"]]')

train.loc[11677, 'annotation'] = ast.literal_eval('[["saw him once in the kitchen after he died"]]')
train.loc[11677, 'location'] = ast.literal_eval('[["160 201"]]')

train.loc[12124, 'annotation'] = ast.literal_eval('[["tried Ambien but it didnt work"]]')
train.loc[12124, 'location'] = ast.literal_eval('[["325 337;349 366"]]')

train.loc[12279, 'annotation'] = ast.literal_eval('[["heard what she described as a party later than evening these things did not actually happen"]]')
train.loc[12279, 'location'] = ast.literal_eval('[["405 459;488 524"]]')

train.loc[12289, 'annotation'] = ast.literal_eval('[["experienced seeing her son at the kitchen table these things did not actually happen"]]')
train.loc[12289, 'location'] = ast.literal_eval('[["353 400;488 524"]]')

train.loc[13238, 'annotation'] = ast.literal_eval('[["SCRACHY THROAT"], ["RUNNY NOSE"]]')
train.loc[13238, 'location'] = ast.literal_eval('[["293 307"], ["321 331"]]')

train.loc[13297, 'annotation'] = ast.literal_eval('[["without improvement when taking tylenol"], ["without improvement when taking ibuprofen"]]')
train.loc[13297, 'location'] = ast.literal_eval('[["182 221"], ["182 213;225 234"]]')

train.loc[13299, 'annotation'] = ast.literal_eval('[["yesterday"], ["yesterday"]]')
train.loc[13299, 'location'] = ast.literal_eval('[["79 88"], ["409 418"]]')

train.loc[13845, 'annotation'] = ast.literal_eval('[["headache global"], ["headache throughout her head"]]')
train.loc[13845, 'location'] = ast.literal_eval('[["86 94;230 236"], ["86 94;237 256"]]')

train.loc[14083, 'annotation'] = ast.literal_eval('[["headache generalized in her head"]]')
train.loc[14083, 'location'] = ast.literal_eval('[["56 64;156 179"]]')

In [None]:
train['annotation_length'] = train['annotation'].apply(len)
display(train['annotation_length'].value_counts())

# CV split

In [None]:
# ====================================================
# CV split
# ====================================================
Fold = GroupKFold(n_splits=5)
groups = train['pn_num'].values
for n, (train_index, val_index) in enumerate(Fold.split(train, train['location'], groups)):
    train.loc[val_index, 'kfold'] = int(n)
train['kfold'] = train['kfold'].astype(int)
display(train.groupby('kfold').size())

# Tokenizer

In [None]:
import numpy as np
from transformers import AutoTokenizer


def get_tokenizer(name, precompute=False, df=None, folder=None):
    if folder is None:
        tokenizer = AutoTokenizer.from_pretrained(name)
    else:
        tokenizer = AutoTokenizer.from_pretrained(folder)

    tokenizer.name = name
    tokenizer.special_tokens = {
        "sep": tokenizer.sep_token_id,
        "cls": tokenizer.cls_token_id,
        "pad": tokenizer.pad_token_id,
    }

    if precompute:
        tokenizer.precomputed = precompute_tokens(df, tokenizer)
    else:
        tokenizer.precomputed = None

    return tokenizer


def precompute_tokens(df, tokenizer):
    feature_texts = df["feature_text"].unique()

    ids = {}
    offsets = {}

    for feature_text in feature_texts:
        encoding = tokenizer(
            feature_text,
            return_token_type_ids=True,
            return_offsets_mapping=True,
            return_attention_mask=False,
            add_special_tokens=False,
        )
        ids[feature_text] = encoding["input_ids"]
        offsets[feature_text] = encoding["offset_mapping"]

    texts = df["clean_text"].unique()

    for text in texts:
        encoding = tokenizer(
            text,
            return_token_type_ids=True,
            return_offsets_mapping=True,
            return_attention_mask=False,
            add_special_tokens=False,
        )
        ids[text] = encoding["input_ids"]
        offsets[text] = encoding["offset_mapping"]

    return {"ids": ids, "offsets": offsets}

# Dataset

In [None]:
def loc_list_to_ints(loc_list):
    to_return = []
    for loc_str in loc_list:
        loc_strs = loc_str.split(";")
        for loc in loc_strs:
            start, end = loc.split()
            to_return.append((int(start), int(end)))
    return to_return

def tokenize_and_add_labels(tokenizer, example):
    tokenized_inputs = tokenizer(
        example["feature_text"],
        example["pn_history"],
        truncation="only_second",
        max_length=config['max_length'],
        padding="max_length",
        return_offsets_mapping=True,
        return_token_type_ids=True
    )
    labels = [0.0] * len(tokenized_inputs["input_ids"])
    tokenized_inputs["location_int"] = loc_list_to_ints(example["location"])
    tokenized_inputs["sequence_ids"] = tokenized_inputs.sequence_ids()
    
    for idx, (seq_id, offsets) in enumerate(zip(tokenized_inputs["sequence_ids"], tokenized_inputs["offset_mapping"])):
        if seq_id is None or seq_id == 0:
            labels[idx] = -100
            continue
        exit = False
        token_start, token_end = offsets
        for feature_start, feature_end in tokenized_inputs["location_int"]:
            if exit:
                break
            if token_start >= feature_start and token_end <= feature_end:
                labels[idx] = 1.0
                exit = True
    tokenized_inputs["labels"] = labels
    
    return tokenized_inputs

In [None]:
# ====================================================
# Dataset
# ====================================================
class NBMEData(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        example = self.data.loc[idx]
        tokenized = tokenize_and_add_labels(self.tokenizer, example)
        
        input_ids = np.array(tokenized["input_ids"])
        attention_mask = np.array(tokenized["attention_mask"])
        token_type_ids = np.array(tokenized["token_type_ids"])

        labels = np.array(tokenized["labels"])
        offset_mapping = np.array(tokenized["offset_mapping"])
        sequence_ids = np.array(tokenized["sequence_ids"]).astype("float16")
        
        return {
            'input_ids': input_ids, 
            'attention_mask': attention_mask,
            'token_type_ids':token_type_ids,
            'targets': labels, 
            'offset_mapping': offset_mapping, 
            'sequence_ids': sequence_ids,
        }

# Model

In [None]:

class NBMEModel(nn.Module):
    def __init__(self, num_labels, path=None):
        super().__init__()
                
        self.path = path
        self.num_labels = num_labels
        self.config = transformers.AutoConfig.from_pretrained(config['model_checkpoint'],output_hidden_states=True)

        self.model = transformers.AutoModel.from_pretrained(config['model_checkpoint'], config=self.config)
        self.dropout = nn.Dropout(0.2)
        self.output = nn.Linear(self.config.hidden_size, 1)
        
        if self.path is not None:
            self.load_state_dict(torch.load(self.path)['model'])

    def forward(self, data):
        
        ids = data['input_ids']
        mask = data['attention_mask']
        token_type_ids = data['token_type_ids']
        
        try:
            target = data['targets']
        except:
            target = None

        transformer_out = self.model(ids,mask,token_type_ids)
        sequence_output = transformer_out[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.output(sequence_output)

        output = {
            "logits": torch.sigmoid(logits),
        }
        
        if target is not None:
            loss = get_loss(logits, target)
            output['loss'] = loss
            output['targets'] = target

        return output


# Helpler functions

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.max = 0
        self.min = 1e5

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        if val > self.max:
            self.max = val
        if val < self.min:
            self.min = val
            
def train_fn(model, train_loader, optimizer, scheduler, device, current_epoch):  
    losses = AverageMeter()
    optimizer.zero_grad()

    with tqdm(train_loader, unit="batch") as tepoch:
        for batch_idx, data in enumerate(tepoch):
            for k, v in data.items():
                if k != 'offset_mapping':
                    data[k] = v.to(config['device'])

            model.train()
            loss = model(data)['loss'] / config['grad_acc_steps']
                
            loss.backward()
            losses.update(loss.item(), len(train_loader))
            tepoch.set_postfix(train_loss=losses.avg)
            
            if batch_idx % config['grad_acc_steps'] == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad() 
            
            
def eval_fn(model, valid_loader, device, current_epoch):
    losses = AverageMeter()

    final_targets = []
    final_predictions = []
    offset_mapping = []
    sequence_ids = []

    model.eval()
    
    with torch.no_grad():
        
        with tqdm(valid_loader, unit="batch") as tepoch:

            for batch_idx, data in enumerate(tepoch):
                for k, v in data.items():
                    if k not in  ['offset_mapping', 'sequence_id']:
                        data[k] = v.to(config['device'])
                
                x = model(data)
                loss = x['loss']
                losses.update(loss.item(), len(valid_loader))

                o = x['logits'].detach().cpu().numpy()
                final_predictions.extend(o)
                
                y = data['targets'].detach().cpu().numpy()
                final_targets.extend(y)
                
                offset_mapping.extend(data['offset_mapping'].tolist())
                sequence_ids.extend(data['sequence_ids'].tolist())
    
    predicted_locations = decode_predictions(final_predictions, offset_mapping, sequence_ids, test=False)
    scores = get_score(predicted_locations, offset_mapping, sequence_ids, final_targets)

    return round(losses.avg, 4), round(scores['f1'], 4)


def decode_predictions(preds, offset_mapping, sequence_ids, test=False):
    
    all_predictions = []
    for pred, offsets, seq_ids in zip(preds, offset_mapping, sequence_ids):
        start_idx = None
        current_preds = []
        
        for p, o, s_id in zip(pred, offsets, seq_ids):
            if s_id is None or s_id == 0:
                continue
                
            # if class = 1, track start and end location from offset map
            if p > 0.5:
                if start_idx is None:
                    start_idx = o[0]
                end_idx = o[1]
            
            # if class 0, record previously tracked predictions if not done already
            elif start_idx is not None:
                if test:
                    current_preds.append(f"{start_idx} {end_idx}")
                else:
                    current_preds.append((start_idx, end_idx))
                start_idx = None # reset
                
        if test:
            all_predictions.append("; ".join(current_preds)) # delimiting with semi-colon for submission requirement
        else:
            all_predictions.append(current_preds)
            
    return all_predictions


def get_score(predictions, offset_mapping, sequence_ids, labels):
    all_labels = []
    all_preds = []
    
    for preds, offsets, seq_ids, labels in zip(predictions, offset_mapping, sequence_ids, labels):
        num_chars = max(list(itertools.chain(*offsets)))
        char_labels = np.zeros((num_chars))
        
        # formatting ground truth for evaluation
        for o, s_id, label in zip(offsets, seq_ids, labels):
            # do nothing if sequence id is not 1
            if s_id is None or s_id == 0:
                continue
            if int(label) == 1:
                char_labels[o[0]:o[1]] = 1
            
        # formatting predictions for evaluation
        char_preds = np.zeros((num_chars))
        for start_idx, end_idx in preds:
            char_preds[start_idx:end_idx] = 1
            
        all_labels.extend(char_labels)
        all_preds.extend(char_preds)
        
    results = metrics.precision_recall_fscore_support(all_labels, all_preds, average = "binary")
    return {
        "precision": results[0],
        "recall": results[1],
        "f1": results[2]
    }

def save_checkpoint(model, optimizer, scheduler, epoch, score, best_score, name):
    print('saving model of this epoch as:', name)
    
    name = config['output_dir'] + name
    torch.save(
        {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch,
            'score': score,
            'best_score': best_score,
        },
        name
    )

In [None]:
# ====================================================
# scheduler
# ====================================================
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    if config["scheduler"]=='linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
        )
    elif config["scheduler"]=='cosine':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=config["num_cycles"]
        )
    return scheduler

def make_optimizer(self,model, optimizer_name="AdamW"):
        kwargs = {
                'lr':5e-5,
                'weight_decay':0.01
        }
        
        if optimizer_name == "AdamW":
            optimizer = AdamW(self.parameters(),**kwargs)
            return optimizer
        else:
            raise Exception('Unknown optimizer: {}'.format(optimizer_name))
        

def get_loss(output, target):
    loss_fn = nn.BCEWithLogitsLoss(reduction="none")
    loss = loss_fn(output.view(-1, 1), target.view(-1, 1))
    loss = torch.masked_select(loss, target.view(-1, 1) != -100).mean()
    return loss

# Run

In [None]:
def run(df, fold, tokenizer, device):

    print('Fold:', fold)

    print('\npreparing training data...')
    train_df = df[df['kfold'] != fold].reset_index(drop=True)
    train_dataset = NBMEData(train_df, tokenizer)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config['train_batch_size'],
        shuffle=True,
    )
    
    print('\npreparing validation data...')
    valid_df = df[df['kfold'] == fold].reset_index(drop=True)
    valid_dataset = NBMEData(valid_df, tokenizer)
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=config['valid_batch_size'],
        shuffle=False,
    )

    model = NBMEModel(config['num_labels'])
    torch.save(model.config, 'config.pth')
    model.to(device)
    
    optimizer = make_optimizer(model, "AdamW")

    num_training_steps = (len(train_dataset) // (config['train_batch_size'] * config['grad_acc_steps'])) * config['total_epochs']
    num_warmup_steps = int(num_training_steps * 0.01)
    scheduler = get_scheduler(optimizer, num_warmup_steps, num_training_steps)
    config['num_training_steps'] =  num_training_steps
    config['num_warmup_steps'] =  num_warmup_steps 

    epoch_start = 0
    best_score = -1
    start = time.time()

    
    for epoch in range(epoch_start, epoch_start + config['epochs_to_train']):   

        print(f'\n\n\nTraining Epoch: {epoch}')
        train_fn(model, train_loader, optimizer, scheduler, device, epoch)
        
        print('Evaluation...')
        val_loss, val_score = eval_fn(
            model=model, 
            valid_loader=valid_loader, 
            device=device,
            current_epoch=epoch,
        )
        
        if val_score > best_score:
            best_score = val_score

            save_checkpoint(model, optimizer, scheduler, epoch, val_score, best_score, f'best_model_{config["model_checkpoint"]}_{fold}.pth')

        save_checkpoint(model, optimizer, scheduler, epoch, val_score, best_score, f'last_model_{config["model_checkpoint"]}_{fold}.pth')

        print('Valid Score:', val_score, 'Valid Loss:', val_loss, 'Best Score:', best_score)
        
    print(f'Best Score: {best_score}, Time Taken: {round(time.time() - start, 4)}s')

In [None]:
tokenizer = get_tokenizer(
    config['model_checkpoint'], precompute=config['precompute_tokens'], df=train, folder=None)

if config['debug']:
    train = train.sample(config['debug']).reset_index(drop=True)

for fold in config['fold_to_train']:
    run(
        df=train, 
        fold=fold,
        tokenizer=tokenizer,
        device=config['device']
    )
    torch.cuda.empty_cache()
    gc.collect()