# Trainining on all dataset ModernBERT



![Comparison of BERT variants in 2025](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/modernbert/modernbert_pareto_curve.png)

# Directory settings

In [1]:
# ====================================================
# Directory settings
# ====================================================
import os, shutil

OUTPUT_DIR = os.path.join(os.path.dirname(os.getcwd()), 'modern-bert/')
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

for item in os.listdir(OUTPUT_DIR):
    item_path = os.path.join(OUTPUT_DIR, item)
    if os.path.isfile(item_path):  # If it's a file, remove it
        os.remove(item_path)
        print(f"Removed file: {item}")
    elif os.path.isdir(item_path):  # If it's a directory, remove it
        shutil.rmtree(item_path)
        print(f"Removed folder: {item}")

# Library

In [2]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset

os.system('pip uninstall -y transformers')
os.system('python -m pip install transformers>=4.48.0')
os.system('pip uninstall -y tokenizer')
os.system('python -m pip install tokenizers')
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Found existing installation: transformers 4.48.3


Uninstalling transformers-4.48.3:
  Successfully uninstalled transformers-4.48.3


[0m

[0m

[0m



[0m

tokenizers.__version__: 0.21.0
transformers.__version__: 4.48.3


env: TOKENIZERS_PARALLELISM=true


In [3]:
# ====================================================
# CFG
# ====================================================
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" # if running on Kaggle, use this path 
    model = "answerdotai/ModernBERT-base" # if running on local machine, use this path  
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=1 #0.05
    num_warmup_steps=0
    epochs=5
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 12 #batch_size=16
    fc_dropout=0.2
    max_len=512
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=2 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True
    
if CFG.debug:
    CFG.epochs = 3
    CFG.trn_fold = [0]

# ====================================================
# tokenizer
# ====================================================
tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

In [4]:
# ====================================================
# wandb: local notebook login
# ====================================================
if CFG.wandb:
    
    import wandb
    from dotenv import load_dotenv
    import os

    # Load environment variables from .env file in the parent folder
    load_dotenv(os.path.join(os.path.dirname(os.getcwd()), '.env'))

    secret_value_0 = os.getenv("WANDB_API_KEY")
    if secret_value_0:
        wandb.login(key=secret_value_0)
        anony = None
    else:
        anony = "must"
        print('If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. \nGet your W&B access token from here: https://wandb.ai/authorize')

    def class2dict(f):
        return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

    run = wandb.init(project='NBME-Public', 
                     name=CFG.model,
                     config=class2dict(CFG),
                     group=CFG.model,
                     job_type="train",
                     anonymous=anony)

If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. 
Get your W&B access token from here: https://wandb.ai/authorize


[34m[1mwandb[0m: Currently logged in as: [33mzehrakorkusuz[0m ([33mproject-zero[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Tracking run with wandb version 0.19.6


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/workspace/wandb/run-20250211_211117-bnksm6wu[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33manswerdotai/ModernBERT-base[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/project-zero/NBME-Public?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/project-zero/NBME-Public/runs/bnksm6wu?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m




# Helper functions for scoring

In [5]:
# From https://www.kaggle.com/theoviel/evaluation-metric-folds-baseline

def micro_f1(preds, truths):
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


def spans_to_binary(spans, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary


def span_micro_f1(preds, truths):
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
    """
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))
    return micro_f1(bin_preds, bin_truths)

In [6]:
## UPDATED FOR HANDLING NONEs in THE PSEUDOLABELS / only including confident inferences
import ast

# Update the create_labels_for_scoring function
def create_labels_for_scoring(df):
    # Initialize the 'location_for_create_labels' column with empty lists
    df['location_for_create_labels'] = [ast.literal_eval(f'[]')] * len(df)
    
    for i in range(len(df)):
        lst = df.loc[i, 'location']
        if lst and lst != '':  # Check if lst is not None or empty
            # Ensure lst is a list of strings
            if isinstance(lst, str):
                lst = [lst]
            elif isinstance(lst, list):
                lst = [str(item) for item in lst]
            new_lst = ';'.join(lst)
            df.loc[i, 'location_for_create_labels'] = ast.literal_eval(f'[["{new_lst}"]]')

    # Create labels
    truths = []
    for location_list in df['location_for_create_labels'].values:
        truth = []
        if len(location_list) > 0:
            location = location_list[0]
            for loc in [s.split() for s in location.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                truth.append([start, end])
        truths.append(truth)
    
    return truths


def get_char_probs(texts, predictions, tokenizer):
    results = [np.zeros(len(t)) for t in texts]
    for i, (text, prediction) in enumerate(zip(texts, predictions)):
        encoded = tokenizer(text, 
                            add_special_tokens=True,
                            return_offsets_mapping=True)
        for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)):
            start = offset_mapping[0]
            end = offset_mapping[1]
            results[i][start:end] = pred
    return results


def get_results(char_probs, th=0.5):
    results = []
    for char_prob in char_probs:
        result = np.where(char_prob >= th)[0] + 1
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results


def get_predictions(results):
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    return predictions

# Utils

In [7]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    score = span_micro_f1(y_true, y_pred)
    return score


def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

# Data Loading

In [8]:
# ====================================================
# Data Loading
# ====================================================
train = pd.read_csv('/workspace/data/train.csv')
train['annotation'] = train['annotation'].apply(ast.literal_eval)
train['location'] = train['location'].apply(ast.literal_eval)
features = pd.read_csv('/workspace/data/features.csv')
def preprocess_features(features):
    features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
    return features
features = preprocess_features(features)
patient_notes = pd.read_csv('/workspace/data/patient_notes.csv')

print(f"train.shape: {train.shape}")
display(train.head())
print(f"features.shape: {features.shape}")
display(features.head())
print(f"patient_notes.shape: {patient_notes.shape}")
display(patient_notes.head())


train.shape: (14300, 6)


Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724]
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693]
2,00016_002,0,16,2,[chest pressure],[203 217]
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]"
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258]


features.shape: (143, 3)


Unnamed: 0,feature_num,case_num,feature_text
0,0,0,Family-history-of-MI-OR-Family-history-of-myoc...
1,1,0,Family-history-of-thyroid-disorder
2,2,0,Chest-pressure
3,3,0,Intermittent-symptoms
4,4,0,Lightheaded


patient_notes.shape: (42146, 3)


Unnamed: 0,pn_num,case_num,pn_history
0,0,0,"17-year-old male, has come to the student heal..."
1,1,0,17 yo male with recurrent palpitations for the...
2,2,0,Dillon Cleveland is a 17 y.o. male patient wit...
3,3,0,a 17 yo m c/o palpitation started 3 mos ago; \...
4,4,0,17yo male with no pmh here for evaluation of p...


In [9]:
train['annotation_length'] = train['annotation'].apply(len) # refers to number of annotation
train

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,annotation_length
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],1
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],1
2,00016_002,0,16,2,[chest pressure],[203 217],1
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",2
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],1
...,...,...,...,...,...,...,...
14295,95333_912,9,95333,912,[],[],0
14296,95333_913,9,95333,913,[],[],0
14297,95333_914,9,95333,914,[photobia],[274 282],1
14298,95333_915,9,95333,915,[no sick contacts],[421 437],1


In [10]:
train = train.merge(features, on=['feature_num', 'case_num'], how='left')
train = train.merge(patient_notes, on=['pn_num', 'case_num'], how='left')
display(train.head())

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,annotation_length,feature_text,pn_history
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],1,Family-history-of-MI-OR-Family-history-of-myoc...,HPI: 17yo M presents with palpitations. Patien...
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],1,Family-history-of-thyroid-disorder,HPI: 17yo M presents with palpitations. Patien...
2,00016_002,0,16,2,[chest pressure],[203 217],1,Chest-pressure,HPI: 17yo M presents with palpitations. Patien...
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",2,Intermittent-symptoms,HPI: 17yo M presents with palpitations. Patien...
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],1,Lightheaded,HPI: 17yo M presents with palpitations. Patien...


In [11]:
# incorrect annotation
train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')

train.loc[621, 'annotation'] = ast.literal_eval('[["for the last 2-3 months"]]')
train.loc[621, 'location'] = ast.literal_eval('[["77 100"]]')

train.loc[655, 'annotation'] = ast.literal_eval('[["no heat intolerance"], ["no cold intolerance"]]')
train.loc[655, 'location'] = ast.literal_eval('[["285 292;301 312"], ["285 287;296 312"]]')

train.loc[1262, 'annotation'] = ast.literal_eval('[["mother thyroid problem"]]')
train.loc[1262, 'location'] = ast.literal_eval('[["551 557;565 580"]]')

train.loc[1265, 'annotation'] = ast.literal_eval('[[\'felt like he was going to "pass out"\']]')
train.loc[1265, 'location'] = ast.literal_eval('[["131 135;181 212"]]')

train.loc[1396, 'annotation'] = ast.literal_eval('[["stool , with no blood"]]')
train.loc[1396, 'location'] = ast.literal_eval('[["259 280"]]')

train.loc[1591, 'annotation'] = ast.literal_eval('[["diarrhoe non blooody"]]')
train.loc[1591, 'location'] = ast.literal_eval('[["176 184;201 212"]]')

train.loc[1615, 'annotation'] = ast.literal_eval('[["diarrhea for last 2-3 days"]]')
train.loc[1615, 'location'] = ast.literal_eval('[["249 257;271 288"]]')

train.loc[1664, 'annotation'] = ast.literal_eval('[["no vaginal discharge"]]')
train.loc[1664, 'location'] = ast.literal_eval('[["822 824;907 924"]]')

train.loc[1714, 'annotation'] = ast.literal_eval('[["started about 8-10 hours ago"]]')
train.loc[1714, 'location'] = ast.literal_eval('[["101 129"]]')

train.loc[1929, 'annotation'] = ast.literal_eval('[["no blood in the stool"]]')
train.loc[1929, 'location'] = ast.literal_eval('[["531 539;549 561"]]')

train.loc[2134, 'annotation'] = ast.literal_eval('[["last sexually active 9 months ago"]]')
train.loc[2134, 'location'] = ast.literal_eval('[["540 560;581 593"]]')

train.loc[2191, 'annotation'] = ast.literal_eval('[["right lower quadrant pain"]]')
train.loc[2191, 'location'] = ast.literal_eval('[["32 57"]]')

train.loc[2553, 'annotation'] = ast.literal_eval('[["diarrhoea no blood"]]')
train.loc[2553, 'location'] = ast.literal_eval('[["308 317;376 384"]]')

train.loc[3124, 'annotation'] = ast.literal_eval('[["sweating"]]')
train.loc[3124, 'location'] = ast.literal_eval('[["549 557"]]')

train.loc[3858, 'annotation'] = ast.literal_eval('[["previously as regular"], ["previously eveyr 28-29 days"], ["previously lasting 5 days"], ["previously regular flow"]]')
train.loc[3858, 'location'] = ast.literal_eval('[["102 123"], ["102 112;125 141"], ["102 112;143 157"], ["102 112;159 171"]]')

train.loc[4373, 'annotation'] = ast.literal_eval('[["for 2 months"]]')
train.loc[4373, 'location'] = ast.literal_eval('[["33 45"]]')

train.loc[4763, 'annotation'] = ast.literal_eval('[["35 year old"]]')
train.loc[4763, 'location'] = ast.literal_eval('[["5 16"]]')

train.loc[4782, 'annotation'] = ast.literal_eval('[["darker brown stools"]]')
train.loc[4782, 'location'] = ast.literal_eval('[["175 194"]]')

train.loc[4908, 'annotation'] = ast.literal_eval('[["uncle with peptic ulcer"]]')
train.loc[4908, 'location'] = ast.literal_eval('[["700 723"]]')

train.loc[6016, 'annotation'] = ast.literal_eval('[["difficulty falling asleep"]]')
train.loc[6016, 'location'] = ast.literal_eval('[["225 250"]]')

train.loc[6192, 'annotation'] = ast.literal_eval('[["helps to take care of aging mother and in-laws"]]')
train.loc[6192, 'location'] = ast.literal_eval('[["197 218;236 260"]]')

train.loc[6380, 'annotation'] = ast.literal_eval('[["No hair changes"], ["No skin changes"], ["No GI changes"], ["No palpitations"], ["No excessive sweating"]]')
train.loc[6380, 'location'] = ast.literal_eval('[["480 482;507 519"], ["480 482;499 503;512 519"], ["480 482;521 531"], ["480 482;533 545"], ["480 482;564 582"]]')

train.loc[6562, 'annotation'] = ast.literal_eval('[["stressed due to taking care of her mother"], ["stressed due to taking care of husbands parents"]]')
train.loc[6562, 'location'] = ast.literal_eval('[["290 320;327 337"], ["290 320;342 358"]]')

train.loc[6862, 'annotation'] = ast.literal_eval('[["stressor taking care of many sick family members"]]')
train.loc[6862, 'location'] = ast.literal_eval('[["288 296;324 363"]]')

train.loc[7022, 'annotation'] = ast.literal_eval('[["heart started racing and felt numbness for the 1st time in her finger tips"]]')
train.loc[7022, 'location'] = ast.literal_eval('[["108 182"]]')

train.loc[7422, 'annotation'] = ast.literal_eval('[["first started 5 yrs"]]')
train.loc[7422, 'location'] = ast.literal_eval('[["102 121"]]')

train.loc[8876, 'annotation'] = ast.literal_eval('[["No shortness of breath"]]')
train.loc[8876, 'location'] = ast.literal_eval('[["481 483;533 552"]]')

train.loc[9027, 'annotation'] = ast.literal_eval('[["recent URI"], ["nasal stuffines, rhinorrhea, for 3-4 days"]]')
train.loc[9027, 'location'] = ast.literal_eval('[["92 102"], ["123 164"]]')

train.loc[9938, 'annotation'] = ast.literal_eval('[["irregularity with her cycles"], ["heavier bleeding"], ["changes her pad every couple hours"]]')
train.loc[9938, 'location'] = ast.literal_eval('[["89 117"], ["122 138"], ["368 402"]]')

train.loc[9973, 'annotation'] = ast.literal_eval('[["gaining 10-15 lbs"]]')
train.loc[9973, 'location'] = ast.literal_eval('[["344 361"]]')

train.loc[10513, 'annotation'] = ast.literal_eval('[["weight gain"], ["gain of 10-16lbs"]]')
train.loc[10513, 'location'] = ast.literal_eval('[["600 611"], ["607 623"]]')

train.loc[11551, 'annotation'] = ast.literal_eval('[["seeing her son knows are not real"]]')
train.loc[11551, 'location'] = ast.literal_eval('[["386 400;443 461"]]')

train.loc[11677, 'annotation'] = ast.literal_eval('[["saw him once in the kitchen after he died"]]')
train.loc[11677, 'location'] = ast.literal_eval('[["160 201"]]')

train.loc[12124, 'annotation'] = ast.literal_eval('[["tried Ambien but it didnt work"]]')
train.loc[12124, 'location'] = ast.literal_eval('[["325 337;349 366"]]')

train.loc[12279, 'annotation'] = ast.literal_eval('[["heard what she described as a party later than evening these things did not actually happen"]]')
train.loc[12279, 'location'] = ast.literal_eval('[["405 459;488 524"]]')

train.loc[12289, 'annotation'] = ast.literal_eval('[["experienced seeing her son at the kitchen table these things did not actually happen"]]')
train.loc[12289, 'location'] = ast.literal_eval('[["353 400;488 524"]]')

train.loc[13238, 'annotation'] = ast.literal_eval('[["SCRACHY THROAT"], ["RUNNY NOSE"]]')
train.loc[13238, 'location'] = ast.literal_eval('[["293 307"], ["321 331"]]')

train.loc[13297, 'annotation'] = ast.literal_eval('[["without improvement when taking tylenol"], ["without improvement when taking ibuprofen"]]')
train.loc[13297, 'location'] = ast.literal_eval('[["182 221"], ["182 213;225 234"]]')

train.loc[13299, 'annotation'] = ast.literal_eval('[["yesterday"], ["yesterday"]]')
train.loc[13299, 'location'] = ast.literal_eval('[["79 88"], ["409 418"]]')

train.loc[13845, 'annotation'] = ast.literal_eval('[["headache global"], ["headache throughout her head"]]')
train.loc[13845, 'location'] = ast.literal_eval('[["86 94;230 236"], ["86 94;237 256"]]')

train.loc[14083, 'annotation'] = ast.literal_eval('[["headache generalized in her head"]]')
train.loc[14083, 'location'] = ast.literal_eval('[["56 64;156 179"]]')

In [12]:
pseudolabels = pd.read_csv('/workspace/data/pseudolabels.csv')
pseudolabels =  pseudolabels.merge(features, on=['feature_num', 'case_num'], how='left')
pseudolabels = pseudolabels.merge(patient_notes, on=['pn_num', 'case_num'], how='left')

pseudolabels = pseudolabels.dropna(subset=['location'])
pseudolabels = pseudolabels[pseudolabels['location'].apply(lambda x: len(x) > 0)]

def convert_location_format(location):
    if pd.isna(location) or not location:
        return []
    if isinstance(location, str):
        return [loc.strip() for loc in location.split(';')]
    return location

pseudolabels['location'] = pseudolabels['location'].apply(convert_location_format)

pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal..."
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...
...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...


In [13]:
pseudolabels['annotation_length'] = pseudolabels['location'].apply(len)

In [14]:
pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history,annotation_length
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal...",1
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...,1
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...,1
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...,1
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...,1
...,...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,2
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,1
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...,1
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...,1


In [15]:
pseudolabels = pseudolabels.drop(columns=['fold'])

In [16]:
# add pseudolabels to train
## IF YOU WANT TO INCLUDE PSEUDOLABELS
### train = pd.concat([train, pseudolabels], axis=0, ignore_index=True)

# CV split

In [17]:
# ====================================================
# CV split
# ====================================================
train = train.reset_index(drop=True)

Fold = GroupKFold(n_splits=CFG.n_fold)
groups = train['pn_num'].values
for n, (train_index, val_index) in enumerate(Fold.split(train, train['location'], groups)):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)
display(train.groupby('fold').size())

fold
0    2860
1    2860
2    2860
3    2860
4    2860
dtype: int64

In [18]:
if CFG.debug:
    display(train.groupby('fold').size())
    train = train.sample(n=1000, random_state=0).reset_index(drop=True)
    display(train.groupby('fold').size())

# Dataset

In [19]:
# Define max_len
for text_col in ['pn_history']:
    pn_history_lengths = []
    tk0 = tqdm(patient_notes[text_col].fillna("").values, total=len(patient_notes))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        pn_history_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(pn_history_lengths)}')

for text_col in ['feature_text']:
    features_lengths = []
    tk0 = tqdm(features[text_col].fillna("").values, total=len(features))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        features_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(features_lengths)}')

CFG.max_len = max(pn_history_lengths) + max(features_lengths) + 3  # cls & sep & sep
LOGGER.info(f"max_len: {CFG.max_len}")

  0%|          | 0/42146 [00:00<?, ?it/s]

  1%|          | 234/42146 [00:00<00:17, 2335.79it/s]

  1%|          | 491/42146 [00:00<00:16, 2472.18it/s]

  2%|▏         | 750/42146 [00:00<00:16, 2524.73it/s]

  2%|▏         | 1004/42146 [00:00<00:16, 2529.92it/s]

  3%|▎         | 1259/42146 [00:00<00:16, 2536.08it/s]

  4%|▎         | 1514/42146 [00:00<00:15, 2539.52it/s]

  4%|▍         | 1770/42146 [00:00<00:15, 2544.90it/s]

  5%|▍         | 2025/42146 [00:00<00:15, 2543.03it/s]

  5%|▌         | 2284/42146 [00:00<00:15, 2556.72it/s]

  6%|▌         | 2540/42146 [00:01<00:15, 2541.34it/s]

  7%|▋         | 2795/42146 [00:01<00:15, 2514.67it/s]

  7%|▋         | 3047/42146 [00:01<00:15, 2506.93it/s]

  8%|▊         | 3298/42146 [00:01<00:15, 2464.84it/s]

  8%|▊         | 3545/42146 [00:01<00:15, 2432.46it/s]

  9%|▉         | 3789/42146 [00:01<00:15, 2411.18it/s]

 10%|▉         | 4031/42146 [00:01<00:15, 2402.51it/s]

 10%|█         | 4272/42146 [00:01<00:15, 2380.41it/s]

 11%|█         | 4511/42146 [00:01<00:15, 2374.54it/s]

 11%|█▏        | 4749/42146 [00:01<00:15, 2374.15it/s]

 12%|█▏        | 4987/42146 [00:02<00:15, 2350.07it/s]

 12%|█▏        | 5223/42146 [00:02<00:15, 2341.25it/s]

 13%|█▎        | 5470/42146 [00:02<00:15, 2378.61it/s]

 14%|█▎        | 5720/42146 [00:02<00:15, 2412.56it/s]

 14%|█▍        | 5969/42146 [00:02<00:14, 2432.56it/s]

 15%|█▍        | 6219/42146 [00:02<00:14, 2452.19it/s]

 15%|█▌        | 6471/42146 [00:02<00:14, 2470.17it/s]

 16%|█▌        | 6721/42146 [00:02<00:14, 2478.14it/s]

 17%|█▋        | 6975/42146 [00:02<00:14, 2495.50it/s]

 17%|█▋        | 7227/42146 [00:02<00:13, 2501.88it/s]

 18%|█▊        | 7478/42146 [00:03<00:13, 2502.69it/s]

 18%|█▊        | 7731/42146 [00:03<00:13, 2507.51it/s]

 19%|█▉        | 7984/42146 [00:03<00:13, 2512.52it/s]

 20%|█▉        | 8236/42146 [00:03<00:13, 2506.98it/s]

 20%|██        | 8487/42146 [00:03<00:13, 2501.25it/s]

 21%|██        | 8738/42146 [00:03<00:13, 2487.12it/s]

 21%|██▏       | 8988/42146 [00:03<00:13, 2490.54it/s]

 22%|██▏       | 9238/42146 [00:03<00:13, 2470.69it/s]

 23%|██▎       | 9486/42146 [00:03<00:13, 2473.11it/s]

 23%|██▎       | 9734/42146 [00:03<00:13, 2473.72it/s]

 24%|██▎       | 9982/42146 [00:04<00:13, 2468.22it/s]

 24%|██▍       | 10235/42146 [00:04<00:12, 2484.56it/s]

 25%|██▍       | 10484/42146 [00:04<00:12, 2478.32it/s]

 25%|██▌       | 10734/42146 [00:04<00:12, 2483.43it/s]

 26%|██▌       | 10983/42146 [00:04<00:12, 2476.50it/s]

 27%|██▋       | 11231/42146 [00:04<00:12, 2473.80it/s]

 27%|██▋       | 11481/42146 [00:04<00:12, 2480.39it/s]

 28%|██▊       | 11730/42146 [00:04<00:12, 2479.56it/s]

 28%|██▊       | 11983/42146 [00:04<00:12, 2491.78it/s]

 29%|██▉       | 12236/42146 [00:04<00:11, 2502.26it/s]

 30%|██▉       | 12489/42146 [00:05<00:11, 2508.16it/s]

 30%|███       | 12740/42146 [00:05<00:11, 2504.57it/s]

 31%|███       | 12992/42146 [00:05<00:11, 2507.94it/s]

 31%|███▏      | 13243/42146 [00:05<00:11, 2495.99it/s]

 32%|███▏      | 13493/42146 [00:05<00:11, 2479.98it/s]

 33%|███▎      | 13742/42146 [00:05<00:11, 2423.86it/s]

 33%|███▎      | 13991/42146 [00:05<00:11, 2443.05it/s]

 34%|███▍      | 14240/42146 [00:05<00:11, 2455.16it/s]

 34%|███▍      | 14486/42146 [00:05<00:11, 2447.81it/s]

 35%|███▍      | 14734/42146 [00:05<00:11, 2456.61it/s]

 36%|███▌      | 14980/42146 [00:06<00:11, 2456.49it/s]

 36%|███▌      | 15226/42146 [00:06<00:11, 2410.75it/s]

 37%|███▋      | 15469/42146 [00:06<00:11, 2415.59it/s]

 37%|███▋      | 15711/42146 [00:06<00:10, 2403.51it/s]

 38%|███▊      | 15952/42146 [00:06<00:10, 2395.80it/s]

 38%|███▊      | 16192/42146 [00:06<00:10, 2377.15it/s]

 39%|███▉      | 16432/42146 [00:06<00:10, 2381.51it/s]

 40%|███▉      | 16675/42146 [00:06<00:10, 2395.05it/s]

 40%|████      | 16917/42146 [00:06<00:10, 2400.68it/s]

 41%|████      | 17163/42146 [00:06<00:10, 2417.83it/s]

 41%|████▏     | 17409/42146 [00:07<00:10, 2428.10it/s]

 42%|████▏     | 17657/42146 [00:07<00:10, 2441.72it/s]

 42%|████▏     | 17905/42146 [00:07<00:09, 2450.42it/s]

 43%|████▎     | 18154/42146 [00:07<00:09, 2462.06it/s]

 44%|████▎     | 18401/42146 [00:07<00:09, 2450.85it/s]

 44%|████▍     | 18647/42146 [00:07<00:09, 2440.04it/s]

 45%|████▍     | 18894/42146 [00:07<00:09, 2448.46it/s]

 45%|████▌     | 19139/42146 [00:07<00:09, 2445.56it/s]

 46%|████▌     | 19385/42146 [00:07<00:09, 2448.15it/s]

 47%|████▋     | 19634/42146 [00:07<00:09, 2459.44it/s]

 47%|████▋     | 19882/42146 [00:08<00:09, 2463.91it/s]

 48%|████▊     | 20129/42146 [00:08<00:08, 2456.68it/s]

 48%|████▊     | 20381/42146 [00:08<00:08, 2475.03it/s]

 49%|████▉     | 20637/42146 [00:08<00:08, 2497.81it/s]

 50%|████▉     | 20888/42146 [00:08<00:08, 2499.55it/s]

 50%|█████     | 21142/42146 [00:08<00:08, 2509.12it/s]

 51%|█████     | 21394/42146 [00:08<00:08, 2508.57it/s]

 51%|█████▏    | 21646/42146 [00:08<00:08, 2510.15it/s]

 52%|█████▏    | 21901/42146 [00:08<00:08, 2519.71it/s]

 53%|█████▎    | 22153/42146 [00:08<00:07, 2517.54it/s]

 53%|█████▎    | 22407/42146 [00:09<00:07, 2523.63it/s]

 54%|█████▍    | 22662/42146 [00:09<00:07, 2529.76it/s]

 54%|█████▍    | 22915/42146 [00:09<00:07, 2516.47it/s]

 55%|█████▍    | 23168/42146 [00:09<00:07, 2518.43it/s]

 56%|█████▌    | 23420/42146 [00:09<00:07, 2516.94it/s]

 56%|█████▌    | 23676/42146 [00:09<00:07, 2527.14it/s]

 57%|█████▋    | 23929/42146 [00:09<00:07, 2520.03it/s]

 57%|█████▋    | 24183/42146 [00:09<00:07, 2525.01it/s]

 58%|█████▊    | 24438/42146 [00:09<00:06, 2529.83it/s]

 59%|█████▊    | 24695/42146 [00:09<00:06, 2540.64it/s]

 59%|█████▉    | 24952/42146 [00:10<00:06, 2548.81it/s]

 60%|█████▉    | 25207/42146 [00:10<00:06, 2537.33it/s]

 60%|██████    | 25461/42146 [00:10<00:06, 2525.12it/s]

 61%|██████    | 25716/42146 [00:10<00:06, 2530.15it/s]

 62%|██████▏   | 25972/42146 [00:10<00:06, 2538.28it/s]

 62%|██████▏   | 26229/42146 [00:10<00:06, 2546.45it/s]

 63%|██████▎   | 26488/42146 [00:10<00:06, 2557.20it/s]

 63%|██████▎   | 26744/42146 [00:10<00:06, 2549.46it/s]

 64%|██████▍   | 26999/42146 [00:10<00:05, 2531.53it/s]

 65%|██████▍   | 27253/42146 [00:11<00:05, 2495.74it/s]

 65%|██████▌   | 27503/42146 [00:11<00:05, 2480.57it/s]

 66%|██████▌   | 27761/42146 [00:11<00:05, 2508.53it/s]

 66%|██████▋   | 28020/42146 [00:11<00:05, 2531.98it/s]

 67%|██████▋   | 28278/42146 [00:11<00:05, 2544.80it/s]

 68%|██████▊   | 28533/42146 [00:11<00:05, 2543.45it/s]

 68%|██████▊   | 28788/42146 [00:11<00:05, 2518.87it/s]

 69%|██████▉   | 29040/42146 [00:11<00:05, 2468.18it/s]

 69%|██████▉   | 29288/42146 [00:11<00:05, 2446.30it/s]

 70%|███████   | 29533/42146 [00:11<00:05, 2417.10it/s]

 71%|███████   | 29775/42146 [00:12<00:05, 2417.59it/s]

 71%|███████   | 30017/42146 [00:12<00:05, 2413.53it/s]

 72%|███████▏  | 30259/42146 [00:12<00:04, 2385.60it/s]

 72%|███████▏  | 30498/42146 [00:12<00:04, 2379.24it/s]

 73%|███████▎  | 30736/42146 [00:12<00:04, 2372.09it/s]

 73%|███████▎  | 30974/42146 [00:12<00:04, 2355.02it/s]

 74%|███████▍  | 31213/42146 [00:12<00:04, 2362.63it/s]

 75%|███████▍  | 31450/42146 [00:12<00:04, 2353.49it/s]

 75%|███████▌  | 31686/42146 [00:12<00:04, 2308.55it/s]

 76%|███████▌  | 31923/42146 [00:12<00:04, 2325.67it/s]

 76%|███████▋  | 32156/42146 [00:13<00:04, 2325.94it/s]

 77%|███████▋  | 32389/42146 [00:13<00:04, 2298.75it/s]

 77%|███████▋  | 32623/42146 [00:13<00:04, 2309.29it/s]

 78%|███████▊  | 32857/42146 [00:13<00:04, 2316.89it/s]

 79%|███████▊  | 33093/42146 [00:13<00:03, 2327.65it/s]

 79%|███████▉  | 33326/42146 [00:13<00:03, 2323.45it/s]

 80%|███████▉  | 33559/42146 [00:13<00:03, 2321.73it/s]

 80%|████████  | 33792/42146 [00:13<00:03, 2306.79it/s]

 81%|████████  | 34026/42146 [00:13<00:03, 2316.25it/s]

 81%|████████▏ | 34258/42146 [00:13<00:03, 2317.00it/s]

 82%|████████▏ | 34490/42146 [00:14<00:03, 2315.72it/s]

 82%|████████▏ | 34722/42146 [00:14<00:03, 2315.98it/s]

 83%|████████▎ | 34955/42146 [00:14<00:03, 2318.92it/s]

 83%|████████▎ | 35187/42146 [00:14<00:03, 2265.69it/s]

 84%|████████▍ | 35418/42146 [00:14<00:02, 2276.64it/s]

 85%|████████▍ | 35649/42146 [00:14<00:02, 2285.28it/s]

 85%|████████▌ | 35879/42146 [00:14<00:02, 2289.20it/s]

 86%|████████▌ | 36109/42146 [00:14<00:02, 2288.65it/s]

 86%|████████▌ | 36338/42146 [00:14<00:02, 2280.31it/s]

 87%|████████▋ | 36567/42146 [00:14<00:02, 2280.21it/s]

 87%|████████▋ | 36796/42146 [00:15<00:02, 2279.83it/s]

 88%|████████▊ | 37029/42146 [00:15<00:02, 2291.72it/s]

 88%|████████▊ | 37287/42146 [00:15<00:02, 2376.42it/s]

 89%|████████▉ | 37550/42146 [00:15<00:01, 2450.99it/s]

 90%|████████▉ | 37813/42146 [00:15<00:01, 2504.36it/s]

 90%|█████████ | 38072/42146 [00:15<00:01, 2529.28it/s]

 91%|█████████ | 38325/42146 [00:15<00:01, 2509.86it/s]

 92%|█████████▏| 38577/42146 [00:15<00:01, 2511.68it/s]

 92%|█████████▏| 38833/42146 [00:15<00:01, 2524.24it/s]

 93%|█████████▎| 39094/42146 [00:15<00:01, 2549.41it/s]

 93%|█████████▎| 39356/42146 [00:16<00:01, 2567.48it/s]

 94%|█████████▍| 39616/42146 [00:16<00:00, 2574.38it/s]

 95%|█████████▍| 39879/42146 [00:16<00:00, 2590.01it/s]

 95%|█████████▌| 40146/42146 [00:16<00:00, 2610.95it/s]

 96%|█████████▌| 40408/42146 [00:16<00:00, 2611.12it/s]

 96%|█████████▋| 40670/42146 [00:16<00:00, 2595.68it/s]

 97%|█████████▋| 40934/42146 [00:16<00:00, 2608.45it/s]

 98%|█████████▊| 41195/42146 [00:16<00:00, 2606.63it/s]

 98%|█████████▊| 41462/42146 [00:16<00:00, 2625.42it/s]

 99%|█████████▉| 41725/42146 [00:16<00:00, 2600.81it/s]

100%|█████████▉| 41986/42146 [00:17<00:00, 2597.12it/s]

100%|██████████| 42146/42146 [00:17<00:00, 2458.86it/s]


pn_history max(lengths): 412


  0%|          | 0/143 [00:00<?, ?it/s]

100%|██████████| 143/143 [00:00<00:00, 20218.62it/s]


feature_text max(lengths): 28


max_len: 443


In [20]:
# ====================================================
# Dataset
# ====================================================
def prepare_input(cfg, text, feature_text):

    # Omit token_type_ids for ModernBERT
    is_token_type_ids = CFG.model != 'answerdotai/ModernBERT-base'

    inputs = cfg.tokenizer(text, feature_text, 
                           add_special_tokens=True,
                           max_length=CFG.max_len,
                           padding="max_length",
                           truncation=True,
                           return_offsets_mapping=False,
                           return_token_type_ids=is_token_type_ids)
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs


def create_label(cfg, text, annotation_length, location_list):
    encoded = cfg.tokenizer(text,
                            add_special_tokens=True,
                            max_length=CFG.max_len,
                            padding="max_length",
                            return_offsets_mapping=True)
    offset_mapping = encoded['offset_mapping']
    ignore_idxes = np.where(np.array(encoded.sequence_ids()) != 0)[0]
    label = np.zeros(len(offset_mapping))
    label[ignore_idxes] = -1
    if annotation_length != 0:
        for location in location_list:
            for loc in [s.split() for s in location.split(';')]:
                start_idx = -1
                end_idx = -1
                start, end = int(loc[0]), int(loc[1])
                for idx in range(len(offset_mapping)):
                    if (start_idx == -1) & (start < offset_mapping[idx][0]):
                        start_idx = idx - 1
                    if (end_idx == -1) & (end <= offset_mapping[idx][1]):
                        end_idx = idx + 1
                if start_idx == -1:
                    start_idx = end_idx
                if (start_idx != -1) & (end_idx != -1):
                    label[start_idx:end_idx] = 1
    return torch.tensor(label, dtype=torch.bfloat16)


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.feature_texts = df['feature_text'].values
        self.pn_historys = df['pn_history'].values
        self.annotation_lengths = df['annotation_length'].values
        self.locations = df['location'].values

    def __len__(self):
        return len(self.feature_texts)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, 
                               self.pn_historys[item], 
                               self.feature_texts[item])
        label = create_label(self.cfg, 
                             self.pn_historys[item], 
                             self.annotation_lengths[item], 
                             self.locations[item])
        return inputs, label

# Model

In [21]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from torch.utils.checkpoint import checkpoint

class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = AutoModel(self.config)

        self.model.gradient_checkpointing_enable()
        self.fc_dropout = nn.Dropout(cfg.fc_dropout)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        # Debugging input shapes and types
        #print("Feature extraction inputs:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        # Process the inputs through the model
        outputs = self.model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        
        # Debugging the output of the model
        #print(f"Last hidden states: shape={last_hidden_states.shape}, dtype={last_hidden_states.dtype}")
        return last_hidden_states

    def forward(self, inputs):
        # Debugging the forward pass
        #print("Starting forward pass...")
        #print("Input keys and shapes:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):  # Enable AMP
            #print("Running model under AMP...")
            outputs = self.model(**inputs)
        
        # Debugging model output
        #print("Model outputs:")
        #if isinstance(outputs, tuple):
           # for idx, output in enumerate(outputs):
                #print(f"Output {idx}: shape={output.shape if hasattr(output, 'shape') else 'N/A'}")
        #else:
            #print(f"Outputs: {outputs}")
            #pass
        
        # Final linear layer
        feature = outputs.last_hidden_state
        output = self.fc(self.fc_dropout(feature))
        
        # Debugging final output
        #print(f"Final output shape: {output.shape}, dtype={output.dtype}")
        return output


# Helper functions

In [22]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x76df1d590fd0>

In [23]:
def valid_fn(valid_loader, model, criterion, device):
    """
    Validation function with detailed debugging for intermediate outputs.
    """
    losses = AverageMeter()
    model.eval()
    preds = []
    start = time.time()

    print("Starting validation...")
    for step, (inputs, labels) in enumerate(valid_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        
        with torch.no_grad():
            y_preds = model(inputs)  # Model predictions
        
        # Compute loss
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)

        # Store predictions
        preds.append(y_preds.sigmoid().to('cpu').numpy())

        end = time.time()

        # Debugging: Log step-wise progress and intermediate outputs
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(step, len(valid_loader),
                          loss=losses,
                          remain=timeSince(start, float(step+1)/len(valid_loader))))
            #print(f"Sample Predictions: {y_preds.sigmoid()[:5].view(-1).tolist()}")  # Example predictions
            #print(f"Sample Labels: {labels[:5].view(-1).tolist()}")  # Example labels

    predictions = np.concatenate(preds)

    # Debugging: Check shape and content of predictions
    #print(f"Validation Predictions Shape: {predictions.shape}")
    # print(f"First 5 Predictions: {predictions[:5]}")

    return losses.avg, predictions


def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    """
    Training function with detailed debugging for loss and gradients.
    """
    print(f"\nStarting training epoch {epoch + 1} for fold {fold}")
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)

    losses = AverageMeter()
    start = time.time()
    global_step = 0

    for step, (inputs, labels) in enumerate(train_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)

        # Mixed precision training
        with torch.cuda.amp.autocast(enabled=CFG.apex, dtype=torch.bfloat16):
            y_preds = model(inputs)
            loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
            loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()

        # Gradient accumulation
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps

        losses.update(loss.item(), batch_size)

        # Backpropagation
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)

        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()

        end = time.time()
        
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print(f'Epoch: [{epoch+1}][{step}/{len(train_loader)}] '
                  f'Elapsed {timeSince(start, float(step+1)/len(train_loader))} '
                  f'Loss: {losses.val:.4f}({losses.avg:.4f}) '
                  f'Grad Norm: {grad_norm:.4f}  '
                  f'LR: {scheduler.get_lr()[0]}')

    print(f"Epoch {epoch + 1} completed. Average loss: {losses.avg:.4f}")
    return losses.avg


def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.sigmoid().to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [24]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    valid_texts = valid_folds['pn_history'].values
    valid_labels = create_labels_for_scoring(valid_folds)
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = CustomModel(CFG, config_path=None, pretrained=True)

    # checkpoint_path = f"/workspace/deberta-v3-large-5-folds-public/deberta-v3-large_fold{fold}_best.pth"
    # if os.path.exists(checkpoint_path):
    #     LOGGER.info(f"Loading checkpoint from {checkpoint_path}")
    #     checkpoint = torch.load(checkpoint_path, map_location="cuda", weights_only=False) # need to use weight_only = True for latest pytorch version
    #     model.load_state_dict(checkpoint['model'], strict=False)  # Load model weights
    #     model.to(device)
    model.train()  # Set model to training mode

    # else:
    #     LOGGER.info("No checkpoint found, starting fresh.")
        
    torch.save(model.config, OUTPUT_DIR+'config.pth')
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss(reduction="none")
    
    best_score = 0.

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        print("Starting training")
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        print("Starting evaluation")
        avg_val_loss, predictions = valid_fn(valid_loader, model, criterion, device)
        predictions = predictions.reshape((len(valid_folds), CFG.max_len))
        
        # scoring
        char_probs = get_char_probs(valid_texts, predictions, CFG.tokenizer)
        results = get_results(char_probs, th=0.5)
        preds = get_predictions(results)
        score = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        if CFG.wandb:
            wandb.log({f"[fold{fold}] epoch": epoch+1, 
                       f"[fold{fold}] avg_train_loss": avg_loss, 
                       f"[fold{fold}] avg_val_loss": avg_val_loss,
                       f"[fold{fold}] score": score})
        
        if best_score <= score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(),
                        'predictions': predictions},
                        OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth")
            

    #predictions = torch.load(OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             #map_location=torch.device('cpu'))['predictions']
    predictions = torch.load(OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             map_location=torch.device('cpu'),weights_only=False)['predictions']

    
    valid_folds[[i for i in range(CFG.max_len)]] = predictions

    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

In [25]:
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    #model="microsoft/deberta-v3-large" #"answerdotai/ModernBERT-base"
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" If running on Kaggle
    model = "answerdotai/ModernBERT-base" # running on runpod instance
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=0.05
    num_warmup_steps=0
    epochs=4
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 10 #batch_size=16
    fc_dropout=0.2
    max_len=443
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=1 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True

tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

if CFG.debug:
    CFG.epochs = 4
    CFG.trn_fold = [0]

In [26]:
torch.cuda.empty_cache()

In [27]:
def run_fold_training(oof_df, fold):
    """
    Helper function to run training for a specific fold.
    Concatenates the out-of-fold predictions to oof_df.
    """
    _oof_df = train_loop(train, fold)
    oof_df = pd.concat([oof_df, _oof_df])
    LOGGER.info(f"========== fold: {fold} result ==========")
    evaluate_results(_oof_df)  # Evaluate results after each fold
    return oof_df

def evaluate_results(oof_df):
    """
    Function to handle the result evaluation, including scoring and logging.
    """
    labels = create_labels_for_scoring(oof_df)
    predictions = oof_df[[i for i in range(CFG.max_len)]].values
    char_probs = get_char_probs(oof_df['pn_history'].values, predictions, CFG.tokenizer)
    results = get_results(char_probs, th=0.5)
    preds = get_predictions(results)
    score = get_score(labels, preds)
    LOGGER.info(f'Score: {score:<.4f}')

In [28]:
if __name__ == '__main__':
    CFG.debug = True

    if CFG.train:
        oof_df = pd.DataFrame()

        # Run training for a single fold in debug mode or multiple folds otherwise
        if CFG.debug:
            fold = 0  # In debug mode, we only use fold 0
            if fold in CFG.trn_fold:
                oof_df = run_fold_training(oof_df, fold)
        else:
            for fold in range(CFG.n_fold):  # For all folds in non-debug mode
                if fold in CFG.trn_fold:
                    oof_df = run_fold_training(oof_df, fold)

        # Final evaluation and saving results
        oof_df = oof_df.reset_index(drop=True)
        LOGGER.info(f"========== CV ==========")
        evaluate_results(oof_df)  # Final evaluation of all results
        oof_df.to_pickle(OUTPUT_DIR + 'oof_df.pkl')

    if CFG.wandb:
        wandb.finish()



Starting training

Starting training epoch 1 for fold 0


Epoch: [1][0/1144] Elapsed 0m 4s (remain 92m 45s) Loss: 0.2581(0.2581) Grad Norm: 416179.5000  LR: 9.99999998821668e-06


Epoch: [1][100/1144] Elapsed 1m 10s (remain 12m 4s) Loss: 0.0831(0.0900) Grad Norm: 18727.7070  LR: 9.999879798820318e-06


Epoch: [1][200/1144] Elapsed 2m 14s (remain 10m 32s) Loss: 0.0391(0.0768) Grad Norm: 31531.2207  LR: 9.99952394958765e-06


Epoch: [1][300/1144] Elapsed 3m 17s (remain 9m 13s) Loss: 0.0446(0.0684) Grad Norm: 43202.3320  LR: 9.998932457290953e-06


Epoch: [1][400/1144] Elapsed 4m 22s (remain 8m 6s) Loss: 0.0469(0.0639) Grad Norm: 89377.3828  LR: 9.998105349809095e-06


Epoch: [1][500/1144] Elapsed 5m 26s (remain 6m 59s) Loss: 0.0490(0.0605) Grad Norm: 62269.3320  LR: 9.997042666126213e-06


Epoch: [1][600/1144] Elapsed 6m 32s (remain 5m 54s) Loss: 0.0309(0.0570) Grad Norm: 18640.8965  LR: 9.995744456329885e-06


Epoch: [1][700/1144] Elapsed 7m 37s (remain 4m 49s) Loss: 0.0213(0.0549) Grad Norm: 23287.6035  LR: 9.994210781608763e-06


Epoch: [1][800/1144] Elapsed 8m 43s (remain 3m 44s) Loss: 0.0110(0.0525) Grad Norm: 32788.4531  LR: 9.992441714249694e-06


Epoch: [1][900/1144] Elapsed 9m 52s (remain 2m 39s) Loss: 0.0464(0.0503) Grad Norm: 26934.8906  LR: 9.990437337634305e-06


Epoch: [1][1000/1144] Elapsed 10m 56s (remain 1m 33s) Loss: 0.0542(0.0488) Grad Norm: 62043.4570  LR: 9.98819774623508e-06


Epoch: [1][1100/1144] Elapsed 12m 2s (remain 0m 28s) Loss: 0.0131(0.0469) Grad Norm: 22764.7109  LR: 9.985723045610904e-06


Epoch: [1][1143/1144] Elapsed 12m 30s (remain 0m 0s) Loss: 0.0382(0.0460) Grad Norm: 88658.2344  LR: 9.984586668665641e-06
Epoch 1 completed. Average loss: 0.0460
Starting evaluation
Starting validation...


EVAL: [0/286] Elapsed 0m 1s (remain 8m 50s) Loss: 0.0300(0.0300) 


EVAL: [100/286] Elapsed 0m 46s (remain 1m 25s) Loss: 0.0215(0.0245) 


EVAL: [200/286] Elapsed 1m 30s (remain 0m 38s) Loss: 0.0166(0.0301) 


EVAL: [285/286] Elapsed 2m 9s (remain 0m 0s) Loss: 0.0237(0.0302) 


Epoch 1 - avg_train_loss: 0.0460  avg_val_loss: 0.0302  time: 883s


Epoch 1 - Score: 0.6701


Epoch 1 - Save Best Score: 0.6701 Model


Starting training

Starting training epoch 2 for fold 0


Epoch: [2][0/1144] Elapsed 0m 0s (remain 18m 42s) Loss: 0.0423(0.0423) Grad Norm: 286566.0312  LR: 9.984559724388208e-06


Epoch: [2][100/1144] Elapsed 1m 7s (remain 11m 34s) Loss: 0.0140(0.0272) Grad Norm: 13693.2559  LR: 9.981746674247888e-06


Epoch: [2][200/1144] Elapsed 2m 12s (remain 10m 19s) Loss: 0.0198(0.0264) Grad Norm: 27617.1504  LR: 9.978698818941641e-06


Epoch: [2][300/1144] Elapsed 3m 16s (remain 9m 10s) Loss: 0.0259(0.0267) Grad Norm: 38128.8281  LR: 9.97541630212434e-06


Epoch: [2][400/1144] Elapsed 4m 20s (remain 8m 2s) Loss: 0.0192(0.0258) Grad Norm: 24234.7285  LR: 9.971899278511176e-06


Epoch: [2][500/1144] Elapsed 5m 25s (remain 6m 57s) Loss: 0.0154(0.0255) Grad Norm: 61970.0117  LR: 9.968147913870378e-06


Epoch: [2][600/1144] Elapsed 6m 33s (remain 5m 55s) Loss: 0.0133(0.0252) Grad Norm: 26718.2109  LR: 9.964162385015392e-06


Epoch: [2][700/1144] Elapsed 7m 43s (remain 4m 52s) Loss: 0.0576(0.0249) Grad Norm: 85386.4062  LR: 9.95994287979655e-06


Epoch: [2][800/1144] Elapsed 8m 48s (remain 3m 46s) Loss: 0.0248(0.0246) Grad Norm: 19250.7305  LR: 9.955489597092213e-06


Epoch: [2][900/1144] Elapsed 9m 55s (remain 2m 40s) Loss: 0.0199(0.0241) Grad Norm: 49929.7891  LR: 9.950802746799404e-06


Epoch: [2][1000/1144] Elapsed 11m 2s (remain 1m 34s) Loss: 0.0143(0.0239) Grad Norm: 27087.4727  LR: 9.945882549823906e-06


Epoch: [2][1100/1144] Elapsed 12m 8s (remain 0m 28s) Loss: 0.0082(0.0238) Grad Norm: 16851.1777  LR: 9.940729238069857e-06


Epoch: [2][1143/1144] Elapsed 12m 39s (remain 0m 0s) Loss: 0.0078(0.0237) Grad Norm: 27736.3867  LR: 9.938441702975689e-06
Epoch 2 completed. Average loss: 0.0237
Starting evaluation
Starting validation...


EVAL: [0/286] Elapsed 0m 0s (remain 4m 3s) Loss: 0.0222(0.0222) 


EVAL: [100/286] Elapsed 0m 48s (remain 1m 27s) Loss: 0.0157(0.0189) 


EVAL: [200/286] Elapsed 1m 34s (remain 0m 39s) Loss: 0.0129(0.0224) 


EVAL: [285/286] Elapsed 2m 13s (remain 0m 0s) Loss: 0.0107(0.0216) 


Epoch 2 - avg_train_loss: 0.0237  avg_val_loss: 0.0216  time: 895s


Epoch 2 - Score: 0.7492


Epoch 2 - Save Best Score: 0.7492 Model


Starting training

Starting training epoch 3 for fold 0


Epoch: [3][0/1144] Elapsed 0m 1s (remain 20m 40s) Loss: 0.0294(0.0294) Grad Norm: 47518.5742  LR: 9.938387992324574e-06


Epoch: [3][100/1144] Elapsed 1m 11s (remain 12m 16s) Loss: 0.0014(0.0198) Grad Norm: 4877.2617  LR: 9.932899424829846e-06


Epoch: [3][200/1144] Elapsed 2m 19s (remain 10m 55s) Loss: 0.0149(0.0183) Grad Norm: 53736.7539  LR: 9.927178354491498e-06


Epoch: [3][300/1144] Elapsed 3m 28s (remain 9m 43s) Loss: 0.0185(0.0180) Grad Norm: 31640.7812  LR: 9.921225050961318e-06


Epoch: [3][400/1144] Elapsed 4m 37s (remain 8m 33s) Loss: 0.0222(0.0182) Grad Norm: 50278.8477  LR: 9.915039794836955e-06


Epoch: [3][500/1144] Elapsed 5m 45s (remain 7m 23s) Loss: 0.0446(0.0183) Grad Norm: 56499.6211  LR: 9.908622877648706e-06


Epoch: [3][600/1144] Elapsed 6m 54s (remain 6m 14s) Loss: 0.0203(0.0182) Grad Norm: 45107.9414  LR: 9.901974601845776e-06


Epoch: [3][700/1144] Elapsed 8m 0s (remain 5m 3s) Loss: 0.0117(0.0179) Grad Norm: 41939.3164  LR: 9.895095280782014e-06


Epoch: [3][800/1144] Elapsed 9m 8s (remain 3m 54s) Loss: 0.0149(0.0177) Grad Norm: 21571.5508  LR: 9.88798523870115e-06


Epoch: [3][900/1144] Elapsed 10m 17s (remain 2m 46s) Loss: 0.0131(0.0176) Grad Norm: 31795.3301  LR: 9.88064481072151e-06


Epoch: [3][1000/1144] Elapsed 11m 26s (remain 1m 38s) Loss: 0.0045(0.0174) Grad Norm: 13249.1348  LR: 9.873074342820225e-06


Epoch: [3][1100/1144] Elapsed 12m 35s (remain 0m 29s) Loss: 0.0085(0.0172) Grad Norm: 41038.0938  LR: 9.865274191816917e-06


Epoch: [3][1143/1144] Elapsed 13m 4s (remain 0m 0s) Loss: 0.0081(0.0172) Grad Norm: 10022.7051  LR: 9.861849601988384e-06
Epoch 3 completed. Average loss: 0.0172
Starting evaluation
Starting validation...


EVAL: [0/286] Elapsed 0m 0s (remain 4m 29s) Loss: 0.0239(0.0239) 


EVAL: [100/286] Elapsed 0m 47s (remain 1m 27s) Loss: 0.0153(0.0169) 


EVAL: [200/286] Elapsed 1m 31s (remain 0m 38s) Loss: 0.0062(0.0201) 


EVAL: [285/286] Elapsed 2m 8s (remain 0m 0s) Loss: 0.0063(0.0191) 


Epoch 3 - avg_train_loss: 0.0172  avg_val_loss: 0.0191  time: 915s


Epoch 3 - Score: 0.7735


Epoch 3 - Save Best Score: 0.7735 Model


Starting training

Starting training epoch 4 for fold 0


Epoch: [4][0/1144] Elapsed 0m 1s (remain 21m 28s) Loss: 0.0070(0.0070) Grad Norm: 24013.2812  LR: 9.86176945610761e-06


Epoch: [4][100/1144] Elapsed 1m 5s (remain 11m 17s) Loss: 0.0044(0.0131) Grad Norm: 6075.4004  LR: 9.853639210102213e-06


Epoch: [4][200/1144] Elapsed 2m 10s (remain 10m 13s) Loss: 0.0322(0.0141) Grad Norm: 93926.9297  LR: 9.845280197032851e-06


Epoch: [4][300/1144] Elapsed 3m 17s (remain 9m 13s) Loss: 0.0582(0.0144) Grad Norm: 93377.8047  LR: 9.836692810885728e-06


Epoch: [4][400/1144] Elapsed 4m 22s (remain 8m 6s) Loss: 0.0073(0.0136) Grad Norm: 19828.7656  LR: 9.827877456410978e-06


Epoch: [4][500/1144] Elapsed 5m 27s (remain 7m 0s) Loss: 0.0290(0.0136) Grad Norm: 56001.6562  LR: 9.818834549103587e-06


Epoch: [4][600/1144] Elapsed 6m 33s (remain 5m 55s) Loss: 0.0124(0.0138) Grad Norm: 27602.7148  LR: 9.809564515183814e-06


Epoch: [4][700/1144] Elapsed 7m 37s (remain 4m 49s) Loss: 0.0108(0.0134) Grad Norm: 23608.0938  LR: 9.800067791577096e-06


Epoch: [4][800/1144] Elapsed 8m 43s (remain 3m 44s) Loss: 0.0140(0.0132) Grad Norm: 32052.4980  LR: 9.790344825893463e-06


Epoch: [4][900/1144] Elapsed 9m 53s (remain 2m 39s) Loss: 0.0054(0.0133) Grad Norm: 24686.0625  LR: 9.780396076406429e-06


Epoch: [4][1000/1144] Elapsed 10m 57s (remain 1m 33s) Loss: 0.0032(0.0135) Grad Norm: 5844.7861  LR: 9.770222012031404e-06


Epoch: [4][1100/1144] Elapsed 12m 2s (remain 0m 28s) Loss: 0.0037(0.0134) Grad Norm: 12571.4834  LR: 9.759823112303583e-06


Epoch: [4][1143/1144] Elapsed 12m 29s (remain 0m 0s) Loss: 0.0317(0.0134) Grad Norm: 67056.9609  LR: 9.755282581475769e-06
Epoch 4 completed. Average loss: 0.0134
Starting evaluation
Starting validation...


EVAL: [0/286] Elapsed 0m 0s (remain 3m 52s) Loss: 0.0144(0.0144) 


EVAL: [100/286] Elapsed 0m 44s (remain 1m 21s) Loss: 0.0132(0.0158) 


EVAL: [200/286] Elapsed 1m 28s (remain 0m 37s) Loss: 0.0050(0.0189) 


EVAL: [285/286] Elapsed 2m 7s (remain 0m 0s) Loss: 0.0051(0.0177) 


Epoch 4 - avg_train_loss: 0.0134  avg_val_loss: 0.0177  time: 879s


Epoch 4 - Score: 0.8024


Epoch 4 - Save Best Score: 0.8024 Model




Score: 0.8024




Score: 0.8024


[34m[1mwandb[0m:                                                                                


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m: [fold0] avg_train_loss █▃▂▁
[34m[1mwandb[0m:   [fold0] avg_val_loss █▃▂▁
[34m[1mwandb[0m:          [fold0] epoch ▁▃▆█
[34m[1mwandb[0m:          [fold0] score ▁▅▆█
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m: [fold0] avg_train_loss 0.0134
[34m[1mwandb[0m:   [fold0] avg_val_loss 0.01768
[34m[1mwandb[0m:          [fold0] epoch 4
[34m[1mwandb[0m:          [fold0] score 0.80237
[34m[1mwandb[0m: 


[34m[1mwandb[0m: 🚀 View run [33manswerdotai/ModernBERT-base[0m at: [34m[4mhttps://wandb.ai/project-zero/NBME-Public/runs/bnksm6wu?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/project-zero/NBME-Public?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20250211_211117-bnksm6wu/logs[0m
