# Trainining on all dataset ModernBERT



![Comparison of BERT variants in 2025](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/modernbert/modernbert_pareto_curve.png)

# Directory settings

In [1]:
# ====================================================
# Directory settings
# ====================================================
import os, shutil

OUTPUT_DIR = os.path.join(os.path.dirname(os.getcwd()), 'modern-bert/')
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

for item in os.listdir(OUTPUT_DIR):
    item_path = os.path.join(OUTPUT_DIR, item)
    if os.path.isfile(item_path):  # If it's a file, remove it
        os.remove(item_path)
        print(f"Removed file: {item}")
    elif os.path.isdir(item_path):  # If it's a directory, remove it
        shutil.rmtree(item_path)
        print(f"Removed folder: {item}")

Removed folder: tokenizer
Removed file: train.log
Removed file: config.pth


# Library

In [2]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset

os.system('pip uninstall -y transformers')
os.system('python -m pip install transformers>=4.48.0')
os.system('pip uninstall -y tokenizer')
os.system('python -m pip install tokenizers')
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Found existing installation: transformers 4.48.3


Uninstalling transformers-4.48.3:
  Successfully uninstalled transformers-4.48.3


[0m

[0m

[0m



[0m

tokenizers.__version__: 0.21.0
transformers.__version__: 4.48.3


env: TOKENIZERS_PARALLELISM=true


In [3]:
# ====================================================
# CFG
# ====================================================
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" # if running on Kaggle, use this path 
    model = "answerdotai/ModernBERT-base" # if running on local machine, use this path  
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=1 #0.05
    num_warmup_steps=0
    epochs=5
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 12 #batch_size=16
    fc_dropout=0.2
    max_len=512
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=2 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True
    
if CFG.debug:
    CFG.epochs = 3
    CFG.trn_fold = [0]

# ====================================================
# tokenizer
# ====================================================
tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

In [4]:
# ====================================================
# wandb: local notebook login
# ====================================================
if CFG.wandb:
    
    import wandb
    from dotenv import load_dotenv
    import os

    # Load environment variables from .env file in the parent folder
    load_dotenv(os.path.join(os.path.dirname(os.getcwd()), '.env'))

    secret_value_0 = os.getenv("WANDB_API_KEY")
    if secret_value_0:
        wandb.login(key=secret_value_0)
        anony = None
    else:
        anony = "must"
        print('If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. \nGet your W&B access token from here: https://wandb.ai/authorize')

    def class2dict(f):
        return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

    run = wandb.init(project='NBME-Public', 
                     name=CFG.model,
                     config=class2dict(CFG),
                     group=CFG.model,
                     job_type="train",
                     anonymous=anony)

[34m[1mwandb[0m: Currently logged in as: [33mzehrakorkusuz[0m ([33mproject-zero[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. 
Get your W&B access token from here: https://wandb.ai/authorize


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Tracking run with wandb version 0.19.6


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/workspace/wandb/run-20250211_224808-4ym6uuzz[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33manswerdotai/ModernBERT-base[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/project-zero/NBME-Public?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/project-zero/NBME-Public/runs/4ym6uuzz?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m




# Helper functions for scoring

In [5]:
# From https://www.kaggle.com/theoviel/evaluation-metric-folds-baseline

def micro_f1(preds, truths):
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


def spans_to_binary(spans, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary


def span_micro_f1(preds, truths):
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
    """
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))
    return micro_f1(bin_preds, bin_truths)

In [6]:
## UPDATED FOR HANDLING NONEs in THE PSEUDOLABELS / only including confident inferences
import ast

# Update the create_labels_for_scoring function
def create_labels_for_scoring(df):
    # Initialize the 'location_for_create_labels' column with empty lists
    df['location_for_create_labels'] = [ast.literal_eval(f'[]')] * len(df)
    
    for i in range(len(df)):
        lst = df.loc[i, 'location']
        if lst and lst != '':  # Check if lst is not None or empty
            # Ensure lst is a list of strings
            if isinstance(lst, str):
                lst = [lst]
            elif isinstance(lst, list):
                lst = [str(item) for item in lst]
            new_lst = ';'.join(lst)
            df.loc[i, 'location_for_create_labels'] = ast.literal_eval(f'[["{new_lst}"]]')

    # Create labels
    truths = []
    for location_list in df['location_for_create_labels'].values:
        truth = []
        if len(location_list) > 0:
            location = location_list[0]
            for loc in [s.split() for s in location.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                truth.append([start, end])
        truths.append(truth)
    
    return truths


def get_char_probs(texts, predictions, tokenizer):
    results = [np.zeros(len(t)) for t in texts]
    for i, (text, prediction) in enumerate(zip(texts, predictions)):
        encoded = tokenizer(text, 
                            add_special_tokens=True,
                            return_offsets_mapping=True)
        for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)):
            start = offset_mapping[0]
            end = offset_mapping[1]
            results[i][start:end] = pred
    return results


def get_results(char_probs, th=0.5):
    results = []
    for char_prob in char_probs:
        result = np.where(char_prob >= th)[0] + 1
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results


def get_predictions(results):
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    return predictions

# Utils

In [7]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    score = span_micro_f1(y_true, y_pred)
    return score


def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

# Data Loading

In [8]:
# ====================================================
# Data Loading
# ====================================================
train = pd.read_csv('/workspace/data/train.csv')
train['annotation'] = train['annotation'].apply(ast.literal_eval)
train['location'] = train['location'].apply(ast.literal_eval)
features = pd.read_csv('/workspace/data/features.csv')
def preprocess_features(features):
    features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
    return features
features = preprocess_features(features)
patient_notes = pd.read_csv('/workspace/data/patient_notes.csv')

print(f"train.shape: {train.shape}")
display(train.head())
print(f"features.shape: {features.shape}")
display(features.head())
print(f"patient_notes.shape: {patient_notes.shape}")
display(patient_notes.head())


train.shape: (14300, 6)


Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724]
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693]
2,00016_002,0,16,2,[chest pressure],[203 217]
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]"
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258]


features.shape: (143, 3)


Unnamed: 0,feature_num,case_num,feature_text
0,0,0,Family-history-of-MI-OR-Family-history-of-myoc...
1,1,0,Family-history-of-thyroid-disorder
2,2,0,Chest-pressure
3,3,0,Intermittent-symptoms
4,4,0,Lightheaded


patient_notes.shape: (42146, 3)


Unnamed: 0,pn_num,case_num,pn_history
0,0,0,"17-year-old male, has come to the student heal..."
1,1,0,17 yo male with recurrent palpitations for the...
2,2,0,Dillon Cleveland is a 17 y.o. male patient wit...
3,3,0,a 17 yo m c/o palpitation started 3 mos ago; \...
4,4,0,17yo male with no pmh here for evaluation of p...


In [9]:
train['annotation_length'] = train['annotation'].apply(len) # refers to number of annotation
train

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,annotation_length
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],1
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],1
2,00016_002,0,16,2,[chest pressure],[203 217],1
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",2
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],1
...,...,...,...,...,...,...,...
14295,95333_912,9,95333,912,[],[],0
14296,95333_913,9,95333,913,[],[],0
14297,95333_914,9,95333,914,[photobia],[274 282],1
14298,95333_915,9,95333,915,[no sick contacts],[421 437],1


In [10]:
train = train.merge(features, on=['feature_num', 'case_num'], how='left')
train = train.merge(patient_notes, on=['pn_num', 'case_num'], how='left')
display(train.head())

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,annotation_length,feature_text,pn_history
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],1,Family-history-of-MI-OR-Family-history-of-myoc...,HPI: 17yo M presents with palpitations. Patien...
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],1,Family-history-of-thyroid-disorder,HPI: 17yo M presents with palpitations. Patien...
2,00016_002,0,16,2,[chest pressure],[203 217],1,Chest-pressure,HPI: 17yo M presents with palpitations. Patien...
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",2,Intermittent-symptoms,HPI: 17yo M presents with palpitations. Patien...
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],1,Lightheaded,HPI: 17yo M presents with palpitations. Patien...


In [11]:
# incorrect annotation
train.loc[338, 'annotation'] = ast.literal_eval('[["father heart attack"]]')
train.loc[338, 'location'] = ast.literal_eval('[["764 783"]]')

train.loc[621, 'annotation'] = ast.literal_eval('[["for the last 2-3 months"]]')
train.loc[621, 'location'] = ast.literal_eval('[["77 100"]]')

train.loc[655, 'annotation'] = ast.literal_eval('[["no heat intolerance"], ["no cold intolerance"]]')
train.loc[655, 'location'] = ast.literal_eval('[["285 292;301 312"], ["285 287;296 312"]]')

train.loc[1262, 'annotation'] = ast.literal_eval('[["mother thyroid problem"]]')
train.loc[1262, 'location'] = ast.literal_eval('[["551 557;565 580"]]')

train.loc[1265, 'annotation'] = ast.literal_eval('[[\'felt like he was going to "pass out"\']]')
train.loc[1265, 'location'] = ast.literal_eval('[["131 135;181 212"]]')

train.loc[1396, 'annotation'] = ast.literal_eval('[["stool , with no blood"]]')
train.loc[1396, 'location'] = ast.literal_eval('[["259 280"]]')

train.loc[1591, 'annotation'] = ast.literal_eval('[["diarrhoe non blooody"]]')
train.loc[1591, 'location'] = ast.literal_eval('[["176 184;201 212"]]')

train.loc[1615, 'annotation'] = ast.literal_eval('[["diarrhea for last 2-3 days"]]')
train.loc[1615, 'location'] = ast.literal_eval('[["249 257;271 288"]]')

train.loc[1664, 'annotation'] = ast.literal_eval('[["no vaginal discharge"]]')
train.loc[1664, 'location'] = ast.literal_eval('[["822 824;907 924"]]')

train.loc[1714, 'annotation'] = ast.literal_eval('[["started about 8-10 hours ago"]]')
train.loc[1714, 'location'] = ast.literal_eval('[["101 129"]]')

train.loc[1929, 'annotation'] = ast.literal_eval('[["no blood in the stool"]]')
train.loc[1929, 'location'] = ast.literal_eval('[["531 539;549 561"]]')

train.loc[2134, 'annotation'] = ast.literal_eval('[["last sexually active 9 months ago"]]')
train.loc[2134, 'location'] = ast.literal_eval('[["540 560;581 593"]]')

train.loc[2191, 'annotation'] = ast.literal_eval('[["right lower quadrant pain"]]')
train.loc[2191, 'location'] = ast.literal_eval('[["32 57"]]')

train.loc[2553, 'annotation'] = ast.literal_eval('[["diarrhoea no blood"]]')
train.loc[2553, 'location'] = ast.literal_eval('[["308 317;376 384"]]')

train.loc[3124, 'annotation'] = ast.literal_eval('[["sweating"]]')
train.loc[3124, 'location'] = ast.literal_eval('[["549 557"]]')

train.loc[3858, 'annotation'] = ast.literal_eval('[["previously as regular"], ["previously eveyr 28-29 days"], ["previously lasting 5 days"], ["previously regular flow"]]')
train.loc[3858, 'location'] = ast.literal_eval('[["102 123"], ["102 112;125 141"], ["102 112;143 157"], ["102 112;159 171"]]')

train.loc[4373, 'annotation'] = ast.literal_eval('[["for 2 months"]]')
train.loc[4373, 'location'] = ast.literal_eval('[["33 45"]]')

train.loc[4763, 'annotation'] = ast.literal_eval('[["35 year old"]]')
train.loc[4763, 'location'] = ast.literal_eval('[["5 16"]]')

train.loc[4782, 'annotation'] = ast.literal_eval('[["darker brown stools"]]')
train.loc[4782, 'location'] = ast.literal_eval('[["175 194"]]')

train.loc[4908, 'annotation'] = ast.literal_eval('[["uncle with peptic ulcer"]]')
train.loc[4908, 'location'] = ast.literal_eval('[["700 723"]]')

train.loc[6016, 'annotation'] = ast.literal_eval('[["difficulty falling asleep"]]')
train.loc[6016, 'location'] = ast.literal_eval('[["225 250"]]')

train.loc[6192, 'annotation'] = ast.literal_eval('[["helps to take care of aging mother and in-laws"]]')
train.loc[6192, 'location'] = ast.literal_eval('[["197 218;236 260"]]')

train.loc[6380, 'annotation'] = ast.literal_eval('[["No hair changes"], ["No skin changes"], ["No GI changes"], ["No palpitations"], ["No excessive sweating"]]')
train.loc[6380, 'location'] = ast.literal_eval('[["480 482;507 519"], ["480 482;499 503;512 519"], ["480 482;521 531"], ["480 482;533 545"], ["480 482;564 582"]]')

train.loc[6562, 'annotation'] = ast.literal_eval('[["stressed due to taking care of her mother"], ["stressed due to taking care of husbands parents"]]')
train.loc[6562, 'location'] = ast.literal_eval('[["290 320;327 337"], ["290 320;342 358"]]')

train.loc[6862, 'annotation'] = ast.literal_eval('[["stressor taking care of many sick family members"]]')
train.loc[6862, 'location'] = ast.literal_eval('[["288 296;324 363"]]')

train.loc[7022, 'annotation'] = ast.literal_eval('[["heart started racing and felt numbness for the 1st time in her finger tips"]]')
train.loc[7022, 'location'] = ast.literal_eval('[["108 182"]]')

train.loc[7422, 'annotation'] = ast.literal_eval('[["first started 5 yrs"]]')
train.loc[7422, 'location'] = ast.literal_eval('[["102 121"]]')

train.loc[8876, 'annotation'] = ast.literal_eval('[["No shortness of breath"]]')
train.loc[8876, 'location'] = ast.literal_eval('[["481 483;533 552"]]')

train.loc[9027, 'annotation'] = ast.literal_eval('[["recent URI"], ["nasal stuffines, rhinorrhea, for 3-4 days"]]')
train.loc[9027, 'location'] = ast.literal_eval('[["92 102"], ["123 164"]]')

train.loc[9938, 'annotation'] = ast.literal_eval('[["irregularity with her cycles"], ["heavier bleeding"], ["changes her pad every couple hours"]]')
train.loc[9938, 'location'] = ast.literal_eval('[["89 117"], ["122 138"], ["368 402"]]')

train.loc[9973, 'annotation'] = ast.literal_eval('[["gaining 10-15 lbs"]]')
train.loc[9973, 'location'] = ast.literal_eval('[["344 361"]]')

train.loc[10513, 'annotation'] = ast.literal_eval('[["weight gain"], ["gain of 10-16lbs"]]')
train.loc[10513, 'location'] = ast.literal_eval('[["600 611"], ["607 623"]]')

train.loc[11551, 'annotation'] = ast.literal_eval('[["seeing her son knows are not real"]]')
train.loc[11551, 'location'] = ast.literal_eval('[["386 400;443 461"]]')

train.loc[11677, 'annotation'] = ast.literal_eval('[["saw him once in the kitchen after he died"]]')
train.loc[11677, 'location'] = ast.literal_eval('[["160 201"]]')

train.loc[12124, 'annotation'] = ast.literal_eval('[["tried Ambien but it didnt work"]]')
train.loc[12124, 'location'] = ast.literal_eval('[["325 337;349 366"]]')

train.loc[12279, 'annotation'] = ast.literal_eval('[["heard what she described as a party later than evening these things did not actually happen"]]')
train.loc[12279, 'location'] = ast.literal_eval('[["405 459;488 524"]]')

train.loc[12289, 'annotation'] = ast.literal_eval('[["experienced seeing her son at the kitchen table these things did not actually happen"]]')
train.loc[12289, 'location'] = ast.literal_eval('[["353 400;488 524"]]')

train.loc[13238, 'annotation'] = ast.literal_eval('[["SCRACHY THROAT"], ["RUNNY NOSE"]]')
train.loc[13238, 'location'] = ast.literal_eval('[["293 307"], ["321 331"]]')

train.loc[13297, 'annotation'] = ast.literal_eval('[["without improvement when taking tylenol"], ["without improvement when taking ibuprofen"]]')
train.loc[13297, 'location'] = ast.literal_eval('[["182 221"], ["182 213;225 234"]]')

train.loc[13299, 'annotation'] = ast.literal_eval('[["yesterday"], ["yesterday"]]')
train.loc[13299, 'location'] = ast.literal_eval('[["79 88"], ["409 418"]]')

train.loc[13845, 'annotation'] = ast.literal_eval('[["headache global"], ["headache throughout her head"]]')
train.loc[13845, 'location'] = ast.literal_eval('[["86 94;230 236"], ["86 94;237 256"]]')

train.loc[14083, 'annotation'] = ast.literal_eval('[["headache generalized in her head"]]')
train.loc[14083, 'location'] = ast.literal_eval('[["56 64;156 179"]]')

In [12]:
pseudolabels = pd.read_csv('/workspace/data/pseudolabels.csv')
pseudolabels =  pseudolabels.merge(features, on=['feature_num', 'case_num'], how='left')
pseudolabels = pseudolabels.merge(patient_notes, on=['pn_num', 'case_num'], how='left')

pseudolabels = pseudolabels.dropna(subset=['location'])
pseudolabels = pseudolabels[pseudolabels['location'].apply(lambda x: len(x) > 0)]

def convert_location_format(location):
    if pd.isna(location) or not location:
        return []
    if isinstance(location, str):
        return [loc.strip() for loc in location.split(';')]
    return location

pseudolabels['location'] = pseudolabels['location'].apply(convert_location_format)

pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal..."
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...
...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...


In [13]:
pseudolabels['annotation_length'] = pseudolabels['location'].apply(len)

In [14]:
pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history,annotation_length
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal...",1
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...,1
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...,1
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...,1
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...,1
...,...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,2
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,1
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...,1
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...,1


In [15]:
pseudolabels = pseudolabels.drop(columns=['fold'])

In [16]:
# add pseudolabels to train
## IF YOU WANT TO INCLUDE PSEUDOLABELS
train = pd.concat([train, pseudolabels], axis=0, ignore_index=True)

# CV split

In [17]:
# ====================================================
# CV split
# ====================================================
train = train.reset_index(drop=True)

Fold = GroupKFold(n_splits=CFG.n_fold)
groups = train['pn_num'].values
for n, (train_index, val_index) in enumerate(Fold.split(train, train['location'], groups)):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)
display(train.groupby('fold').size())

fold
0    20107
1    20107
2    20106
3    20106
4    20106
dtype: int64

In [18]:
if CFG.debug:
    display(train.groupby('fold').size())
    train = train.sample(n=1000, random_state=0).reset_index(drop=True)
    display(train.groupby('fold').size())

# Dataset

In [19]:
# Define max_len
for text_col in ['pn_history']:
    pn_history_lengths = []
    tk0 = tqdm(patient_notes[text_col].fillna("").values, total=len(patient_notes))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        pn_history_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(pn_history_lengths)}')

for text_col in ['feature_text']:
    features_lengths = []
    tk0 = tqdm(features[text_col].fillna("").values, total=len(features))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        features_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(features_lengths)}')

CFG.max_len = max(pn_history_lengths) + max(features_lengths) + 3  # cls & sep & sep
LOGGER.info(f"max_len: {CFG.max_len}")

  0%|          | 0/42146 [00:00<?, ?it/s]

  0%|          | 197/42146 [00:00<00:21, 1969.75it/s]

  1%|          | 439/42146 [00:00<00:18, 2234.37it/s]

  2%|▏         | 682/42146 [00:00<00:17, 2322.73it/s]

  2%|▏         | 925/42146 [00:00<00:17, 2364.03it/s]

  3%|▎         | 1171/42146 [00:00<00:17, 2398.22it/s]

  3%|▎         | 1415/42146 [00:00<00:16, 2412.28it/s]

  4%|▍         | 1657/42146 [00:00<00:16, 2402.36it/s]

  5%|▍         | 1899/42146 [00:00<00:16, 2406.06it/s]

  5%|▌         | 2142/42146 [00:00<00:16, 2413.37it/s]

  6%|▌         | 2385/42146 [00:01<00:16, 2415.65it/s]

  6%|▌         | 2627/42146 [00:01<00:16, 2405.53it/s]

  7%|▋         | 2868/42146 [00:01<00:16, 2377.20it/s]

  7%|▋         | 3106/42146 [00:01<00:16, 2353.21it/s]

  8%|▊         | 3342/42146 [00:01<00:17, 2254.63it/s]

  8%|▊         | 3569/42146 [00:01<00:17, 2158.83it/s]

  9%|▉         | 3787/42146 [00:01<00:18, 2088.51it/s]

  9%|▉         | 3997/42146 [00:01<00:18, 2047.00it/s]

 10%|█         | 4217/42146 [00:01<00:18, 2089.46it/s]

 11%|█         | 4444/42146 [00:01<00:17, 2140.65it/s]

 11%|█         | 4659/42146 [00:02<00:18, 2063.96it/s]

 12%|█▏        | 4882/42146 [00:02<00:17, 2110.94it/s]

 12%|█▏        | 5112/42146 [00:02<00:17, 2163.22it/s]

 13%|█▎        | 5350/42146 [00:02<00:16, 2225.18it/s]

 13%|█▎        | 5574/42146 [00:02<00:16, 2226.10it/s]

 14%|█▍        | 5819/42146 [00:02<00:15, 2292.05it/s]

 14%|█▍        | 6070/42146 [00:02<00:15, 2354.86it/s]

 15%|█▍        | 6315/42146 [00:02<00:15, 2382.14it/s]

 16%|█▌        | 6560/42146 [00:02<00:14, 2400.80it/s]

 16%|█▌        | 6813/42146 [00:02<00:14, 2437.28it/s]

 17%|█▋        | 7060/42146 [00:03<00:14, 2444.24it/s]

 17%|█▋        | 7305/42146 [00:03<00:14, 2418.20it/s]

 18%|█▊        | 7547/42146 [00:03<00:14, 2401.20it/s]

 18%|█▊        | 7788/42146 [00:03<00:14, 2387.39it/s]

 19%|█▉        | 8029/42146 [00:03<00:14, 2391.20it/s]

 20%|█▉        | 8269/42146 [00:03<00:14, 2389.33it/s]

 20%|██        | 8508/42146 [00:03<00:14, 2365.71it/s]

 21%|██        | 8745/42146 [00:03<00:14, 2349.32it/s]

 21%|██▏       | 8987/42146 [00:03<00:14, 2368.50it/s]

 22%|██▏       | 9224/42146 [00:03<00:13, 2354.88it/s]

 22%|██▏       | 9470/42146 [00:04<00:13, 2385.72it/s]

 23%|██▎       | 9718/42146 [00:04<00:13, 2413.14it/s]

 24%|██▎       | 9965/42146 [00:04<00:13, 2428.66it/s]

 24%|██▍       | 10211/42146 [00:04<00:13, 2435.59it/s]

 25%|██▍       | 10455/42146 [00:04<00:13, 2422.85it/s]

 25%|██▌       | 10701/42146 [00:04<00:12, 2431.58it/s]

 26%|██▌       | 10945/42146 [00:04<00:12, 2410.54it/s]

 27%|██▋       | 11187/42146 [00:04<00:12, 2405.21it/s]

 27%|██▋       | 11428/42146 [00:04<00:12, 2390.57it/s]

 28%|██▊       | 11673/42146 [00:05<00:12, 2406.96it/s]

 28%|██▊       | 11920/42146 [00:05<00:12, 2425.40it/s]

 29%|██▉       | 12170/42146 [00:05<00:12, 2445.29it/s]

 29%|██▉       | 12424/42146 [00:05<00:12, 2471.27it/s]

 30%|███       | 12672/42146 [00:05<00:11, 2457.72it/s]

 31%|███       | 12918/42146 [00:05<00:12, 2421.14it/s]

 31%|███       | 13161/42146 [00:05<00:11, 2417.69it/s]

 32%|███▏      | 13403/42146 [00:05<00:11, 2417.63it/s]

 32%|███▏      | 13645/42146 [00:05<00:11, 2386.49it/s]

 33%|███▎      | 13888/42146 [00:05<00:11, 2398.24it/s]

 34%|███▎      | 14128/42146 [00:06<00:11, 2376.53it/s]

 34%|███▍      | 14374/42146 [00:06<00:11, 2399.04it/s]

 35%|███▍      | 14620/42146 [00:06<00:11, 2416.68it/s]

 35%|███▌      | 14864/42146 [00:06<00:11, 2421.83it/s]

 36%|███▌      | 15107/42146 [00:06<00:11, 2411.50it/s]

 36%|███▋      | 15357/42146 [00:06<00:10, 2437.42it/s]

 37%|███▋      | 15612/42146 [00:06<00:10, 2469.90it/s]

 38%|███▊      | 15860/42146 [00:06<00:10, 2441.81it/s]

 38%|███▊      | 16105/42146 [00:06<00:10, 2427.38it/s]

 39%|███▉      | 16349/42146 [00:06<00:10, 2428.42it/s]

 39%|███▉      | 16596/42146 [00:07<00:10, 2438.41it/s]

 40%|███▉      | 16840/42146 [00:07<00:10, 2435.16it/s]

 41%|████      | 17084/42146 [00:07<00:10, 2433.79it/s]

 41%|████      | 17328/42146 [00:07<00:10, 2435.36it/s]

 42%|████▏     | 17577/42146 [00:07<00:10, 2449.71it/s]

 42%|████▏     | 17824/42146 [00:07<00:09, 2455.71it/s]

 43%|████▎     | 18075/42146 [00:07<00:09, 2470.13it/s]

 43%|████▎     | 18324/42146 [00:07<00:09, 2473.86it/s]

 44%|████▍     | 18572/42146 [00:07<00:09, 2451.13it/s]

 45%|████▍     | 18818/42146 [00:07<00:09, 2430.83it/s]

 45%|████▌     | 19062/42146 [00:08<00:09, 2395.12it/s]

 46%|████▌     | 19302/42146 [00:08<00:09, 2312.87it/s]

 46%|████▋     | 19538/42146 [00:08<00:09, 2325.39it/s]

 47%|████▋     | 19775/42146 [00:08<00:09, 2336.86it/s]

 47%|████▋     | 20010/42146 [00:08<00:09, 2338.30it/s]

 48%|████▊     | 20246/42146 [00:08<00:09, 2342.34it/s]

 49%|████▊     | 20490/42146 [00:08<00:09, 2370.55it/s]

 49%|████▉     | 20757/42146 [00:08<00:08, 2458.26it/s]

 50%|████▉     | 21027/42146 [00:08<00:08, 2528.22it/s]

 51%|█████     | 21285/42146 [00:08<00:08, 2541.04it/s]

 51%|█████     | 21554/42146 [00:09<00:07, 2584.31it/s]

 52%|█████▏    | 21813/42146 [00:09<00:07, 2573.21it/s]

 52%|█████▏    | 22071/42146 [00:09<00:07, 2569.50it/s]

 53%|█████▎    | 22336/42146 [00:09<00:07, 2591.95it/s]

 54%|█████▎    | 22596/42146 [00:09<00:07, 2543.43it/s]

 54%|█████▍    | 22851/42146 [00:09<00:07, 2497.98it/s]

 55%|█████▍    | 23102/42146 [00:09<00:07, 2484.90it/s]

 55%|█████▌    | 23351/42146 [00:09<00:07, 2485.26it/s]

 56%|█████▌    | 23603/42146 [00:09<00:07, 2494.87it/s]

 57%|█████▋    | 23853/42146 [00:09<00:07, 2494.30it/s]

 57%|█████▋    | 24103/42146 [00:10<00:07, 2491.64it/s]

 58%|█████▊    | 24354/42146 [00:10<00:07, 2494.69it/s]

 58%|█████▊    | 24604/42146 [00:10<00:07, 2471.46it/s]

 59%|█████▉    | 24855/42146 [00:10<00:06, 2480.43it/s]

 60%|█████▉    | 25110/42146 [00:10<00:06, 2498.74it/s]

 60%|██████    | 25367/42146 [00:10<00:06, 2518.78it/s]

 61%|██████    | 25619/42146 [00:10<00:06, 2491.12it/s]

 61%|██████▏   | 25869/42146 [00:10<00:06, 2483.44it/s]

 62%|██████▏   | 26118/42146 [00:10<00:06, 2476.66it/s]

 63%|██████▎   | 26366/42146 [00:10<00:06, 2471.07it/s]

 63%|██████▎   | 26616/42146 [00:11<00:06, 2479.27it/s]

 64%|██████▎   | 26864/42146 [00:11<00:06, 2463.19it/s]

 64%|██████▍   | 27114/42146 [00:11<00:06, 2472.21it/s]

 65%|██████▍   | 27370/42146 [00:11<00:05, 2496.59it/s]

 66%|██████▌   | 27620/42146 [00:11<00:05, 2483.59it/s]

 66%|██████▌   | 27885/42146 [00:11<00:05, 2531.47it/s]

 67%|██████▋   | 28139/42146 [00:11<00:05, 2515.01it/s]

 67%|██████▋   | 28391/42146 [00:11<00:05, 2482.39it/s]

 68%|██████▊   | 28640/42146 [00:11<00:05, 2443.07it/s]

 69%|██████▊   | 28885/42146 [00:12<00:05, 2378.22it/s]

 69%|██████▉   | 29124/42146 [00:12<00:05, 2331.62it/s]

 70%|██████▉   | 29358/42146 [00:12<00:05, 2305.51it/s]

 70%|███████   | 29589/42146 [00:12<00:05, 2295.71it/s]

 71%|███████   | 29819/42146 [00:12<00:05, 2287.88it/s]

 71%|███████▏  | 30048/42146 [00:12<00:05, 2282.03it/s]

 72%|███████▏  | 30277/42146 [00:12<00:05, 2266.14it/s]

 72%|███████▏  | 30504/42146 [00:12<00:05, 2261.61it/s]

 73%|███████▎  | 30731/42146 [00:12<00:05, 2230.75it/s]

 73%|███████▎  | 30956/42146 [00:12<00:05, 2234.74it/s]

 74%|███████▍  | 31181/42146 [00:13<00:04, 2236.60it/s]

 75%|███████▍  | 31405/42146 [00:13<00:04, 2233.81it/s]

 75%|███████▌  | 31635/42146 [00:13<00:04, 2250.82it/s]

 76%|███████▌  | 31863/42146 [00:13<00:04, 2258.88it/s]

 76%|███████▌  | 32089/42146 [00:13<00:04, 2257.75it/s]

 77%|███████▋  | 32315/42146 [00:13<00:04, 2252.09it/s]

 77%|███████▋  | 32541/42146 [00:13<00:04, 2253.72it/s]

 78%|███████▊  | 32767/42146 [00:13<00:04, 2240.64it/s]

 78%|███████▊  | 32997/42146 [00:13<00:04, 2257.85it/s]

 79%|███████▉  | 33231/42146 [00:13<00:03, 2282.00it/s]

 79%|███████▉  | 33460/42146 [00:14<00:03, 2265.58it/s]

 80%|███████▉  | 33687/42146 [00:14<00:03, 2237.65it/s]

 80%|████████  | 33911/42146 [00:14<00:03, 2212.38it/s]

 81%|████████  | 34133/42146 [00:14<00:03, 2203.03it/s]

 82%|████████▏ | 34354/42146 [00:14<00:03, 2194.08it/s]

 82%|████████▏ | 34574/42146 [00:14<00:03, 2187.72it/s]

 83%|████████▎ | 34796/42146 [00:14<00:03, 2196.24it/s]

 83%|████████▎ | 35020/42146 [00:14<00:03, 2209.15it/s]

 84%|████████▎ | 35241/42146 [00:14<00:03, 2205.68it/s]

 84%|████████▍ | 35462/42146 [00:14<00:03, 2198.31it/s]

 85%|████████▍ | 35684/42146 [00:15<00:02, 2204.23it/s]

 85%|████████▌ | 35905/42146 [00:15<00:02, 2196.93it/s]

 86%|████████▌ | 36125/42146 [00:15<00:02, 2183.91it/s]

 86%|████████▌ | 36345/42146 [00:15<00:02, 2187.07it/s]

 87%|████████▋ | 36565/42146 [00:15<00:02, 2190.28it/s]

 87%|████████▋ | 36785/42146 [00:15<00:02, 2189.32it/s]

 88%|████████▊ | 37010/42146 [00:15<00:02, 2206.40it/s]

 88%|████████▊ | 37261/42146 [00:15<00:02, 2295.79it/s]

 89%|████████▉ | 37515/42146 [00:15<00:01, 2368.76it/s]

 90%|████████▉ | 37778/42146 [00:15<00:01, 2445.44it/s]

 90%|█████████ | 38033/42146 [00:16<00:01, 2474.30it/s]

 91%|█████████ | 38301/42146 [00:16<00:01, 2534.59it/s]

 91%|█████████▏| 38556/42146 [00:16<00:01, 2538.49it/s]

 92%|█████████▏| 38817/42146 [00:16<00:01, 2558.70it/s]

 93%|█████████▎| 39091/42146 [00:16<00:01, 2611.86it/s]

 93%|█████████▎| 39353/42146 [00:16<00:01, 2596.42it/s]

 94%|█████████▍| 39625/42146 [00:16<00:00, 2631.00it/s]

 95%|█████████▍| 39889/42146 [00:16<00:00, 2588.77it/s]

 95%|█████████▌| 40149/42146 [00:16<00:00, 2569.39it/s]

 96%|█████████▌| 40407/42146 [00:16<00:00, 2547.98it/s]

 96%|█████████▋| 40662/42146 [00:17<00:00, 2521.00it/s]

 97%|█████████▋| 40915/42146 [00:17<00:00, 2515.71it/s]

 98%|█████████▊| 41176/42146 [00:17<00:00, 2543.16it/s]

 98%|█████████▊| 41443/42146 [00:17<00:00, 2580.49it/s]

 99%|█████████▉| 41707/42146 [00:17<00:00, 2596.94it/s]

100%|█████████▉| 41967/42146 [00:17<00:00, 2560.38it/s]

100%|██████████| 42146/42146 [00:17<00:00, 2385.01it/s]


pn_history max(lengths): 412


  0%|          | 0/143 [00:00<?, ?it/s]

100%|██████████| 143/143 [00:00<00:00, 17686.00it/s]


feature_text max(lengths): 28


max_len: 443


In [20]:
# ====================================================
# Dataset
# ====================================================
def prepare_input(cfg, text, feature_text):

    # Omit token_type_ids for ModernBERT
    is_token_type_ids = CFG.model != 'answerdotai/ModernBERT-base'

    inputs = cfg.tokenizer(text, feature_text, 
                           add_special_tokens=True,
                           max_length=CFG.max_len,
                           padding="max_length",
                           truncation=True,
                           return_offsets_mapping=False,
                           return_token_type_ids=is_token_type_ids)
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs


def create_label(cfg, text, annotation_length, location_list):
    encoded = cfg.tokenizer(text,
                            add_special_tokens=True,
                            max_length=CFG.max_len,
                            padding="max_length",
                            return_offsets_mapping=True)
    offset_mapping = encoded['offset_mapping']
    ignore_idxes = np.where(np.array(encoded.sequence_ids()) != 0)[0]
    label = np.zeros(len(offset_mapping))
    label[ignore_idxes] = -1
    if annotation_length != 0:
        for location in location_list:
            for loc in [s.split() for s in location.split(';')]:
                start_idx = -1
                end_idx = -1
                start, end = int(loc[0]), int(loc[1])
                for idx in range(len(offset_mapping)):
                    if (start_idx == -1) & (start < offset_mapping[idx][0]):
                        start_idx = idx - 1
                    if (end_idx == -1) & (end <= offset_mapping[idx][1]):
                        end_idx = idx + 1
                if start_idx == -1:
                    start_idx = end_idx
                if (start_idx != -1) & (end_idx != -1):
                    label[start_idx:end_idx] = 1
    return torch.tensor(label, dtype=torch.bfloat16)


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.feature_texts = df['feature_text'].values
        self.pn_historys = df['pn_history'].values
        self.annotation_lengths = df['annotation_length'].values
        self.locations = df['location'].values

    def __len__(self):
        return len(self.feature_texts)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, 
                               self.pn_historys[item], 
                               self.feature_texts[item])
        label = create_label(self.cfg, 
                             self.pn_historys[item], 
                             self.annotation_lengths[item], 
                             self.locations[item])
        return inputs, label

# Model

In [21]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from torch.utils.checkpoint import checkpoint

class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = AutoModel(self.config)

        self.model.gradient_checkpointing_enable()
        self.fc_dropout = nn.Dropout(cfg.fc_dropout)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        # Debugging input shapes and types
        #print("Feature extraction inputs:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        # Process the inputs through the model
        outputs = self.model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        
        # Debugging the output of the model
        #print(f"Last hidden states: shape={last_hidden_states.shape}, dtype={last_hidden_states.dtype}")
        return last_hidden_states

    def forward(self, inputs):
        # Debugging the forward pass
        #print("Starting forward pass...")
        #print("Input keys and shapes:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):  # Enable AMP
            #print("Running model under AMP...")
            outputs = self.model(**inputs)
        
        # Debugging model output
        #print("Model outputs:")
        #if isinstance(outputs, tuple):
           # for idx, output in enumerate(outputs):
                #print(f"Output {idx}: shape={output.shape if hasattr(output, 'shape') else 'N/A'}")
        #else:
            #print(f"Outputs: {outputs}")
            #pass
        
        # Final linear layer
        feature = outputs.last_hidden_state
        output = self.fc(self.fc_dropout(feature))
        
        # Debugging final output
        #print(f"Final output shape: {output.shape}, dtype={output.dtype}")
        return output


# Helper functions

In [22]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f891ee1a650>

In [23]:
def valid_fn(valid_loader, model, criterion, device):
    """
    Validation function with detailed debugging for intermediate outputs.
    """
    losses = AverageMeter()
    model.eval()
    preds = []
    start = time.time()

    print("Starting validation...")
    for step, (inputs, labels) in enumerate(valid_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        
        with torch.no_grad():
            y_preds = model(inputs)  # Model predictions
        
        # Compute loss
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)

        # Store predictions
        preds.append(y_preds.sigmoid().to('cpu').numpy())

        end = time.time()

        # Debugging: Log step-wise progress and intermediate outputs
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(step, len(valid_loader),
                          loss=losses,
                          remain=timeSince(start, float(step+1)/len(valid_loader))))
            #print(f"Sample Predictions: {y_preds.sigmoid()[:5].view(-1).tolist()}")  # Example predictions
            #print(f"Sample Labels: {labels[:5].view(-1).tolist()}")  # Example labels

    predictions = np.concatenate(preds)

    # Debugging: Check shape and content of predictions
    #print(f"Validation Predictions Shape: {predictions.shape}")
    # print(f"First 5 Predictions: {predictions[:5]}")

    return losses.avg, predictions


def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    """
    Training function with detailed debugging for loss and gradients.
    """
    print(f"\nStarting training epoch {epoch + 1} for fold {fold}")
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)

    losses = AverageMeter()
    start = time.time()
    global_step = 0

    for step, (inputs, labels) in enumerate(train_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)

        # Mixed precision training
        with torch.cuda.amp.autocast(enabled=CFG.apex, dtype=torch.bfloat16):
            y_preds = model(inputs)
            loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
            loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()

        # Gradient accumulation
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps

        losses.update(loss.item(), batch_size)

        # Backpropagation
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)

        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()

        end = time.time()
        
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print(f'Epoch: [{epoch+1}][{step}/{len(train_loader)}] '
                  f'Elapsed {timeSince(start, float(step+1)/len(train_loader))} '
                  f'Loss: {losses.val:.4f}({losses.avg:.4f}) '
                  f'Grad Norm: {grad_norm:.4f}  '
                  f'LR: {scheduler.get_lr()[0]}')

    print(f"Epoch {epoch + 1} completed. Average loss: {losses.avg:.4f}")
    return losses.avg


def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.sigmoid().to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [24]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    valid_texts = valid_folds['pn_history'].values
    valid_labels = create_labels_for_scoring(valid_folds)
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = CustomModel(CFG, config_path=None, pretrained=True)

    # checkpoint_path = f"/workspace/deberta-v3-large-5-folds-public/deberta-v3-large_fold{fold}_best.pth"
    # if os.path.exists(checkpoint_path):
    #     LOGGER.info(f"Loading checkpoint from {checkpoint_path}")
    #     checkpoint = torch.load(checkpoint_path, map_location="cuda", weights_only=False) # need to use weight_only = True for latest pytorch version
    #     model.load_state_dict(checkpoint['model'], strict=False)  # Load model weights
    #     model.to(device)
    model.train()  # Set model to training mode

    # else:
    #     LOGGER.info("No checkpoint found, starting fresh.")
        
    torch.save(model.config, OUTPUT_DIR+'config.pth')
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss(reduction="none")
    
    best_score = 0.

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        print("Starting training")
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        print("Starting evaluation")
        avg_val_loss, predictions = valid_fn(valid_loader, model, criterion, device)
        predictions = predictions.reshape((len(valid_folds), CFG.max_len))
        
        # scoring
        char_probs = get_char_probs(valid_texts, predictions, CFG.tokenizer)
        results = get_results(char_probs, th=0.5)
        preds = get_predictions(results)
        score = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        if CFG.wandb:
            wandb.log({f"[fold{fold}] epoch": epoch+1, 
                       f"[fold{fold}] avg_train_loss": avg_loss, 
                       f"[fold{fold}] avg_val_loss": avg_val_loss,
                       f"[fold{fold}] score": score})
        
        if best_score <= score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(),
                        'predictions': predictions},
                        OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth")
            

    #predictions = torch.load(OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             #map_location=torch.device('cpu'))['predictions']
    predictions = torch.load(OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             map_location=torch.device('cpu'),weights_only=False)['predictions']

    
    valid_folds[[i for i in range(CFG.max_len)]] = predictions

    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

In [25]:
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    #model="microsoft/deberta-v3-large" #"answerdotai/ModernBERT-base"
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" If running on Kaggle
    model = "answerdotai/ModernBERT-base" # running on runpod instance
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=0.05
    num_warmup_steps=0
    epochs=4
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 10 #batch_size=16
    fc_dropout=0.2
    max_len=443
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=1 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True

tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

if CFG.debug:
    CFG.epochs = 4
    CFG.trn_fold = [0]

In [26]:
torch.cuda.empty_cache()

In [27]:
def run_fold_training(oof_df, fold):
    """
    Helper function to run training for a specific fold.
    Concatenates the out-of-fold predictions to oof_df.
    """
    _oof_df = train_loop(train, fold)
    oof_df = pd.concat([oof_df, _oof_df])
    LOGGER.info(f"========== fold: {fold} result ==========")
    evaluate_results(_oof_df)  # Evaluate results after each fold
    return oof_df

def evaluate_results(oof_df):
    """
    Function to handle the result evaluation, including scoring and logging.
    """
    labels = create_labels_for_scoring(oof_df)
    predictions = oof_df[[i for i in range(CFG.max_len)]].values
    char_probs = get_char_probs(oof_df['pn_history'].values, predictions, CFG.tokenizer)
    results = get_results(char_probs, th=0.5)
    preds = get_predictions(results)
    score = get_score(labels, preds)
    LOGGER.info(f'Score: {score:<.4f}')

In [28]:
if __name__ == '__main__':
    CFG.debug = True

    if CFG.train:
        oof_df = pd.DataFrame()

        # Run training for a single fold in debug mode or multiple folds otherwise
        if CFG.debug:
            fold = 0  # In debug mode, we only use fold 0
            if fold in CFG.trn_fold:
                oof_df = run_fold_training(oof_df, fold)
        else:
            for fold in range(CFG.n_fold):  # For all folds in non-debug mode
                if fold in CFG.trn_fold:
                    oof_df = run_fold_training(oof_df, fold)

        # Final evaluation and saving results
        oof_df = oof_df.reset_index(drop=True)
        LOGGER.info(f"========== CV ==========")
        evaluate_results(oof_df)  # Final evaluation of all results
        oof_df.to_pickle(OUTPUT_DIR + 'oof_df.pkl')

    if CFG.wandb:
        wandb.finish()



Starting training

Starting training epoch 1 for fold 0


Epoch: [1][0/8042] Elapsed 0m 5s (remain 674m 17s) Loss: 0.2891(0.2891) Grad Norm: 371675.4688  LR: 9.999999999761583e-06


Epoch: [1][100/8042] Elapsed 1m 8s (remain 90m 14s) Loss: 0.0518(0.1065) Grad Norm: 13329.0332  LR: 9.999997567906066e-06


Epoch: [1][200/8042] Elapsed 2m 11s (remain 85m 12s) Loss: 0.0807(0.0913) Grad Norm: 90074.0781  LR: 9.99999036770871e-06


Epoch: [1][300/8042] Elapsed 3m 14s (remain 83m 13s) Loss: 0.0612(0.0825) Grad Norm: 36674.0078  LR: 9.99997839917638e-06


Epoch: [1][400/8042] Elapsed 4m 18s (remain 82m 8s) Loss: 0.0324(0.0756) Grad Norm: 14883.4004  LR: 9.999961662320492e-06


Epoch: [1][500/8042] Elapsed 5m 22s (remain 80m 49s) Loss: 0.0544(0.0705) Grad Norm: 199112.4531  LR: 9.999940157157006e-06


Epoch: [1][600/8042] Elapsed 6m 24s (remain 79m 19s) Loss: 0.0332(0.0661) Grad Norm: 37233.4180  LR: 9.99991388370643e-06


Epoch: [1][700/8042] Elapsed 7m 28s (remain 78m 15s) Loss: 0.0711(0.0628) Grad Norm: 66610.6953  LR: 9.999882841993823e-06


Epoch: [1][800/8042] Elapsed 8m 31s (remain 77m 8s) Loss: 0.0520(0.0597) Grad Norm: 64699.3203  LR: 9.999847032048786e-06


Epoch: [1][900/8042] Elapsed 9m 35s (remain 75m 58s) Loss: 0.0415(0.0569) Grad Norm: 66695.7969  LR: 9.999806453905471e-06


Epoch: [1][1000/8042] Elapsed 10m 40s (remain 75m 7s) Loss: 0.0311(0.0545) Grad Norm: 61588.7891  LR: 9.999761107602578e-06


Epoch: [1][1100/8042] Elapsed 11m 44s (remain 74m 1s) Loss: 0.0265(0.0524) Grad Norm: 43677.4023  LR: 9.999710993183348e-06


Epoch: [1][1200/8042] Elapsed 12m 49s (remain 73m 5s) Loss: 0.0185(0.0505) Grad Norm: 41191.3750  LR: 9.999656110695576e-06


Epoch: [1][1300/8042] Elapsed 13m 55s (remain 72m 6s) Loss: 0.0127(0.0485) Grad Norm: 33883.4297  LR: 9.999596460191603e-06


Epoch: [1][1400/8042] Elapsed 14m 59s (remain 71m 5s) Loss: 0.0177(0.0470) Grad Norm: 28564.7031  LR: 9.999532041728313e-06


Epoch: [1][1500/8042] Elapsed 16m 2s (remain 69m 55s) Loss: 0.0387(0.0456) Grad Norm: 73322.2812  LR: 9.999462855367142e-06


Epoch: [1][1600/8042] Elapsed 17m 7s (remain 68m 54s) Loss: 0.0235(0.0442) Grad Norm: 34897.4883  LR: 9.99938890117407e-06


Epoch: [1][1700/8042] Elapsed 18m 12s (remain 67m 54s) Loss: 0.0381(0.0427) Grad Norm: 101567.0234  LR: 9.999310179219625e-06


Epoch: [1][1800/8042] Elapsed 19m 17s (remain 66m 51s) Loss: 0.0277(0.0415) Grad Norm: 44191.3594  LR: 9.999226689578882e-06


Epoch: [1][1900/8042] Elapsed 20m 22s (remain 65m 50s) Loss: 0.0213(0.0405) Grad Norm: 46243.7617  LR: 9.99913843233146e-06


Epoch: [1][2000/8042] Elapsed 21m 26s (remain 64m 44s) Loss: 0.0135(0.0395) Grad Norm: 74400.0234  LR: 9.999045407561532e-06


Epoch: [1][2100/8042] Elapsed 22m 29s (remain 63m 35s) Loss: 0.0181(0.0385) Grad Norm: 38946.6328  LR: 9.998947615357808e-06


Epoch: [1][2200/8042] Elapsed 23m 34s (remain 62m 34s) Loss: 0.0081(0.0377) Grad Norm: 48026.1211  LR: 9.998845055813553e-06


Epoch: [1][2300/8042] Elapsed 24m 37s (remain 61m 26s) Loss: 0.0094(0.0368) Grad Norm: 51269.4805  LR: 9.998737729026573e-06


Epoch: [1][2400/8042] Elapsed 25m 40s (remain 60m 19s) Loss: 0.0213(0.0359) Grad Norm: 97236.9141  LR: 9.998625635099223e-06


Epoch: [1][2500/8042] Elapsed 26m 44s (remain 59m 14s) Loss: 0.0031(0.0352) Grad Norm: 39192.1953  LR: 9.998508774138403e-06


Epoch: [1][2600/8042] Elapsed 27m 50s (remain 58m 13s) Loss: 0.0186(0.0345) Grad Norm: 79498.4922  LR: 9.998387146255558e-06


Epoch: [1][2700/8042] Elapsed 28m 54s (remain 57m 9s) Loss: 0.0235(0.0339) Grad Norm: 118486.5234  LR: 9.998260751566684e-06


Epoch: [1][2800/8042] Elapsed 29m 57s (remain 56m 4s) Loss: 0.0166(0.0333) Grad Norm: 75739.7266  LR: 9.998129590192316e-06


Epoch: [1][2900/8042] Elapsed 31m 2s (remain 55m 1s) Loss: 0.0102(0.0327) Grad Norm: 41682.1328  LR: 9.997993662257542e-06


Epoch: [1][3000/8042] Elapsed 32m 6s (remain 53m 55s) Loss: 0.0141(0.0321) Grad Norm: 61662.4570  LR: 9.997852967891992e-06


Epoch: [1][3100/8042] Elapsed 33m 12s (remain 52m 55s) Loss: 0.0082(0.0317) Grad Norm: 73989.1172  LR: 9.99770750722984e-06


Epoch: [1][3200/8042] Elapsed 34m 16s (remain 51m 50s) Loss: 0.0103(0.0311) Grad Norm: 29647.0117  LR: 9.997557280409806e-06


Epoch: [1][3300/8042] Elapsed 35m 19s (remain 50m 44s) Loss: 0.0036(0.0307) Grad Norm: 59532.3242  LR: 9.99740228757516e-06


Epoch: [1][3400/8042] Elapsed 36m 22s (remain 49m 38s) Loss: 0.0096(0.0303) Grad Norm: 37914.2695  LR: 9.997242528873713e-06


Epoch: [1][3500/8042] Elapsed 37m 26s (remain 48m 33s) Loss: 0.0051(0.0299) Grad Norm: 30620.1016  LR: 9.99707800445782e-06


Epoch: [1][3600/8042] Elapsed 38m 30s (remain 47m 29s) Loss: 0.0286(0.0295) Grad Norm: 149949.7188  LR: 9.996908714484386e-06


Epoch: [1][3700/8042] Elapsed 39m 34s (remain 46m 25s) Loss: 0.0117(0.0291) Grad Norm: 62678.2500  LR: 9.996734659114854e-06


Epoch: [1][3800/8042] Elapsed 40m 40s (remain 45m 22s) Loss: 0.0086(0.0287) Grad Norm: 70010.8125  LR: 9.996555838515217e-06


Epoch: [1][3900/8042] Elapsed 41m 44s (remain 44m 18s) Loss: 0.0190(0.0283) Grad Norm: 153237.3594  LR: 9.996372252856011e-06


Epoch: [1][4000/8042] Elapsed 42m 46s (remain 43m 12s) Loss: 0.0067(0.0279) Grad Norm: 44335.5977  LR: 9.996183902312316e-06


Epoch: [1][4100/8042] Elapsed 43m 50s (remain 42m 7s) Loss: 0.0100(0.0276) Grad Norm: 189559.7031  LR: 9.995990787063755e-06


Epoch: [1][4200/8042] Elapsed 44m 51s (remain 41m 0s) Loss: 0.0011(0.0272) Grad Norm: 14467.4336  LR: 9.995792907294496e-06


Epoch: [1][4300/8042] Elapsed 45m 56s (remain 39m 57s) Loss: 0.0325(0.0269) Grad Norm: 324096.9688  LR: 9.995590263193251e-06


Epoch: [1][4400/8042] Elapsed 47m 1s (remain 38m 54s) Loss: 0.0444(0.0266) Grad Norm: 398594.1250  LR: 9.995382854953278e-06


Epoch: [1][4500/8042] Elapsed 48m 5s (remain 37m 50s) Loss: 0.0015(0.0263) Grad Norm: 26727.2422  LR: 9.995170682772371e-06


Epoch: [1][4600/8042] Elapsed 49m 8s (remain 36m 44s) Loss: 0.0061(0.0260) Grad Norm: 73584.7656  LR: 9.994953746852873e-06


Epoch: [1][4700/8042] Elapsed 50m 11s (remain 35m 40s) Loss: 0.0064(0.0257) Grad Norm: 91820.4219  LR: 9.994732047401673e-06


Epoch: [1][4800/8042] Elapsed 51m 13s (remain 34m 34s) Loss: 0.0104(0.0254) Grad Norm: 103457.8438  LR: 9.994505584630195e-06


Epoch: [1][4900/8042] Elapsed 52m 15s (remain 33m 29s) Loss: 0.0126(0.0251) Grad Norm: 110200.2422  LR: 9.994274358754412e-06


Epoch: [1][5000/8042] Elapsed 53m 17s (remain 32m 24s) Loss: 0.0233(0.0249) Grad Norm: 242729.6406  LR: 9.994038369994834e-06


Epoch: [1][5100/8042] Elapsed 54m 21s (remain 31m 20s) Loss: 0.0158(0.0246) Grad Norm: 152178.3906  LR: 9.993797618576519e-06


Epoch: [1][5200/8042] Elapsed 55m 27s (remain 30m 17s) Loss: 0.0040(0.0243) Grad Norm: 23978.1211  LR: 9.993552104729064e-06


Epoch: [1][5300/8042] Elapsed 56m 30s (remain 29m 13s) Loss: 0.0095(0.0241) Grad Norm: 116346.2188  LR: 9.993301828686604e-06


Epoch: [1][5400/8042] Elapsed 57m 33s (remain 28m 8s) Loss: 0.0143(0.0239) Grad Norm: 340224.2500  LR: 9.993046790687824e-06


Epoch: [1][5500/8042] Elapsed 58m 35s (remain 27m 3s) Loss: 0.0086(0.0237) Grad Norm: 159671.7031  LR: 9.992786990975943e-06


Epoch: [1][5600/8042] Elapsed 59m 41s (remain 26m 0s) Loss: 0.0190(0.0234) Grad Norm: 117176.6328  LR: 9.992522429798725e-06


Epoch: [1][5700/8042] Elapsed 60m 46s (remain 24m 57s) Loss: 0.0044(0.0232) Grad Norm: 80946.3359  LR: 9.992253107408474e-06


Epoch: [1][5800/8042] Elapsed 61m 51s (remain 23m 53s) Loss: 0.0052(0.0230) Grad Norm: 84931.4844  LR: 9.991979024062035e-06


Epoch: [1][5900/8042] Elapsed 62m 58s (remain 22m 50s) Loss: 0.0150(0.0228) Grad Norm: 377779.5938  LR: 9.991700180020791e-06


Epoch: [1][6000/8042] Elapsed 64m 1s (remain 21m 46s) Loss: 0.0121(0.0226) Grad Norm: 300436.9688  LR: 9.991416575550667e-06


Epoch: [1][6100/8042] Elapsed 65m 5s (remain 20m 42s) Loss: 0.0014(0.0224) Grad Norm: 54548.5938  LR: 9.99112821092213e-06


Epoch: [1][6200/8042] Elapsed 66m 9s (remain 19m 38s) Loss: 0.0229(0.0222) Grad Norm: 488269.4062  LR: 9.990835086410181e-06


Epoch: [1][6300/8042] Elapsed 67m 14s (remain 18m 34s) Loss: 0.0204(0.0221) Grad Norm: 328504.6562  LR: 9.990537202294368e-06


Epoch: [1][6400/8042] Elapsed 68m 18s (remain 17m 30s) Loss: 0.0067(0.0219) Grad Norm: 227420.1250  LR: 9.990234558858768e-06


Epoch: [1][6500/8042] Elapsed 69m 21s (remain 16m 26s) Loss: 0.0068(0.0218) Grad Norm: 234089.6875  LR: 9.989927156392009e-06


Epoch: [1][6600/8042] Elapsed 70m 25s (remain 15m 22s) Loss: 0.0026(0.0216) Grad Norm: 107627.3906  LR: 9.989614995187246e-06


Epoch: [1][6700/8042] Elapsed 71m 29s (remain 14m 18s) Loss: 0.0093(0.0214) Grad Norm: 432216.6250  LR: 9.989298075542182e-06


Epoch: [1][6800/8042] Elapsed 72m 34s (remain 13m 14s) Loss: 0.0034(0.0213) Grad Norm: 123648.0547  LR: 9.988976397759048e-06


Epoch: [1][6900/8042] Elapsed 73m 38s (remain 12m 10s) Loss: 0.0071(0.0211) Grad Norm: 271149.7500  LR: 9.988649962144622e-06


Epoch: [1][7000/8042] Elapsed 74m 42s (remain 11m 6s) Loss: 0.0072(0.0210) Grad Norm: 180840.6562  LR: 9.988318769010215e-06


Epoch: [1][7100/8042] Elapsed 75m 46s (remain 10m 2s) Loss: 0.0058(0.0208) Grad Norm: 106775.0781  LR: 9.987982818671673e-06


Epoch: [1][7200/8042] Elapsed 76m 49s (remain 8m 58s) Loss: 0.0219(0.0207) Grad Norm: 590665.1250  LR: 9.987642111449386e-06


Epoch: [1][7300/8042] Elapsed 77m 55s (remain 7m 54s) Loss: 0.0177(0.0205) Grad Norm: 194030.4531  LR: 9.98729664766827e-06


Epoch: [1][7400/8042] Elapsed 79m 1s (remain 6m 50s) Loss: 0.0087(0.0204) Grad Norm: 189757.9375  LR: 9.986946427657789e-06


Epoch: [1][7500/8042] Elapsed 80m 5s (remain 5m 46s) Loss: 0.0110(0.0202) Grad Norm: 203836.9375  LR: 9.986591451751933e-06


Epoch: [1][7600/8042] Elapsed 81m 11s (remain 4m 42s) Loss: 0.0096(0.0201) Grad Norm: 253041.4375  LR: 9.986231720289233e-06


Epoch: [1][7700/8042] Elapsed 82m 15s (remain 3m 38s) Loss: 0.0140(0.0200) Grad Norm: 268108.8750  LR: 9.98586723361275e-06


Epoch: [1][7800/8042] Elapsed 83m 19s (remain 2m 34s) Loss: 0.0141(0.0199) Grad Norm: 375391.0625  LR: 9.985497992070092e-06


Epoch: [1][7900/8042] Elapsed 84m 24s (remain 1m 30s) Loss: 0.0272(0.0198) Grad Norm: 481172.9375  LR: 9.985123996013386e-06


Epoch: [1][8000/8042] Elapsed 85m 28s (remain 0m 26s) Loss: 0.0070(0.0197) Grad Norm: 391251.8438  LR: 9.984745245799303e-06


Epoch: [1][8041/8042] Elapsed 85m 54s (remain 0m 0s) Loss: 0.0050(0.0196) Grad Norm: 376937.7188  LR: 9.984588584106055e-06
Epoch 1 completed. Average loss: 0.0196
Starting evaluation
Starting validation...


EVAL: [0/2011] Elapsed 0m 1s (remain 61m 46s) Loss: 0.0126(0.0126) 


EVAL: [100/2011] Elapsed 0m 46s (remain 14m 40s) Loss: 0.0083(0.0151) 


EVAL: [200/2011] Elapsed 1m 31s (remain 13m 40s) Loss: 0.0032(0.0163) 


EVAL: [300/2011] Elapsed 2m 16s (remain 12m 52s) Loss: 0.0155(0.0158) 


EVAL: [400/2011] Elapsed 3m 0s (remain 12m 3s) Loss: 0.0125(0.0144) 


EVAL: [500/2011] Elapsed 3m 45s (remain 11m 18s) Loss: 0.0007(0.0136) 


EVAL: [600/2011] Elapsed 4m 29s (remain 10m 32s) Loss: 0.0085(0.0122) 


EVAL: [700/2011] Elapsed 5m 14s (remain 9m 48s) Loss: 0.0005(0.0113) 


EVAL: [800/2011] Elapsed 5m 59s (remain 9m 3s) Loss: 0.0046(0.0105) 


EVAL: [900/2011] Elapsed 6m 45s (remain 8m 19s) Loss: 0.0114(0.0099) 


EVAL: [1000/2011] Elapsed 7m 29s (remain 7m 33s) Loss: 0.0082(0.0098) 


EVAL: [1100/2011] Elapsed 8m 13s (remain 6m 47s) Loss: 0.0153(0.0100) 


EVAL: [1200/2011] Elapsed 8m 59s (remain 6m 3s) Loss: 0.0009(0.0101) 


EVAL: [1300/2011] Elapsed 9m 44s (remain 5m 18s) Loss: 0.0096(0.0101) 


EVAL: [1400/2011] Elapsed 10m 28s (remain 4m 33s) Loss: 0.0023(0.0101) 


EVAL: [1500/2011] Elapsed 11m 13s (remain 3m 48s) Loss: 0.0289(0.0100) 


EVAL: [1600/2011] Elapsed 11m 59s (remain 3m 4s) Loss: 0.0027(0.0101) 


EVAL: [1700/2011] Elapsed 12m 43s (remain 2m 19s) Loss: 0.0077(0.0102) 


EVAL: [1800/2011] Elapsed 13m 28s (remain 1m 34s) Loss: 0.0016(0.0102) 


EVAL: [1900/2011] Elapsed 14m 12s (remain 0m 49s) Loss: 0.0037(0.0100) 


EVAL: [2000/2011] Elapsed 14m 57s (remain 0m 4s) Loss: 0.0061(0.0098) 


EVAL: [2010/2011] Elapsed 15m 1s (remain 0m 0s) Loss: 0.0069(0.0097) 


Epoch 1 - avg_train_loss: 0.0196  avg_val_loss: 0.0097  time: 6070s


Epoch 1 - Score: 0.9140


Epoch 1 - Save Best Score: 0.9140 Model


Starting training

Starting training epoch 2 for fold 0


Epoch: [2][0/8042] Elapsed 0m 1s (remain 172m 2s) Loss: 0.0021(0.0021) Grad Norm: 10434.1660  LR: 9.984584753106386e-06


Epoch: [2][100/8042] Elapsed 1m 5s (remain 86m 26s) Loss: 0.0130(0.0120) Grad Norm: 44179.5039  LR: 9.984199252610533e-06


Epoch: [2][200/8042] Elapsed 2m 10s (remain 85m 4s) Loss: 0.0218(0.0114) Grad Norm: 41638.8711  LR: 9.983808998839205e-06


Epoch: [2][300/8042] Elapsed 3m 16s (remain 84m 4s) Loss: 0.0154(0.0118) Grad Norm: 44786.3594  LR: 9.983413992164567e-06


Epoch: [2][400/8042] Elapsed 4m 20s (remain 82m 52s) Loss: 0.0054(0.0115) Grad Norm: 21516.0664  LR: 9.983014232963331e-06


Epoch: [2][500/8042] Elapsed 5m 26s (remain 81m 56s) Loss: 0.0176(0.0114) Grad Norm: 48946.3477  LR: 9.982609721616732e-06


Epoch: [2][600/8042] Elapsed 6m 31s (remain 80m 44s) Loss: 0.0296(0.0113) Grad Norm: 47033.2305  LR: 9.982200458510541e-06


Epoch: [2][700/8042] Elapsed 7m 36s (remain 79m 41s) Loss: 0.0038(0.0112) Grad Norm: 18140.6582  LR: 9.981786444035059e-06


Epoch: [2][800/8042] Elapsed 8m 43s (remain 78m 50s) Loss: 0.0007(0.0109) Grad Norm: 3626.0771  LR: 9.981367678585118e-06


Epoch: [2][900/8042] Elapsed 9m 47s (remain 77m 35s) Loss: 0.0057(0.0107) Grad Norm: 14385.0146  LR: 9.980944162560083e-06


Epoch: [2][1000/8042] Elapsed 10m 51s (remain 76m 20s) Loss: 0.0090(0.0107) Grad Norm: 11741.9502  LR: 9.980515896363848e-06


Epoch: [2][1100/8042] Elapsed 11m 56s (remain 75m 19s) Loss: 0.0048(0.0107) Grad Norm: 32874.3086  LR: 9.980082880404833e-06


Epoch: [2][1200/8042] Elapsed 13m 3s (remain 74m 23s) Loss: 0.0062(0.0107) Grad Norm: 41052.8711  LR: 9.979645115095999e-06


Epoch: [2][1300/8042] Elapsed 14m 8s (remain 73m 18s) Loss: 0.0045(0.0105) Grad Norm: 12419.5850  LR: 9.979202600854823e-06


Epoch: [2][1400/8042] Elapsed 15m 14s (remain 72m 15s) Loss: 0.0053(0.0105) Grad Norm: 17062.6641  LR: 9.978755338103322e-06


Epoch: [2][1500/8042] Elapsed 16m 18s (remain 71m 4s) Loss: 0.0095(0.0104) Grad Norm: 31232.0938  LR: 9.97830332726803e-06


Epoch: [2][1600/8042] Elapsed 17m 22s (remain 69m 52s) Loss: 0.0079(0.0104) Grad Norm: 25190.2949  LR: 9.977846568780022e-06


Epoch: [2][1700/8042] Elapsed 18m 28s (remain 68m 53s) Loss: 0.0097(0.0104) Grad Norm: 19666.6348  LR: 9.977385063074894e-06


Epoch: [2][1800/8042] Elapsed 19m 34s (remain 67m 51s) Loss: 0.0064(0.0104) Grad Norm: 13699.8945  LR: 9.976918810592763e-06


Epoch: [2][1900/8042] Elapsed 20m 40s (remain 66m 46s) Loss: 0.0038(0.0103) Grad Norm: 17089.2578  LR: 9.976447811778288e-06


Epoch: [2][2000/8042] Elapsed 21m 45s (remain 65m 41s) Loss: 0.0103(0.0103) Grad Norm: 50650.4414  LR: 9.97597206708064e-06


Epoch: [2][2100/8042] Elapsed 22m 50s (remain 64m 35s) Loss: 0.0060(0.0102) Grad Norm: 29804.8047  LR: 9.975491576953524e-06


Epoch: [2][2200/8042] Elapsed 23m 55s (remain 63m 29s) Loss: 0.0302(0.0102) Grad Norm: 174525.5781  LR: 9.975006341855167e-06


Epoch: [2][2300/8042] Elapsed 25m 0s (remain 62m 24s) Loss: 0.0250(0.0101) Grad Norm: 105306.8047  LR: 9.974516362248325e-06


Epoch: [2][2400/8042] Elapsed 26m 6s (remain 61m 20s) Loss: 0.0118(0.0100) Grad Norm: 63052.0469  LR: 9.974021638600276e-06


Epoch: [2][2500/8042] Elapsed 27m 12s (remain 60m 16s) Loss: 0.0076(0.0099) Grad Norm: 41719.6680  LR: 9.973522171382821e-06


Epoch: [2][2600/8042] Elapsed 28m 18s (remain 59m 13s) Loss: 0.0028(0.0099) Grad Norm: 81516.5312  LR: 9.973017961072287e-06


Epoch: [2][2700/8042] Elapsed 29m 22s (remain 58m 5s) Loss: 0.0012(0.0098) Grad Norm: 11209.3262  LR: 9.972509008149522e-06


Epoch: [2][2800/8042] Elapsed 30m 28s (remain 57m 1s) Loss: 0.0010(0.0097) Grad Norm: 9021.8545  LR: 9.971995313099903e-06


Epoch: [2][2900/8042] Elapsed 31m 34s (remain 55m 56s) Loss: 0.0018(0.0097) Grad Norm: 22137.7988  LR: 9.97147687641332e-06


Epoch: [2][3000/8042] Elapsed 32m 37s (remain 54m 48s) Loss: 0.0087(0.0096) Grad Norm: 106901.6328  LR: 9.970953698584192e-06


Epoch: [2][3100/8042] Elapsed 33m 42s (remain 53m 42s) Loss: 0.0003(0.0096) Grad Norm: 3879.6655  LR: 9.970425780111456e-06


Epoch: [2][3200/8042] Elapsed 34m 48s (remain 52m 38s) Loss: 0.0044(0.0096) Grad Norm: 20536.1094  LR: 9.969893121498575e-06


Epoch: [2][3300/8042] Elapsed 35m 56s (remain 51m 37s) Loss: 0.0067(0.0095) Grad Norm: 60112.6289  LR: 9.969355723253527e-06


Epoch: [2][3400/8042] Elapsed 37m 2s (remain 50m 32s) Loss: 0.0069(0.0095) Grad Norm: 57386.5859  LR: 9.968813585888811e-06


Epoch: [2][3500/8042] Elapsed 38m 10s (remain 49m 30s) Loss: 0.0177(0.0094) Grad Norm: 65821.1641  LR: 9.968266709921448e-06


Epoch: [2][3600/8042] Elapsed 39m 16s (remain 48m 26s) Loss: 0.0347(0.0094) Grad Norm: 157594.4062  LR: 9.967715095872975e-06


Epoch: [2][3700/8042] Elapsed 40m 23s (remain 47m 22s) Loss: 0.0817(0.0094) Grad Norm: 159220.5781  LR: 9.96715874426945e-06


Epoch: [2][3800/8042] Elapsed 41m 30s (remain 46m 19s) Loss: 0.0089(0.0093) Grad Norm: 49592.7852  LR: 9.966597655641445e-06


Epoch: [2][3900/8042] Elapsed 42m 36s (remain 45m 13s) Loss: 0.0112(0.0093) Grad Norm: 48017.2070  LR: 9.96603183052406e-06


Epoch: [2][4000/8042] Elapsed 43m 43s (remain 44m 9s) Loss: 0.0053(0.0093) Grad Norm: 26084.7695  LR: 9.9654612694569e-06


Epoch: [2][4100/8042] Elapsed 44m 48s (remain 43m 3s) Loss: 0.0176(0.0092) Grad Norm: 317289.1875  LR: 9.96488597298409e-06


Epoch: [2][4200/8042] Elapsed 45m 53s (remain 41m 57s) Loss: 0.0009(0.0092) Grad Norm: 28589.7305  LR: 9.964305941654275e-06


Epoch: [2][4300/8042] Elapsed 46m 59s (remain 40m 52s) Loss: 0.0042(0.0091) Grad Norm: 82529.2109  LR: 9.963721176020612e-06


Epoch: [2][4400/8042] Elapsed 48m 5s (remain 39m 47s) Loss: 0.0076(0.0091) Grad Norm: 185025.7812  LR: 9.963131676640773e-06


Epoch: [2][4500/8042] Elapsed 49m 12s (remain 38m 42s) Loss: 0.0011(0.0091) Grad Norm: 15605.0215  LR: 9.962537444076948e-06


Epoch: [2][4600/8042] Elapsed 50m 18s (remain 37m 37s) Loss: 0.0018(0.0090) Grad Norm: 30242.3652  LR: 9.961938478895834e-06


Epoch: [2][4700/8042] Elapsed 51m 24s (remain 36m 31s) Loss: 0.0099(0.0090) Grad Norm: 542266.8750  LR: 9.961334781668648e-06


Epoch: [2][4800/8042] Elapsed 52m 29s (remain 35m 26s) Loss: 0.0045(0.0090) Grad Norm: 56070.2617  LR: 9.960726352971117e-06


Epoch: [2][4900/8042] Elapsed 53m 36s (remain 34m 21s) Loss: 0.0105(0.0090) Grad Norm: 87944.1016  LR: 9.960113193383479e-06


Epoch: [2][5000/8042] Elapsed 54m 43s (remain 33m 16s) Loss: 0.0042(0.0089) Grad Norm: 121476.7188  LR: 9.959495303490487e-06


Epoch: [2][5100/8042] Elapsed 55m 49s (remain 32m 11s) Loss: 0.0014(0.0089) Grad Norm: 33101.9922  LR: 9.958872683881404e-06


Epoch: [2][5200/8042] Elapsed 56m 56s (remain 31m 5s) Loss: 0.0013(0.0089) Grad Norm: 33296.1914  LR: 9.958245335149999e-06


Epoch: [2][5300/8042] Elapsed 58m 2s (remain 30m 0s) Loss: 0.0044(0.0089) Grad Norm: 95359.6562  LR: 9.95761325789456e-06


Epoch: [2][5400/8042] Elapsed 59m 7s (remain 28m 54s) Loss: 0.0002(0.0088) Grad Norm: 2292.2913  LR: 9.956976452717874e-06


Epoch: [2][5500/8042] Elapsed 60m 12s (remain 27m 48s) Loss: 0.0033(0.0088) Grad Norm: 37061.2539  LR: 9.956334920227249e-06


Epoch: [2][5600/8042] Elapsed 61m 17s (remain 26m 42s) Loss: 0.0038(0.0088) Grad Norm: 74245.6641  LR: 9.955688661034487e-06


Epoch: [2][5700/8042] Elapsed 62m 22s (remain 25m 36s) Loss: 0.0028(0.0087) Grad Norm: 20292.1465  LR: 9.95503767575591e-06


Epoch: [2][5800/8042] Elapsed 63m 27s (remain 24m 30s) Loss: 0.0064(0.0087) Grad Norm: 153439.3281  LR: 9.954381965012343e-06


Epoch: [2][5900/8042] Elapsed 64m 34s (remain 23m 25s) Loss: 0.0123(0.0087) Grad Norm: 109807.0859  LR: 9.953721529429114e-06


Epoch: [2][6000/8042] Elapsed 65m 38s (remain 22m 19s) Loss: 0.0150(0.0087) Grad Norm: 97349.6875  LR: 9.95305636963606e-06


Epoch: [2][6100/8042] Elapsed 66m 43s (remain 21m 13s) Loss: 0.0051(0.0087) Grad Norm: 192418.8125  LR: 9.952386486267525e-06


Epoch: [2][6200/8042] Elapsed 67m 48s (remain 20m 8s) Loss: 0.0064(0.0087) Grad Norm: 241125.9844  LR: 9.951711879962356e-06


Epoch: [2][6300/8042] Elapsed 68m 53s (remain 19m 2s) Loss: 0.0045(0.0086) Grad Norm: 167562.5625  LR: 9.951032551363902e-06


Epoch: [2][6400/8042] Elapsed 69m 58s (remain 17m 56s) Loss: 0.0096(0.0086) Grad Norm: 196427.0469  LR: 9.95034850112002e-06


Epoch: [2][6500/8042] Elapsed 71m 2s (remain 16m 50s) Loss: 0.0113(0.0086) Grad Norm: 665815.3750  LR: 9.949659729883063e-06


Epoch: [2][6600/8042] Elapsed 72m 7s (remain 15m 44s) Loss: 0.0023(0.0086) Grad Norm: 92259.4844  LR: 9.948966238309897e-06


Epoch: [2][6700/8042] Elapsed 73m 10s (remain 14m 38s) Loss: 0.0029(0.0086) Grad Norm: 177987.2188  LR: 9.948268027061878e-06


Epoch: [2][6800/8042] Elapsed 74m 14s (remain 13m 32s) Loss: 0.0252(0.0085) Grad Norm: 235109.0625  LR: 9.94756509680487e-06


Epoch: [2][6900/8042] Elapsed 75m 18s (remain 12m 27s) Loss: 0.0043(0.0085) Grad Norm: 102843.4219  LR: 9.946857448209238e-06


Epoch: [2][7000/8042] Elapsed 76m 22s (remain 11m 21s) Loss: 0.0035(0.0085) Grad Norm: 156661.2812  LR: 9.946145081949839e-06


Epoch: [2][7100/8042] Elapsed 77m 28s (remain 10m 16s) Loss: 0.0130(0.0085) Grad Norm: 318940.0000  LR: 9.94542799870604e-06


Epoch: [2][7200/8042] Elapsed 78m 33s (remain 9m 10s) Loss: 0.0083(0.0085) Grad Norm: 203350.7656  LR: 9.944706199161698e-06


Epoch: [2][7300/8042] Elapsed 79m 38s (remain 8m 5s) Loss: 0.0052(0.0085) Grad Norm: 410005.5312  LR: 9.943979684005172e-06


Epoch: [2][7400/8042] Elapsed 80m 43s (remain 6m 59s) Loss: 0.0016(0.0085) Grad Norm: 40489.6055  LR: 9.943248453929317e-06


Epoch: [2][7500/8042] Elapsed 81m 49s (remain 5m 54s) Loss: 0.0110(0.0084) Grad Norm: 330252.0000  LR: 9.942512509631484e-06


Epoch: [2][7600/8042] Elapsed 82m 52s (remain 4m 48s) Loss: 0.0068(0.0084) Grad Norm: 239537.0000  LR: 9.941771851813517e-06


Epoch: [2][7700/8042] Elapsed 83m 57s (remain 3m 43s) Loss: 0.0059(0.0084) Grad Norm: 210781.7188  LR: 9.941026481181763e-06


Epoch: [2][7800/8042] Elapsed 85m 2s (remain 2m 37s) Loss: 0.0061(0.0084) Grad Norm: 276639.4688  LR: 9.940276398447058e-06


Epoch: [2][7900/8042] Elapsed 86m 6s (remain 1m 32s) Loss: 0.0100(0.0084) Grad Norm: 204881.6562  LR: 9.939521604324729e-06


Epoch: [2][8000/8042] Elapsed 87m 12s (remain 0m 26s) Loss: 0.0022(0.0083) Grad Norm: 234999.2031  LR: 9.938762099534604e-06


Epoch: [2][8041/8042] Elapsed 87m 39s (remain 0m 0s) Loss: 0.0173(0.0083) Grad Norm: 641309.0000  LR: 9.93844934112016e-06
Epoch 2 completed. Average loss: 0.0083
Starting evaluation
Starting validation...


EVAL: [0/2011] Elapsed 0m 0s (remain 25m 24s) Loss: 0.0083(0.0083) 


EVAL: [100/2011] Elapsed 0m 45s (remain 14m 24s) Loss: 0.0000(0.0119) 


EVAL: [200/2011] Elapsed 1m 31s (remain 13m 39s) Loss: 0.0016(0.0137) 


EVAL: [300/2011] Elapsed 2m 16s (remain 12m 53s) Loss: 0.0113(0.0129) 


EVAL: [400/2011] Elapsed 3m 0s (remain 12m 5s) Loss: 0.0042(0.0113) 


EVAL: [500/2011] Elapsed 3m 44s (remain 11m 17s) Loss: 0.0002(0.0105) 


EVAL: [600/2011] Elapsed 4m 30s (remain 10m 33s) Loss: 0.0055(0.0094) 


EVAL: [700/2011] Elapsed 5m 15s (remain 9m 49s) Loss: 0.0004(0.0086) 


EVAL: [800/2011] Elapsed 5m 59s (remain 9m 3s) Loss: 0.0023(0.0080) 


EVAL: [900/2011] Elapsed 6m 44s (remain 8m 18s) Loss: 0.0048(0.0075) 


EVAL: [1000/2011] Elapsed 7m 28s (remain 7m 32s) Loss: 0.0057(0.0073) 


EVAL: [1100/2011] Elapsed 8m 13s (remain 6m 47s) Loss: 0.0102(0.0074) 


EVAL: [1200/2011] Elapsed 8m 58s (remain 6m 2s) Loss: 0.0004(0.0074) 


EVAL: [1300/2011] Elapsed 9m 43s (remain 5m 18s) Loss: 0.0094(0.0075) 


EVAL: [1400/2011] Elapsed 10m 28s (remain 4m 33s) Loss: 0.0028(0.0074) 


EVAL: [1500/2011] Elapsed 11m 13s (remain 3m 48s) Loss: 0.0120(0.0073) 


EVAL: [1600/2011] Elapsed 11m 58s (remain 3m 3s) Loss: 0.0017(0.0073) 


EVAL: [1700/2011] Elapsed 12m 44s (remain 2m 19s) Loss: 0.0045(0.0073) 


EVAL: [1800/2011] Elapsed 13m 27s (remain 1m 34s) Loss: 0.0003(0.0073) 


EVAL: [1900/2011] Elapsed 14m 12s (remain 0m 49s) Loss: 0.0021(0.0071) 


EVAL: [2000/2011] Elapsed 14m 57s (remain 0m 4s) Loss: 0.0024(0.0069) 


EVAL: [2010/2011] Elapsed 15m 2s (remain 0m 0s) Loss: 0.0067(0.0069) 


Epoch 2 - avg_train_loss: 0.0083  avg_val_loss: 0.0069  time: 6174s


Epoch 2 - Score: 0.9408


Epoch 2 - Save Best Score: 0.9408 Model


Starting training

Starting training epoch 3 for fold 0


Epoch: [3][0/8042] Elapsed 0m 1s (remain 165m 29s) Loss: 0.0034(0.0034) Grad Norm: 26147.6230  LR: 9.938441702975689e-06


Epoch: [3][100/8042] Elapsed 1m 6s (remain 87m 6s) Loss: 0.0022(0.0073) Grad Norm: 20317.1367  LR: 9.937675510282891e-06


Epoch: [3][200/8042] Elapsed 2m 8s (remain 83m 50s) Loss: 0.0033(0.0073) Grad Norm: 11545.1084  LR: 9.936904608682855e-06


Epoch: [3][300/8042] Elapsed 3m 12s (remain 82m 25s) Loss: 0.0020(0.0077) Grad Norm: 10410.8828  LR: 9.93612899891077e-06


Epoch: [3][400/8042] Elapsed 4m 14s (remain 80m 54s) Loss: 0.0005(0.0081) Grad Norm: 2386.1934  LR: 9.93534868170631e-06


Epoch: [3][500/8042] Elapsed 5m 17s (remain 79m 38s) Loss: 0.0082(0.0080) Grad Norm: 30833.8203  LR: 9.934563657813637e-06


Epoch: [3][600/8042] Elapsed 6m 20s (remain 78m 35s) Loss: 0.0038(0.0078) Grad Norm: 22137.3281  LR: 9.933773927981405e-06


Epoch: [3][700/8042] Elapsed 7m 23s (remain 77m 23s) Loss: 0.0213(0.0078) Grad Norm: 70965.4609  LR: 9.932979492962756e-06


Epoch: [3][800/8042] Elapsed 8m 27s (remain 76m 29s) Loss: 0.0019(0.0078) Grad Norm: 14275.9580  LR: 9.932180353515314e-06


Epoch: [3][900/8042] Elapsed 9m 32s (remain 75m 34s) Loss: 0.0045(0.0076) Grad Norm: 33577.6914  LR: 9.931376510401199e-06


Epoch: [3][1000/8042] Elapsed 10m 35s (remain 74m 28s) Loss: 0.0028(0.0076) Grad Norm: 12746.0215  LR: 9.930567964387006e-06


Epoch: [3][1100/8042] Elapsed 11m 40s (remain 73m 36s) Loss: 0.0103(0.0076) Grad Norm: 18003.9160  LR: 9.929754716243825e-06


Epoch: [3][1200/8042] Elapsed 12m 45s (remain 72m 41s) Loss: 0.0072(0.0076) Grad Norm: 27651.3750  LR: 9.92893676674722e-06


Epoch: [3][1300/8042] Elapsed 13m 49s (remain 71m 39s) Loss: 0.0019(0.0076) Grad Norm: 25325.9883  LR: 9.928114116677248e-06


Epoch: [3][1400/8042] Elapsed 14m 53s (remain 70m 37s) Loss: 0.0053(0.0076) Grad Norm: 26518.7949  LR: 9.927286766818443e-06


Epoch: [3][1500/8042] Elapsed 16m 0s (remain 69m 44s) Loss: 0.0078(0.0076) Grad Norm: 29946.2734  LR: 9.926454717959824e-06


Epoch: [3][1600/8042] Elapsed 17m 5s (remain 68m 44s) Loss: 0.0046(0.0075) Grad Norm: 28144.6309  LR: 9.925617970894887e-06


Epoch: [3][1700/8042] Elapsed 18m 10s (remain 67m 44s) Loss: 0.0078(0.0074) Grad Norm: 18167.3945  LR: 9.924776526421615e-06


Epoch: [3][1800/8042] Elapsed 19m 13s (remain 66m 37s) Loss: 0.0005(0.0074) Grad Norm: 5964.4971  LR: 9.923930385342468e-06


Epoch: [3][1900/8042] Elapsed 20m 19s (remain 65m 38s) Loss: 0.0038(0.0073) Grad Norm: 7261.3413  LR: 9.92307954846438e-06


Epoch: [3][2000/8042] Elapsed 21m 24s (remain 64m 38s) Loss: 0.0135(0.0073) Grad Norm: 64539.7305  LR: 9.922224016598773e-06


Epoch: [3][2100/8042] Elapsed 22m 30s (remain 63m 38s) Loss: 0.0102(0.0073) Grad Norm: 70509.6016  LR: 9.921363790561535e-06


Epoch: [3][2200/8042] Elapsed 23m 35s (remain 62m 36s) Loss: 0.0220(0.0073) Grad Norm: 68632.5859  LR: 9.920498871173043e-06


Epoch: [3][2300/8042] Elapsed 24m 40s (remain 61m 33s) Loss: 0.0016(0.0073) Grad Norm: 36981.9297  LR: 9.919629259258139e-06


Epoch: [3][2400/8042] Elapsed 25m 44s (remain 60m 28s) Loss: 0.0049(0.0072) Grad Norm: 33213.0156  LR: 9.918754955646147e-06


Epoch: [3][2500/8042] Elapsed 26m 48s (remain 59m 24s) Loss: 0.0051(0.0072) Grad Norm: 37050.4141  LR: 9.917875961170863e-06


Epoch: [3][2600/8042] Elapsed 27m 53s (remain 58m 21s) Loss: 0.0043(0.0072) Grad Norm: 27939.0352  LR: 9.916992276670556e-06


Epoch: [3][2700/8042] Elapsed 28m 57s (remain 57m 16s) Loss: 0.0001(0.0072) Grad Norm: 1283.4169  LR: 9.916103902987967e-06


Epoch: [3][2800/8042] Elapsed 30m 2s (remain 56m 11s) Loss: 0.0050(0.0072) Grad Norm: 23391.2676  LR: 9.915210840970314e-06


Epoch: [3][2900/8042] Elapsed 31m 6s (remain 55m 8s) Loss: 0.0088(0.0072) Grad Norm: 58771.9961  LR: 9.914313091469279e-06


Epoch: [3][3000/8042] Elapsed 32m 12s (remain 54m 5s) Loss: 0.0039(0.0071) Grad Norm: 22030.5898  LR: 9.91341065534102e-06


Epoch: [3][3100/8042] Elapsed 33m 17s (remain 53m 3s) Loss: 0.0236(0.0071) Grad Norm: 61609.7539  LR: 9.91250353344616e-06


Epoch: [3][3200/8042] Elapsed 34m 24s (remain 52m 2s) Loss: 0.0043(0.0070) Grad Norm: 31267.0156  LR: 9.911591726649794e-06


Epoch: [3][3300/8042] Elapsed 35m 30s (remain 50m 59s) Loss: 0.0016(0.0070) Grad Norm: 21802.7207  LR: 9.910675235821485e-06


Epoch: [3][3400/8042] Elapsed 36m 34s (remain 49m 55s) Loss: 0.0020(0.0070) Grad Norm: 19300.7812  LR: 9.909754061835262e-06


Epoch: [3][3500/8042] Elapsed 37m 40s (remain 48m 52s) Loss: 0.0014(0.0069) Grad Norm: 16748.8965  LR: 9.908828205569617e-06


Epoch: [3][3600/8042] Elapsed 38m 46s (remain 47m 49s) Loss: 0.0092(0.0069) Grad Norm: 83428.6172  LR: 9.90789766790751e-06


Epoch: [3][3700/8042] Elapsed 39m 51s (remain 46m 45s) Loss: 0.0011(0.0069) Grad Norm: 19622.4375  LR: 9.90696244973637e-06


Epoch: [3][3800/8042] Elapsed 40m 55s (remain 45m 39s) Loss: 0.0061(0.0069) Grad Norm: 66054.6250  LR: 9.90602255194808e-06


Epoch: [3][3900/8042] Elapsed 42m 0s (remain 44m 35s) Loss: 0.0029(0.0069) Grad Norm: 25263.8281  LR: 9.905077975438997e-06


Epoch: [3][4000/8042] Elapsed 43m 4s (remain 43m 30s) Loss: 0.0034(0.0069) Grad Norm: 42643.5586  LR: 9.90412872110993e-06


Epoch: [3][4100/8042] Elapsed 44m 10s (remain 42m 26s) Loss: 0.0154(0.0069) Grad Norm: 126590.6875  LR: 9.903174789866154e-06


Epoch: [3][4200/8042] Elapsed 45m 14s (remain 41m 21s) Loss: 0.0159(0.0068) Grad Norm: 88930.0859  LR: 9.902216182617405e-06


Epoch: [3][4300/8042] Elapsed 46m 18s (remain 40m 17s) Loss: 0.0016(0.0068) Grad Norm: 50942.7500  LR: 9.901252900277875e-06


Epoch: [3][4400/8042] Elapsed 47m 22s (remain 39m 12s) Loss: 0.0010(0.0068) Grad Norm: 19166.3848  LR: 9.90028494376622e-06


Epoch: [3][4500/8042] Elapsed 48m 28s (remain 38m 7s) Loss: 0.0231(0.0068) Grad Norm: 1152769.0000  LR: 9.899312314005545e-06


Epoch: [3][4600/8042] Elapsed 49m 31s (remain 37m 2s) Loss: 0.0013(0.0067) Grad Norm: 18845.1504  LR: 9.898335011923419e-06


Epoch: [3][4700/8042] Elapsed 50m 37s (remain 35m 58s) Loss: 0.0056(0.0067) Grad Norm: 113532.0000  LR: 9.897353038451865e-06


Epoch: [3][4800/8042] Elapsed 51m 41s (remain 34m 53s) Loss: 0.0056(0.0067) Grad Norm: 62280.3594  LR: 9.89636639452736e-06


Epoch: [3][4900/8042] Elapsed 52m 44s (remain 33m 47s) Loss: 0.0086(0.0067) Grad Norm: 90409.8828  LR: 9.895375081090835e-06


Epoch: [3][5000/8042] Elapsed 53m 48s (remain 32m 43s) Loss: 0.0028(0.0066) Grad Norm: 65756.3125  LR: 9.894379099087675e-06


Epoch: [3][5100/8042] Elapsed 54m 52s (remain 31m 38s) Loss: 0.0005(0.0066) Grad Norm: 23567.4375  LR: 9.893378449467719e-06


Epoch: [3][5200/8042] Elapsed 55m 55s (remain 30m 32s) Loss: 0.0008(0.0066) Grad Norm: 40768.6172  LR: 9.892373133185251e-06


Epoch: [3][5300/8042] Elapsed 56m 59s (remain 29m 28s) Loss: 0.0027(0.0066) Grad Norm: 32616.1367  LR: 9.891363151199013e-06


Epoch: [3][5400/8042] Elapsed 58m 1s (remain 28m 22s) Loss: 0.0008(0.0066) Grad Norm: 13210.4102  LR: 9.890348504472194e-06


Epoch: [3][5500/8042] Elapsed 59m 3s (remain 27m 16s) Loss: 0.0044(0.0066) Grad Norm: 30835.6133  LR: 9.88932919397243e-06


Epoch: [3][5600/8042] Elapsed 60m 4s (remain 26m 11s) Loss: 0.0084(0.0066) Grad Norm: 123623.5859  LR: 9.888305220671804e-06


Epoch: [3][5700/8042] Elapsed 61m 7s (remain 25m 5s) Loss: 0.0041(0.0065) Grad Norm: 90534.2188  LR: 9.887276585546848e-06


Epoch: [3][5800/8042] Elapsed 62m 9s (remain 24m 0s) Loss: 0.0050(0.0065) Grad Norm: 92837.1328  LR: 9.886243289578541e-06


Epoch: [3][5900/8042] Elapsed 63m 13s (remain 22m 56s) Loss: 0.0004(0.0065) Grad Norm: 17495.0859  LR: 9.885205333752306e-06


Epoch: [3][6000/8042] Elapsed 64m 19s (remain 21m 52s) Loss: 0.0012(0.0065) Grad Norm: 33636.0508  LR: 9.884162719058006e-06


Epoch: [3][6100/8042] Elapsed 65m 25s (remain 20m 48s) Loss: 0.0077(0.0065) Grad Norm: 177857.7500  LR: 9.88311544648995e-06


Epoch: [3][6200/8042] Elapsed 66m 29s (remain 19m 44s) Loss: 0.0024(0.0065) Grad Norm: 110384.6953  LR: 9.882063517046892e-06


Epoch: [3][6300/8042] Elapsed 67m 32s (remain 18m 39s) Loss: 0.0027(0.0065) Grad Norm: 104055.7266  LR: 9.881006931732023e-06


Epoch: [3][6400/8042] Elapsed 68m 37s (remain 17m 35s) Loss: 0.0014(0.0065) Grad Norm: 125604.3750  LR: 9.879945691552975e-06


Epoch: [3][6500/8042] Elapsed 69m 40s (remain 16m 30s) Loss: 0.0167(0.0065) Grad Norm: 440756.2188  LR: 9.87887979752182e-06


Epoch: [3][6600/8042] Elapsed 70m 42s (remain 15m 26s) Loss: 0.0031(0.0065) Grad Norm: 120457.3203  LR: 9.877809250655069e-06


Epoch: [3][6700/8042] Elapsed 71m 48s (remain 14m 22s) Loss: 0.0071(0.0064) Grad Norm: 100922.0156  LR: 9.876734051973668e-06


Epoch: [3][6800/8042] Elapsed 72m 51s (remain 13m 17s) Loss: 0.0113(0.0064) Grad Norm: 462769.6875  LR: 9.875654202503e-06


Epoch: [3][6900/8042] Elapsed 73m 54s (remain 12m 13s) Loss: 0.0001(0.0064) Grad Norm: 5119.4995  LR: 9.874569703272885e-06


Epoch: [3][7000/8042] Elapsed 74m 57s (remain 11m 8s) Loss: 0.0007(0.0064) Grad Norm: 99111.4219  LR: 9.873480555317575e-06


Epoch: [3][7100/8042] Elapsed 75m 59s (remain 10m 4s) Loss: 0.0007(0.0064) Grad Norm: 41384.3672  LR: 9.872386759675757e-06


Epoch: [3][7200/8042] Elapsed 77m 3s (remain 8m 59s) Loss: 0.0128(0.0064) Grad Norm: 211750.4062  LR: 9.87128831739055e-06


Epoch: [3][7300/8042] Elapsed 78m 7s (remain 7m 55s) Loss: 0.0436(0.0064) Grad Norm: 381441.5625  LR: 9.870185229509506e-06


Epoch: [3][7400/8042] Elapsed 79m 12s (remain 6m 51s) Loss: 0.0011(0.0064) Grad Norm: 60191.6133  LR: 9.869077497084601e-06


Epoch: [3][7500/8042] Elapsed 80m 15s (remain 5m 47s) Loss: 0.0002(0.0064) Grad Norm: 9022.7812  LR: 9.867965121172248e-06


Epoch: [3][7600/8042] Elapsed 81m 18s (remain 4m 43s) Loss: 0.0058(0.0064) Grad Norm: 316575.0000  LR: 9.866848102833286e-06


Epoch: [3][7700/8042] Elapsed 82m 21s (remain 3m 38s) Loss: 0.0003(0.0063) Grad Norm: 11289.1709  LR: 9.865726443132978e-06


Epoch: [3][7800/8042] Elapsed 83m 25s (remain 2m 34s) Loss: 0.0019(0.0063) Grad Norm: 192919.7031  LR: 9.864600143141018e-06


Epoch: [3][7900/8042] Elapsed 84m 30s (remain 1m 30s) Loss: 0.0024(0.0063) Grad Norm: 48462.4805  LR: 9.863469203931522e-06


Epoch: [3][8000/8042] Elapsed 85m 34s (remain 0m 26s) Loss: 0.0075(0.0063) Grad Norm: 575394.5000  LR: 9.862333626583032e-06


Epoch: [3][8041/8042] Elapsed 86m 1s (remain 0m 0s) Loss: 0.0111(0.0063) Grad Norm: 394063.3438  LR: 9.86186669946739e-06
Epoch 3 completed. Average loss: 0.0063
Starting evaluation
Starting validation...


EVAL: [0/2011] Elapsed 0m 0s (remain 27m 22s) Loss: 0.0073(0.0073) 


EVAL: [100/2011] Elapsed 0m 45s (remain 14m 12s) Loss: -0.0003(0.0121) 


EVAL: [200/2011] Elapsed 1m 29s (remain 13m 23s) Loss: 0.0029(0.0142) 


EVAL: [300/2011] Elapsed 2m 13s (remain 12m 41s) Loss: 0.0123(0.0133) 


EVAL: [400/2011] Elapsed 2m 58s (remain 11m 56s) Loss: 0.0038(0.0113) 


EVAL: [500/2011] Elapsed 3m 43s (remain 11m 12s) Loss: -0.0000(0.0104) 


EVAL: [600/2011] Elapsed 4m 27s (remain 10m 27s) Loss: 0.0049(0.0092) 


EVAL: [700/2011] Elapsed 5m 12s (remain 9m 43s) Loss: 0.0001(0.0084) 


EVAL: [800/2011] Elapsed 5m 56s (remain 8m 58s) Loss: 0.0014(0.0078) 


EVAL: [900/2011] Elapsed 6m 40s (remain 8m 13s) Loss: 0.0066(0.0072) 


EVAL: [1000/2011] Elapsed 7m 25s (remain 7m 29s) Loss: 0.0018(0.0070) 


EVAL: [1100/2011] Elapsed 8m 9s (remain 6m 44s) Loss: 0.0090(0.0070) 


EVAL: [1200/2011] Elapsed 8m 53s (remain 6m 0s) Loss: 0.0001(0.0070) 


EVAL: [1300/2011] Elapsed 9m 38s (remain 5m 15s) Loss: 0.0067(0.0070) 


EVAL: [1400/2011] Elapsed 10m 23s (remain 4m 31s) Loss: 0.0032(0.0070) 


EVAL: [1500/2011] Elapsed 11m 7s (remain 3m 46s) Loss: 0.0042(0.0068) 


EVAL: [1600/2011] Elapsed 11m 52s (remain 3m 2s) Loss: 0.0007(0.0068) 


EVAL: [1700/2011] Elapsed 12m 36s (remain 2m 17s) Loss: 0.0017(0.0068) 


EVAL: [1800/2011] Elapsed 13m 21s (remain 1m 33s) Loss: -0.0000(0.0067) 


EVAL: [1900/2011] Elapsed 14m 7s (remain 0m 49s) Loss: 0.0019(0.0065) 


EVAL: [2000/2011] Elapsed 14m 51s (remain 0m 4s) Loss: 0.0024(0.0063) 


EVAL: [2010/2011] Elapsed 14m 55s (remain 0m 0s) Loss: 0.0046(0.0063) 


Epoch 3 - avg_train_loss: 0.0063  avg_val_loss: 0.0063  time: 6070s


Epoch 3 - Score: 0.9483


Epoch 3 - Save Best Score: 0.9483 Model


Starting training

Starting training epoch 4 for fold 0


Epoch: [4][0/8042] Elapsed 0m 1s (remain 165m 27s) Loss: 0.0024(0.0024) Grad Norm: 11719.9014  LR: 9.861855301263968e-06


Epoch: [4][100/8042] Elapsed 1m 6s (remain 86m 59s) Loss: 0.0005(0.0065) Grad Norm: 3971.9387  LR: 9.86071313961992e-06


Epoch: [4][200/8042] Elapsed 2m 12s (remain 85m 56s) Loss: 0.0281(0.0063) Grad Norm: 35084.8281  LR: 9.85956634246525e-06


Epoch: [4][300/8042] Elapsed 3m 18s (remain 84m 59s) Loss: 0.0021(0.0059) Grad Norm: 8668.1064  LR: 9.858414910893623e-06


Epoch: [4][400/8042] Elapsed 4m 23s (remain 83m 32s) Loss: 0.0064(0.0061) Grad Norm: 25814.1074  LR: 9.857258846003124e-06


Epoch: [4][500/8042] Elapsed 5m 26s (remain 81m 52s) Loss: 0.0019(0.0059) Grad Norm: 19153.7559  LR: 9.856098148896256e-06


Epoch: [4][600/8042] Elapsed 6m 30s (remain 80m 37s) Loss: 0.0018(0.0060) Grad Norm: 8213.3789  LR: 9.854932820679938e-06


Epoch: [4][700/8042] Elapsed 7m 34s (remain 79m 15s) Loss: 0.0078(0.0060) Grad Norm: 26792.7383  LR: 9.85376286246551e-06


Epoch: [4][800/8042] Elapsed 8m 36s (remain 77m 51s) Loss: 0.0030(0.0061) Grad Norm: 13909.8145  LR: 9.852588275368722e-06


Epoch: [4][900/8042] Elapsed 9m 40s (remain 76m 43s) Loss: 0.0005(0.0062) Grad Norm: 2011.5255  LR: 9.851409060509743e-06


Epoch: [4][1000/8042] Elapsed 10m 44s (remain 75m 34s) Loss: 0.0125(0.0062) Grad Norm: 54382.1836  LR: 9.850225219013153e-06


Epoch: [4][1100/8042] Elapsed 11m 47s (remain 74m 18s) Loss: 0.0015(0.0060) Grad Norm: 9490.8535  LR: 9.849036752007943e-06


Epoch: [4][1200/8042] Elapsed 12m 49s (remain 73m 2s) Loss: 0.0014(0.0060) Grad Norm: 5355.8530  LR: 9.847843660627521e-06


Epoch: [4][1300/8042] Elapsed 13m 52s (remain 71m 52s) Loss: 0.0011(0.0060) Grad Norm: 12461.3994  LR: 9.846645946009697e-06


Epoch: [4][1400/8042] Elapsed 14m 55s (remain 70m 43s) Loss: 0.0122(0.0059) Grad Norm: 16301.5869  LR: 9.845443609296694e-06


Epoch: [4][1500/8042] Elapsed 15m 58s (remain 69m 36s) Loss: 0.0001(0.0058) Grad Norm: 707.2938  LR: 9.844236651635147e-06


Epoch: [4][1600/8042] Elapsed 17m 1s (remain 68m 29s) Loss: 0.0075(0.0059) Grad Norm: 35617.9258  LR: 9.84302507417609e-06


Epoch: [4][1700/8042] Elapsed 18m 7s (remain 67m 32s) Loss: 0.0012(0.0058) Grad Norm: 7496.8027  LR: 9.841808878074968e-06


Epoch: [4][1800/8042] Elapsed 19m 9s (remain 66m 23s) Loss: 0.0007(0.0058) Grad Norm: 4063.8093  LR: 9.840588064491633e-06


Epoch: [4][1900/8042] Elapsed 20m 14s (remain 65m 21s) Loss: 0.0147(0.0058) Grad Norm: 56259.9141  LR: 9.839362634590329e-06


Epoch: [4][2000/8042] Elapsed 21m 18s (remain 64m 18s) Loss: 0.0013(0.0058) Grad Norm: 22060.4863  LR: 9.838132589539716e-06


Epoch: [4][2100/8042] Elapsed 22m 20s (remain 63m 10s) Loss: 0.0005(0.0058) Grad Norm: 8596.3398  LR: 9.83689793051285e-06


Epoch: [4][2200/8042] Elapsed 23m 23s (remain 62m 5s) Loss: 0.0048(0.0058) Grad Norm: 13070.9277  LR: 9.835658658687182e-06


Epoch: [4][2300/8042] Elapsed 24m 26s (remain 60m 58s) Loss: 0.0091(0.0058) Grad Norm: 193321.8125  LR: 9.834414775244572e-06


Epoch: [4][2400/8042] Elapsed 25m 27s (remain 59m 49s) Loss: 0.0010(0.0058) Grad Norm: 17368.1875  LR: 9.833166281371272e-06


Epoch: [4][2500/8042] Elapsed 26m 30s (remain 58m 44s) Loss: 0.0099(0.0057) Grad Norm: 27954.0625  LR: 9.831913178257928e-06


Epoch: [4][2600/8042] Elapsed 27m 35s (remain 57m 42s) Loss: 0.0041(0.0057) Grad Norm: 15004.2959  LR: 9.830655467099589e-06


Epoch: [4][2700/8042] Elapsed 28m 39s (remain 56m 40s) Loss: 0.0000(0.0057) Grad Norm: 149.6065  LR: 9.829393149095693e-06


Epoch: [4][2800/8042] Elapsed 29m 43s (remain 55m 38s) Loss: 0.0027(0.0057) Grad Norm: 31896.0430  LR: 9.828126225450077e-06


Epoch: [4][2900/8042] Elapsed 30m 49s (remain 54m 38s) Loss: 0.0206(0.0056) Grad Norm: 53417.9805  LR: 9.826854697370962e-06


Epoch: [4][3000/8042] Elapsed 31m 57s (remain 53m 41s) Loss: 0.0069(0.0055) Grad Norm: 60017.1289  LR: 9.825578566070966e-06


Epoch: [4][3100/8042] Elapsed 33m 3s (remain 52m 41s) Loss: 0.0088(0.0055) Grad Norm: 34616.3555  LR: 9.824297832767097e-06


Epoch: [4][3200/8042] Elapsed 34m 9s (remain 51m 40s) Loss: 0.0008(0.0055) Grad Norm: 20633.9375  LR: 9.82301249868075e-06


Epoch: [4][3300/8042] Elapsed 35m 15s (remain 50m 38s) Loss: 0.0017(0.0055) Grad Norm: 44090.9531  LR: 9.821722565037706e-06


Epoch: [4][3400/8042] Elapsed 36m 22s (remain 49m 38s) Loss: 0.0095(0.0055) Grad Norm: 28004.1172  LR: 9.820428033068137e-06


Epoch: [4][3500/8042] Elapsed 37m 27s (remain 48m 35s) Loss: 0.0008(0.0055) Grad Norm: 11318.6104  LR: 9.819128904006598e-06


Epoch: [4][3600/8042] Elapsed 38m 33s (remain 47m 32s) Loss: 0.0116(0.0055) Grad Norm: 93784.9062  LR: 9.817825179092026e-06


Epoch: [4][3700/8042] Elapsed 39m 40s (remain 46m 32s) Loss: 0.0064(0.0055) Grad Norm: 35961.1406  LR: 9.816516859567744e-06


Epoch: [4][3800/8042] Elapsed 40m 48s (remain 45m 32s) Loss: 0.0010(0.0054) Grad Norm: 90364.7031  LR: 9.815203946681456e-06


Epoch: [4][3900/8042] Elapsed 41m 56s (remain 44m 31s) Loss: 0.0035(0.0054) Grad Norm: 53250.1016  LR: 9.813886441685243e-06


Epoch: [4][4000/8042] Elapsed 43m 5s (remain 43m 30s) Loss: 0.0051(0.0054) Grad Norm: 70769.1328  LR: 9.812564345835573e-06


Epoch: [4][4100/8042] Elapsed 44m 13s (remain 42m 30s) Loss: 0.0107(0.0054) Grad Norm: 358027.9062  LR: 9.811237660393284e-06


Epoch: [4][4200/8042] Elapsed 45m 22s (remain 41m 28s) Loss: 0.0005(0.0054) Grad Norm: 7883.3647  LR: 9.809906386623598e-06


Epoch: [4][4300/8042] Elapsed 46m 31s (remain 40m 28s) Loss: 0.0001(0.0054) Grad Norm: 2111.5515  LR: 9.808570525796104e-06


Epoch: [4][4400/8042] Elapsed 47m 39s (remain 39m 25s) Loss: 0.0041(0.0054) Grad Norm: 65530.7070  LR: 9.807230079184777e-06


Epoch: [4][4500/8042] Elapsed 48m 45s (remain 38m 21s) Loss: 0.0002(0.0054) Grad Norm: 10191.3672  LR: 9.805885048067955e-06


Epoch: [4][4600/8042] Elapsed 49m 53s (remain 37m 18s) Loss: 0.0022(0.0053) Grad Norm: 39428.3281  LR: 9.804535433728353e-06


Epoch: [4][4700/8042] Elapsed 50m 59s (remain 36m 14s) Loss: 0.0269(0.0053) Grad Norm: 138111.7969  LR: 9.803181237453058e-06


Epoch: [4][4800/8042] Elapsed 52m 7s (remain 35m 11s) Loss: 0.0028(0.0053) Grad Norm: 16614.5215  LR: 9.801822460533523e-06


Epoch: [4][4900/8042] Elapsed 53m 15s (remain 34m 8s) Loss: 0.0019(0.0053) Grad Norm: 110382.6797  LR: 9.800459104265571e-06


Epoch: [4][5000/8042] Elapsed 54m 23s (remain 33m 4s) Loss: 0.0086(0.0053) Grad Norm: 652237.1250  LR: 9.799091169949393e-06


Epoch: [4][5100/8042] Elapsed 55m 31s (remain 32m 1s) Loss: 0.0098(0.0053) Grad Norm: 269340.5312  LR: 9.797718658889545e-06


Epoch: [4][5200/8042] Elapsed 56m 39s (remain 30m 56s) Loss: 0.0081(0.0054) Grad Norm: 129488.2031  LR: 9.796341572394949e-06


Epoch: [4][5300/8042] Elapsed 57m 46s (remain 29m 52s) Loss: 0.0000(0.0054) Grad Norm: 739.3557  LR: 9.794959911778888e-06


Epoch: [4][5400/8042] Elapsed 58m 53s (remain 28m 47s) Loss: 0.0018(0.0054) Grad Norm: 33012.0430  LR: 9.793573678359009e-06


Epoch: [4][5500/8042] Elapsed 60m 2s (remain 27m 43s) Loss: 0.0024(0.0054) Grad Norm: 25034.1035  LR: 9.792182873457322e-06


Epoch: [4][5600/8042] Elapsed 61m 10s (remain 26m 39s) Loss: 0.0143(0.0054) Grad Norm: 254348.1406  LR: 9.790787498400192e-06


Epoch: [4][5700/8042] Elapsed 62m 20s (remain 25m 35s) Loss: 0.0015(0.0054) Grad Norm: 40443.0469  LR: 9.789387554518343e-06


Epoch: [4][5800/8042] Elapsed 63m 29s (remain 24m 31s) Loss: 0.0026(0.0054) Grad Norm: 47627.0781  LR: 9.787983043146863e-06


Epoch: [4][5900/8042] Elapsed 64m 35s (remain 23m 26s) Loss: 0.0063(0.0053) Grad Norm: 105478.2266  LR: 9.786573965625186e-06


Epoch: [4][6000/8042] Elapsed 65m 41s (remain 22m 20s) Loss: 0.0035(0.0053) Grad Norm: 82880.1797  LR: 9.785160323297107e-06


Epoch: [4][6100/8042] Elapsed 66m 45s (remain 21m 14s) Loss: 0.0000(0.0053) Grad Norm: 1082.4861  LR: 9.783742117510772e-06


Epoch: [4][6200/8042] Elapsed 67m 48s (remain 20m 7s) Loss: 0.0013(0.0053) Grad Norm: 84880.3438  LR: 9.782319349618681e-06


Epoch: [4][6300/8042] Elapsed 68m 53s (remain 19m 2s) Loss: 0.0018(0.0053) Grad Norm: 236416.1250  LR: 9.780892020977682e-06


Epoch: [4][6400/8042] Elapsed 69m 57s (remain 17m 56s) Loss: 0.0010(0.0053) Grad Norm: 51404.7500  LR: 9.779460132948974e-06


Epoch: [4][6500/8042] Elapsed 71m 0s (remain 16m 49s) Loss: 0.0028(0.0053) Grad Norm: 64815.2812  LR: 9.778023686898106e-06


Epoch: [4][6600/8042] Elapsed 72m 4s (remain 15m 44s) Loss: 0.0064(0.0053) Grad Norm: 184972.4062  LR: 9.77658268419497e-06


Epoch: [4][6700/8042] Elapsed 73m 9s (remain 14m 38s) Loss: 0.0011(0.0053) Grad Norm: 62624.5820  LR: 9.775137126213805e-06


Epoch: [4][6800/8042] Elapsed 74m 14s (remain 13m 32s) Loss: 0.0025(0.0053) Grad Norm: 39332.1055  LR: 9.773687014333194e-06


Epoch: [4][6900/8042] Elapsed 75m 19s (remain 12m 27s) Loss: 0.0146(0.0053) Grad Norm: 163456.7188  LR: 9.772232349936067e-06


Epoch: [4][7000/8042] Elapsed 76m 22s (remain 11m 21s) Loss: 0.0020(0.0053) Grad Norm: 146190.1719  LR: 9.770773134409688e-06


Epoch: [4][7100/8042] Elapsed 77m 28s (remain 10m 15s) Loss: 0.0007(0.0053) Grad Norm: 35188.7812  LR: 9.769309369145667e-06


Epoch: [4][7200/8042] Elapsed 78m 34s (remain 9m 10s) Loss: 0.0384(0.0053) Grad Norm: 1012404.6250  LR: 9.767841055539952e-06


Epoch: [4][7300/8042] Elapsed 79m 39s (remain 8m 5s) Loss: 0.0001(0.0053) Grad Norm: 5900.2134  LR: 9.766368194992828e-06


Epoch: [4][7400/8042] Elapsed 80m 44s (remain 6m 59s) Loss: 0.0006(0.0053) Grad Norm: 51682.2656  LR: 9.764890788908916e-06


Epoch: [4][7500/8042] Elapsed 81m 52s (remain 5m 54s) Loss: 0.0074(0.0052) Grad Norm: 596826.8750  LR: 9.76340883869717e-06


Epoch: [4][7600/8042] Elapsed 82m 57s (remain 4m 48s) Loss: 0.0001(0.0052) Grad Norm: 21822.2891  LR: 9.761922345770883e-06


Epoch: [4][7700/8042] Elapsed 84m 2s (remain 3m 43s) Loss: 0.0012(0.0052) Grad Norm: 86300.1641  LR: 9.760431311547675e-06


Epoch: [4][7800/8042] Elapsed 85m 7s (remain 2m 37s) Loss: 0.0004(0.0052) Grad Norm: 34925.7617  LR: 9.7589357374495e-06


Epoch: [4][7900/8042] Elapsed 86m 11s (remain 1m 32s) Loss: 0.0040(0.0052) Grad Norm: 126483.8281  LR: 9.757435624902638e-06


Epoch: [4][8000/8042] Elapsed 87m 13s (remain 0m 26s) Loss: 0.0007(0.0052) Grad Norm: 111470.1953  LR: 9.755930975337703e-06


Epoch: [4][8041/8042] Elapsed 87m 39s (remain 0m 0s) Loss: 0.0016(0.0052) Grad Norm: 113459.7969  LR: 9.755312757924054e-06
Epoch 4 completed. Average loss: 0.0052
Starting evaluation
Starting validation...


EVAL: [0/2011] Elapsed 0m 0s (remain 28m 7s) Loss: 0.0068(0.0068) 


EVAL: [100/2011] Elapsed 0m 45s (remain 14m 16s) Loss: -0.0002(0.0126) 


EVAL: [200/2011] Elapsed 1m 28s (remain 13m 19s) Loss: 0.0028(0.0155) 


EVAL: [300/2011] Elapsed 2m 12s (remain 12m 31s) Loss: 0.0050(0.0145) 


EVAL: [400/2011] Elapsed 2m 56s (remain 11m 47s) Loss: 0.0030(0.0123) 


EVAL: [500/2011] Elapsed 3m 41s (remain 11m 6s) Loss: 0.0002(0.0113) 


EVAL: [600/2011] Elapsed 4m 25s (remain 10m 23s) Loss: 0.0039(0.0100) 


EVAL: [700/2011] Elapsed 5m 10s (remain 9m 39s) Loss: 0.0004(0.0092) 


EVAL: [800/2011] Elapsed 5m 55s (remain 8m 57s) Loss: 0.0013(0.0085) 


EVAL: [900/2011] Elapsed 6m 40s (remain 8m 13s) Loss: 0.0058(0.0079) 


EVAL: [1000/2011] Elapsed 7m 24s (remain 7m 28s) Loss: -0.0003(0.0076) 


EVAL: [1100/2011] Elapsed 8m 8s (remain 6m 43s) Loss: 0.0154(0.0077) 


EVAL: [1200/2011] Elapsed 8m 52s (remain 5m 59s) Loss: 0.0003(0.0076) 


EVAL: [1300/2011] Elapsed 9m 36s (remain 5m 14s) Loss: 0.0106(0.0076) 


EVAL: [1400/2011] Elapsed 10m 21s (remain 4m 30s) Loss: 0.0004(0.0076) 


EVAL: [1500/2011] Elapsed 11m 5s (remain 3m 46s) Loss: 0.0032(0.0074) 


EVAL: [1600/2011] Elapsed 11m 49s (remain 3m 1s) Loss: 0.0006(0.0075) 


EVAL: [1700/2011] Elapsed 12m 33s (remain 2m 17s) Loss: 0.0038(0.0074) 


EVAL: [1800/2011] Elapsed 13m 17s (remain 1m 32s) Loss: 0.0004(0.0073) 


EVAL: [1900/2011] Elapsed 14m 1s (remain 0m 48s) Loss: 0.0011(0.0071) 


EVAL: [2000/2011] Elapsed 14m 45s (remain 0m 4s) Loss: 0.0015(0.0069) 


EVAL: [2010/2011] Elapsed 14m 50s (remain 0m 0s) Loss: 0.0024(0.0069) 


Epoch 4 - avg_train_loss: 0.0052  avg_val_loss: 0.0069  time: 6164s


Epoch 4 - Score: 0.9504


Epoch 4 - Save Best Score: 0.9504 Model




Score: 0.9504




Score: 0.9504


[34m[1mwandb[0m:                                                                                


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m: [fold0] avg_train_loss █▃▂▁
[34m[1mwandb[0m:   [fold0] avg_val_loss █▂▁▂
[34m[1mwandb[0m:          [fold0] epoch ▁▃▆█
[34m[1mwandb[0m:          [fold0] score ▁▆██
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m: [fold0] avg_train_loss 0.00521
[34m[1mwandb[0m:   [fold0] avg_val_loss 0.0069
[34m[1mwandb[0m:          [fold0] epoch 4
[34m[1mwandb[0m:          [fold0] score 0.95044
[34m[1mwandb[0m: 


[34m[1mwandb[0m: 🚀 View run [33manswerdotai/ModernBERT-base[0m at: [34m[4mhttps://wandb.ai/project-zero/NBME-Public/runs/4ym6uuzz?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/project-zero/NBME-Public?apiKey=93189478594090aa7c369a5a7c30ee204eb7d192[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20250211_224808-4ym6uuzz/logs[0m
