# Training deberta-v3-large model to generate high quality pseudolabels

Initially deberta-v3-large model raises OOM error. 

Optimized the training pipeline via Pytorch
- Gradient accumulation
- Gradient checkpoint
- Using bfloat data types
- Autocast

1 fold (with 4-5 epochs) takes around ~5 hours, each epoch ~50 minutes. %87+ accuracy. (See notebook version 6)

Another approach would be to use [huggingface accelerator](https://huggingface.co/docs/accelerate/en/index) as it automates this process without the need to customize manually as well as distribution of training while using multiple GPUs.

# Why Deberta-v3-large?
![Comparison of BERT variants in 2025](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/modernbert/modernbert_pareto_curve.png)

# Directory settings

In [1]:
# ====================================================
# Directory settings
# ====================================================
import os, shutil

OUTPUT_DIR = os.path.join(os.path.dirname(os.getcwd()), 'deberta-v3-large-finetuned-with-pseudolabels/')
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

for item in os.listdir(OUTPUT_DIR):
    item_path = os.path.join(OUTPUT_DIR, item)
    if os.path.isfile(item_path):  # If it's a file, remove it
        os.remove(item_path)
        print(f"Removed file: {item}")
    elif os.path.isdir(item_path):  # If it's a directory, remove it
        shutil.rmtree(item_path)
        print(f"Removed folder: {item}")

Removed folder: tokenizer
Removed file: train.log
Removed file: config.pth


# Library

In [2]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset

os.system('pip uninstall -y transformers')
os.system('python -m pip install transformers>=4.48.0')
os.system('pip uninstall -y tokenizer')
os.system('python -m pip install tokenizers')
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Found existing installation: transformers 4.48.3
Uninstalling transformers-4.48.3:


  Successfully uninstalled transformers-4.48.3


[0m

[0m

[0m



[0m

tokenizers.__version__: 0.21.0
transformers.__version__: 4.48.3


env: TOKENIZERS_PARALLELISM=true


In [3]:
# ====================================================
# CFG
# ====================================================
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" # if running on Kaggle, use this path 
    model = "/workspace/deberta-v3-large" # if running on local machine, use this path  
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=1 #0.05
    num_warmup_steps=0
    epochs=5
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 12 #batch_size=16
    fc_dropout=0.2
    max_len=512
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=2 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True
    
if CFG.debug:
    CFG.epochs = 3
    CFG.trn_fold = [0]

# ====================================================
# tokenizer
# ====================================================
tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

In [4]:
# ====================================================
# wandb: local notebook login
# ====================================================
if CFG.wandb:
    
    import wandb
    from dotenv import load_dotenv
    import os

    # Load environment variables from .env file in the parent folder
    load_dotenv(os.path.join(os.path.dirname(os.getcwd()), '.env'))

    secret_value_0 = os.getenv("WANDB_API_KEY")
    if secret_value_0:
        wandb.login(key=secret_value_0)
        anony = None
    else:
        anony = "must"
        print('If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. \nGet your W&B access token from here: https://wandb.ai/authorize')

    def class2dict(f):
        return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

    run = wandb.init(project='NBME-Public', 
                     name=CFG.model,
                     config=class2dict(CFG),
                     group=CFG.model,
                     job_type="train",
                     anonymous=anony)

If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. 
Get your W&B access token from here: https://wandb.ai/authorize


[34m[1mwandb[0m: Currently logged in as: [33manony-moose-628529071760834444[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: - Waiting for wandb.init()...

[34m[1mwandb[0m: \ Waiting for wandb.init()...

[34m[1mwandb[0m: | Waiting for wandb.init()...

[34m[1mwandb[0m: / Waiting for wandb.init()...

[34m[1mwandb[0m: - Waiting for wandb.init()...

[34m[1mwandb[0m: Tracking run with wandb version 0.19.6


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/workspace/wandb/run-20250211_210932-kl09y3t4[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33m/workspace/deberta-v3-large[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/anony-moose-628529071760834444/NBME-Public?apiKey=e17c3e277c305c3106a5fedd7c64321c1329fc02[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/anony-moose-628529071760834444/NBME-Public/runs/kl09y3t4?apiKey=e17c3e277c305c3106a5fedd7c64321c1329fc02[0m




# Helper functions for scoring

In [5]:
# From https://www.kaggle.com/theoviel/evaluation-metric-folds-baseline

def micro_f1(preds, truths):
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


def spans_to_binary(spans, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary


def span_micro_f1(preds, truths):
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
    """
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))
    return micro_f1(bin_preds, bin_truths)

In [6]:
## UPDATED FOR HANDLING NONEs in THE PSEUDOLABELS / only including confident inferences
import ast

# Update the create_labels_for_scoring function
def create_labels_for_scoring(df):
    # Initialize the 'location_for_create_labels' column with empty lists
    df['location_for_create_labels'] = [ast.literal_eval(f'[]')] * len(df)
    
    for i in range(len(df)):
        lst = df.loc[i, 'location']
        if lst and lst != '':  # Check if lst is not None or empty
            # Ensure lst is a list of strings
            if isinstance(lst, str):
                lst = [lst]
            elif isinstance(lst, list):
                lst = [str(item) for item in lst]
            new_lst = ';'.join(lst)
            df.loc[i, 'location_for_create_labels'] = ast.literal_eval(f'[["{new_lst}"]]')

    # Create labels
    truths = []
    for location_list in df['location_for_create_labels'].values:
        truth = []
        if len(location_list) > 0:
            location = location_list[0]
            for loc in [s.split() for s in location.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                truth.append([start, end])
        truths.append(truth)
    
    return truths


def get_char_probs(texts, predictions, tokenizer):
    results = [np.zeros(len(t)) for t in texts]
    for i, (text, prediction) in enumerate(zip(texts, predictions)):
        encoded = tokenizer(text, 
                            add_special_tokens=True,
                            return_offsets_mapping=True)
        for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)):
            start = offset_mapping[0]
            end = offset_mapping[1]
            results[i][start:end] = pred
    return results


def get_results(char_probs, th=0.5):
    results = []
    for char_prob in char_probs:
        result = np.where(char_prob >= th)[0] + 1
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results


def get_predictions(results):
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    return predictions

# Utils

In [7]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    score = span_micro_f1(y_true, y_pred)
    return score


def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

# Data Loading

In [8]:
# ====================================================
# Data Loading
# ====================================================
train = pd.read_csv('/workspace/data/train.csv')
train['annotation'] = train['annotation'].apply(ast.literal_eval)
train['location'] = train['location'].apply(ast.literal_eval)
features = pd.read_csv('/workspace/data/features.csv')
def preprocess_features(features):
    features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
    return features
features = preprocess_features(features)
patient_notes = pd.read_csv('/workspace/data/patient_notes.csv')

print(f"train.shape: {train.shape}")
display(train.head())
print(f"features.shape: {features.shape}")
display(features.head())
print(f"patient_notes.shape: {patient_notes.shape}")
display(patient_notes.head())


train.shape: (14300, 6)


Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724]
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693]
2,00016_002,0,16,2,[chest pressure],[203 217]
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]"
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258]


features.shape: (143, 3)


Unnamed: 0,feature_num,case_num,feature_text
0,0,0,Family-history-of-MI-OR-Family-history-of-myoc...
1,1,0,Family-history-of-thyroid-disorder
2,2,0,Chest-pressure
3,3,0,Intermittent-symptoms
4,4,0,Lightheaded


patient_notes.shape: (42146, 3)


Unnamed: 0,pn_num,case_num,pn_history
0,0,0,"17-year-old male, has come to the student heal..."
1,1,0,17 yo male with recurrent palpitations for the...
2,2,0,Dillon Cleveland is a 17 y.o. male patient wit...
3,3,0,a 17 yo m c/o palpitation started 3 mos ago; \...
4,4,0,17yo male with no pmh here for evaluation of p...


In [9]:
train['annotation_length'] = train['annotation'].apply(len) # refers to number of annotation
train

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,annotation_length
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],1
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],1
2,00016_002,0,16,2,[chest pressure],[203 217],1
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",2
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],1
...,...,...,...,...,...,...,...
14295,95333_912,9,95333,912,[],[],0
14296,95333_913,9,95333,913,[],[],0
14297,95333_914,9,95333,914,[photobia],[274 282],1
14298,95333_915,9,95333,915,[no sick contacts],[421 437],1


In [10]:
pseudolabels = pd.read_csv('/workspace/data/pseudolabels.csv')
pseudolabels =  pseudolabels.merge(features, on=['feature_num', 'case_num'], how='left')
pseudolabels = pseudolabels.merge(patient_notes, on=['pn_num', 'case_num'], how='left')

pseudolabels = pseudolabels.dropna(subset=['location'])
pseudolabels = pseudolabels[pseudolabels['location'].apply(lambda x: len(x) > 0)]

def convert_location_format(location):
    if pd.isna(location) or not location:
        return []
    if isinstance(location, str):
        return [loc.strip() for loc in location.split(';')]
    return location

pseudolabels['location'] = pseudolabels['location'].apply(convert_location_format)

pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal..."
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...
...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...


In [11]:
pseudolabels['annotation_length'] = pseudolabels['location'].apply(len)

In [12]:
pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history,annotation_length
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal...",1
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...,1
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...,1
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...,1
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...,1
...,...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,2
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,1
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...,1
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...,1


In [13]:
train = pseudolabels
# drop fold column
train = train.drop(columns=['fold'])

# CV split

In [14]:
# ====================================================
# CV split
# ====================================================
train = train.reset_index(drop=True)

Fold = GroupKFold(n_splits=CFG.n_fold)
groups = train['pn_num'].values
for n, (train_index, val_index) in enumerate(Fold.split(train, train['location'], groups)):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)
display(train.groupby('fold').size())

fold
0    17247
1    17247
2    17246
3    17246
4    17246
dtype: int64

In [15]:
if CFG.debug:
    display(train.groupby('fold').size())
    train = train.sample(n=1000, random_state=0).reset_index(drop=True)
    display(train.groupby('fold').size())

# Dataset

In [16]:
# ====================================================
# Define max_len
# ====================================================
for text_col in ['pn_history']:
    pn_history_lengths = []
    tk0 = tqdm(patient_notes[text_col].fillna("").values, total=len(patient_notes))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        pn_history_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(pn_history_lengths)}')

for text_col in ['feature_text']:
    features_lengths = []
    tk0 = tqdm(features[text_col].fillna("").values, total=len(features))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        features_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(features_lengths)}')

CFG.max_len = max(pn_history_lengths) + max(features_lengths) + 3 # cls & sep & sep
LOGGER.info(f"max_len: {CFG.max_len}")

  0%|          | 0/42146 [00:00<?, ?it/s]

  0%|          | 196/42146 [00:00<00:21, 1956.37it/s]

  1%|          | 456/42146 [00:00<00:17, 2330.86it/s]

  2%|▏         | 691/42146 [00:00<00:17, 2337.35it/s]

  2%|▏         | 965/42146 [00:00<00:16, 2491.95it/s]

  3%|▎         | 1231/42146 [00:00<00:16, 2550.71it/s]

  4%|▎         | 1499/42146 [00:00<00:15, 2592.09it/s]

  4%|▍         | 1782/42146 [00:00<00:15, 2669.15it/s]

  5%|▍         | 2049/42146 [00:00<00:15, 2661.39it/s]

  6%|▌         | 2323/42146 [00:00<00:14, 2683.93it/s]

  6%|▌         | 2592/42146 [00:01<00:15, 2561.92it/s]

  7%|▋         | 2850/42146 [00:01<00:15, 2459.03it/s]

  7%|▋         | 3098/42146 [00:01<00:16, 2409.40it/s]

  8%|▊         | 3340/42146 [00:01<00:16, 2324.81it/s]

  8%|▊         | 3574/42146 [00:01<00:17, 2256.53it/s]

  9%|▉         | 3801/42146 [00:01<00:17, 2207.47it/s]

 10%|▉         | 4023/42146 [00:01<00:17, 2168.56it/s]

 10%|█         | 4241/42146 [00:01<00:18, 2055.35it/s]

 11%|█         | 4448/42146 [00:01<00:18, 2033.62it/s]

 11%|█         | 4652/42146 [00:02<00:18, 2018.75it/s]

 12%|█▏        | 4855/42146 [00:02<00:18, 2018.01it/s]

 12%|█▏        | 5058/42146 [00:02<00:18, 1968.45it/s]

 13%|█▎        | 5269/42146 [00:02<00:18, 2008.77it/s]

 13%|█▎        | 5486/42146 [00:02<00:17, 2053.48it/s]

 14%|█▎        | 5707/42146 [00:02<00:17, 2096.05it/s]

 14%|█▍        | 5940/42146 [00:02<00:16, 2162.94it/s]

 15%|█▍        | 6177/42146 [00:02<00:16, 2223.19it/s]

 15%|█▌        | 6411/42146 [00:02<00:15, 2255.69it/s]

 16%|█▌        | 6646/42146 [00:02<00:15, 2283.18it/s]

 16%|█▋        | 6882/42146 [00:03<00:15, 2305.81it/s]

 17%|█▋        | 7114/42146 [00:03<00:15, 2309.72it/s]

 17%|█▋        | 7346/42146 [00:03<00:15, 2305.11it/s]

 18%|█▊        | 7577/42146 [00:03<00:15, 2302.55it/s]

 19%|█▊        | 7809/42146 [00:03<00:14, 2306.03it/s]

 19%|█▉        | 8041/42146 [00:03<00:14, 2309.62it/s]

 20%|█▉        | 8272/42146 [00:03<00:14, 2296.94it/s]

 20%|██        | 8502/42146 [00:03<00:14, 2294.68it/s]

 21%|██        | 8732/42146 [00:03<00:14, 2292.10it/s]

 21%|██▏       | 8962/42146 [00:03<00:14, 2244.41it/s]

 22%|██▏       | 9187/42146 [00:04<00:14, 2221.57it/s]

 22%|██▏       | 9414/42146 [00:04<00:14, 2235.64it/s]

 23%|██▎       | 9645/42146 [00:04<00:14, 2256.16it/s]

 23%|██▎       | 9885/42146 [00:04<00:14, 2297.51it/s]

 24%|██▍       | 10130/42146 [00:04<00:13, 2342.57it/s]

 25%|██▍       | 10370/42146 [00:04<00:13, 2359.30it/s]

 25%|██▌       | 10607/42146 [00:04<00:13, 2320.05it/s]

 26%|██▌       | 10840/42146 [00:04<00:14, 2196.41it/s]

 26%|██▋       | 11066/42146 [00:04<00:14, 2213.22it/s]

 27%|██▋       | 11293/42146 [00:04<00:13, 2227.16it/s]

 27%|██▋       | 11529/42146 [00:05<00:13, 2265.96it/s]

 28%|██▊       | 11767/42146 [00:05<00:13, 2296.27it/s]

 28%|██▊       | 12004/42146 [00:05<00:13, 2316.42it/s]

 29%|██▉       | 12242/42146 [00:05<00:12, 2332.55it/s]

 30%|██▉       | 12476/42146 [00:05<00:12, 2332.99it/s]

 30%|███       | 12710/42146 [00:05<00:12, 2331.22it/s]

 31%|███       | 12946/42146 [00:05<00:12, 2339.72it/s]

 31%|███▏      | 13181/42146 [00:05<00:12, 2325.06it/s]

 32%|███▏      | 13414/42146 [00:05<00:12, 2302.90it/s]

 32%|███▏      | 13649/42146 [00:05<00:12, 2313.58it/s]

 33%|███▎      | 13881/42146 [00:06<00:12, 2315.45it/s]

 33%|███▎      | 14115/42146 [00:06<00:12, 2320.71it/s]

 34%|███▍      | 14348/42146 [00:06<00:11, 2321.02it/s]

 35%|███▍      | 14581/42146 [00:06<00:11, 2321.69it/s]

 35%|███▌      | 14814/42146 [00:06<00:11, 2315.03it/s]

 36%|███▌      | 15046/42146 [00:06<00:11, 2288.87it/s]

 36%|███▌      | 15275/42146 [00:06<00:11, 2281.83it/s]

 37%|███▋      | 15504/42146 [00:06<00:11, 2280.35it/s]

 37%|███▋      | 15733/42146 [00:06<00:11, 2255.32it/s]

 38%|███▊      | 15959/42146 [00:06<00:11, 2240.28it/s]

 38%|███▊      | 16184/42146 [00:07<00:11, 2198.00it/s]

 39%|███▉      | 16404/42146 [00:07<00:11, 2194.23it/s]

 39%|███▉      | 16624/42146 [00:07<00:11, 2186.67it/s]

 40%|███▉      | 16843/42146 [00:07<00:11, 2174.55it/s]

 40%|████      | 17065/42146 [00:07<00:11, 2187.62it/s]

 41%|████      | 17287/42146 [00:07<00:11, 2194.91it/s]

 42%|████▏     | 17514/42146 [00:07<00:11, 2214.36it/s]

 42%|████▏     | 17736/42146 [00:07<00:11, 2205.20it/s]

 43%|████▎     | 17961/42146 [00:07<00:10, 2217.42it/s]

 43%|████▎     | 18188/42146 [00:08<00:10, 2230.88it/s]

 44%|████▎     | 18412/42146 [00:08<00:10, 2227.10it/s]

 44%|████▍     | 18635/42146 [00:08<00:10, 2224.60it/s]

 45%|████▍     | 18860/42146 [00:08<00:10, 2230.06it/s]

 45%|████▌     | 19084/42146 [00:08<00:10, 2209.67it/s]

 46%|████▌     | 19306/42146 [00:08<00:10, 2201.31it/s]

 46%|████▋     | 19529/42146 [00:08<00:10, 2207.10it/s]

 47%|████▋     | 19756/42146 [00:08<00:10, 2224.29it/s]

 47%|████▋     | 19979/42146 [00:08<00:09, 2221.76it/s]

 48%|████▊     | 20208/42146 [00:08<00:09, 2240.65it/s]

 49%|████▊     | 20448/42146 [00:09<00:09, 2286.97it/s]

 49%|████▉     | 20700/42146 [00:09<00:09, 2354.39it/s]

 50%|████▉     | 20950/42146 [00:09<00:08, 2394.43it/s]

 50%|█████     | 21200/42146 [00:09<00:08, 2425.08it/s]

 51%|█████     | 21446/42146 [00:09<00:08, 2432.23it/s]

 51%|█████▏    | 21690/42146 [00:09<00:08, 2409.48it/s]

 52%|█████▏    | 21934/42146 [00:09<00:08, 2416.75it/s]

 53%|█████▎    | 22176/42146 [00:09<00:08, 2408.71it/s]

 53%|█████▎    | 22418/42146 [00:09<00:08, 2410.11it/s]

 54%|█████▍    | 22661/42146 [00:09<00:08, 2414.32it/s]

 54%|█████▍    | 22903/42146 [00:10<00:07, 2411.41it/s]

 55%|█████▍    | 23148/42146 [00:10<00:07, 2421.69it/s]

 56%|█████▌    | 23393/42146 [00:10<00:07, 2427.96it/s]

 56%|█████▌    | 23641/42146 [00:10<00:07, 2442.27it/s]

 57%|█████▋    | 23886/42146 [00:10<00:07, 2439.12it/s]

 57%|█████▋    | 24130/42146 [00:10<00:07, 2438.64it/s]

 58%|█████▊    | 24374/42146 [00:10<00:07, 2427.70it/s]

 58%|█████▊    | 24617/42146 [00:10<00:07, 2414.36it/s]

 59%|█████▉    | 24861/42146 [00:10<00:07, 2419.69it/s]

 60%|█████▉    | 25107/42146 [00:10<00:07, 2430.63it/s]

 60%|██████    | 25351/42146 [00:11<00:06, 2420.89it/s]

 61%|██████    | 25597/42146 [00:11<00:06, 2432.26it/s]

 61%|██████▏   | 25841/42146 [00:11<00:06, 2427.96it/s]

 62%|██████▏   | 26084/42146 [00:11<00:06, 2426.98it/s]

 62%|██████▏   | 26328/42146 [00:11<00:06, 2428.43it/s]

 63%|██████▎   | 26571/42146 [00:11<00:06, 2339.61it/s]

 64%|██████▎   | 26811/42146 [00:11<00:06, 2355.25it/s]

 64%|██████▍   | 27056/42146 [00:11<00:06, 2382.61it/s]

 65%|██████▍   | 27295/42146 [00:11<00:06, 2380.35it/s]

 65%|██████▌   | 27535/42146 [00:11<00:06, 2383.33it/s]

 66%|██████▌   | 27777/42146 [00:12<00:06, 2392.21it/s]

 66%|██████▋   | 28017/42146 [00:12<00:06, 2352.76it/s]

 67%|██████▋   | 28253/42146 [00:12<00:05, 2339.00it/s]

 68%|██████▊   | 28488/42146 [00:12<00:05, 2312.77it/s]

 68%|██████▊   | 28725/42146 [00:12<00:05, 2328.18it/s]

 69%|██████▊   | 28958/42146 [00:12<00:05, 2268.85it/s]

 69%|██████▉   | 29186/42146 [00:12<00:05, 2240.22it/s]

 70%|██████▉   | 29411/42146 [00:12<00:05, 2225.85it/s]

 70%|███████   | 29634/42146 [00:12<00:05, 2219.34it/s]

 71%|███████   | 29857/42146 [00:12<00:05, 2209.48it/s]

 71%|███████▏  | 30079/42146 [00:13<00:05, 2207.48it/s]

 72%|███████▏  | 30300/42146 [00:13<00:05, 2199.50it/s]

 72%|███████▏  | 30522/42146 [00:13<00:05, 2202.75it/s]

 73%|███████▎  | 30743/42146 [00:13<00:05, 2198.79it/s]

 73%|███████▎  | 30963/42146 [00:13<00:05, 2194.26it/s]

 74%|███████▍  | 31183/42146 [00:13<00:04, 2195.87it/s]

 75%|███████▍  | 31403/42146 [00:13<00:04, 2193.45it/s]

 75%|███████▌  | 31626/42146 [00:13<00:04, 2203.01it/s]

 76%|███████▌  | 31849/42146 [00:13<00:04, 2210.13it/s]

 76%|███████▌  | 32071/42146 [00:13<00:04, 2192.39it/s]

 77%|███████▋  | 32291/42146 [00:14<00:04, 2185.03it/s]

 77%|███████▋  | 32512/42146 [00:14<00:04, 2189.75it/s]

 78%|███████▊  | 32731/42146 [00:14<00:04, 2179.15it/s]

 78%|███████▊  | 32949/42146 [00:14<00:04, 2161.83it/s]

 79%|███████▊  | 33166/42146 [00:14<00:04, 2140.59it/s]

 79%|███████▉  | 33381/42146 [00:14<00:04, 2105.29it/s]

 80%|███████▉  | 33592/42146 [00:14<00:04, 2087.11it/s]

 80%|████████  | 33801/42146 [00:14<00:04, 2081.90it/s]

 81%|████████  | 34010/42146 [00:14<00:03, 2082.34it/s]

 81%|████████  | 34221/42146 [00:15<00:03, 2088.11it/s]

 82%|████████▏ | 34430/42146 [00:15<00:03, 2084.49it/s]

 82%|████████▏ | 34639/42146 [00:15<00:03, 2078.04it/s]

 83%|████████▎ | 34847/42146 [00:15<00:03, 2075.71it/s]

 83%|████████▎ | 35055/42146 [00:15<00:03, 2066.64it/s]

 84%|████████▎ | 35264/42146 [00:15<00:03, 2072.31it/s]

 84%|████████▍ | 35472/42146 [00:15<00:03, 2064.58it/s]

 85%|████████▍ | 35682/42146 [00:15<00:03, 2073.83it/s]

 85%|████████▌ | 35904/42146 [00:15<00:02, 2115.20it/s]

 86%|████████▌ | 36121/42146 [00:15<00:02, 2128.80it/s]

 86%|████████▌ | 36334/42146 [00:16<00:02, 2128.54it/s]

 87%|████████▋ | 36555/42146 [00:16<00:02, 2152.76it/s]

 87%|████████▋ | 36786/42146 [00:16<00:02, 2198.93it/s]

 88%|████████▊ | 37011/42146 [00:16<00:02, 2213.03it/s]

 88%|████████▊ | 37283/42146 [00:16<00:02, 2364.70it/s]

 89%|████████▉ | 37549/42146 [00:16<00:01, 2452.71it/s]

 90%|████████▉ | 37829/42146 [00:16<00:01, 2555.58it/s]

 90%|█████████ | 38106/42146 [00:16<00:01, 2618.45it/s]

 91%|█████████ | 38368/42146 [00:16<00:01, 2605.67it/s]

 92%|█████████▏| 38638/42146 [00:16<00:01, 2630.95it/s]

 92%|█████████▏| 38907/42146 [00:17<00:01, 2646.92it/s]

 93%|█████████▎| 39173/42146 [00:17<00:01, 2650.50it/s]

 94%|█████████▎| 39439/42146 [00:17<00:01, 2606.15it/s]

 94%|█████████▍| 39700/42146 [00:17<00:00, 2580.73it/s]

 95%|█████████▍| 39959/42146 [00:17<00:00, 2579.02it/s]

 95%|█████████▌| 40218/42146 [00:17<00:00, 2572.93it/s]

 96%|█████████▌| 40476/42146 [00:17<00:00, 2544.26it/s]

 97%|█████████▋| 40731/42146 [00:17<00:00, 2520.72it/s]

 97%|█████████▋| 40984/42146 [00:17<00:00, 2411.43it/s]

 98%|█████████▊| 41227/42146 [00:17<00:00, 2338.38it/s]

 98%|█████████▊| 41462/42146 [00:18<00:00, 2325.83it/s]

 99%|█████████▉| 41696/42146 [00:18<00:00, 2328.16it/s]

100%|█████████▉| 41936/42146 [00:18<00:00, 2347.98it/s]

100%|██████████| 42146/42146 [00:18<00:00, 2296.74it/s]


pn_history max(lengths): 323


  0%|          | 0/143 [00:00<?, ?it/s]

100%|██████████| 143/143 [00:00<00:00, 16462.25it/s]


feature_text max(lengths): 28


max_len: 354


In [17]:
# ====================================================
# Dataset
# ====================================================
def prepare_input(cfg, text, feature_text):

    # Omit token_type_ids for ModernBERT
    is_token_type_ids = CFG.model != 'answerdotai/ModernBERT-base'

    inputs = cfg.tokenizer(text, feature_text, 
                           add_special_tokens=True,
                           max_length=CFG.max_len,
                           padding="max_length",
                           return_offsets_mapping=False,
                           return_token_type_ids=is_token_type_ids)
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs


def create_label(cfg, text, annotation_length, location_list):
    encoded = cfg.tokenizer(text,
                            add_special_tokens=True,
                            max_length=CFG.max_len,
                            padding="max_length",
                            return_offsets_mapping=True)
    offset_mapping = encoded['offset_mapping']
    ignore_idxes = np.where(np.array(encoded.sequence_ids()) != 0)[0]
    label = np.zeros(len(offset_mapping))
    label[ignore_idxes] = -1
    if annotation_length != 0:
        for location in location_list:
            for loc in [s.split() for s in location.split(';')]:
                start_idx = -1
                end_idx = -1
                start, end = int(loc[0]), int(loc[1])
                for idx in range(len(offset_mapping)):
                    if (start_idx == -1) & (start < offset_mapping[idx][0]):
                        start_idx = idx - 1
                    if (end_idx == -1) & (end <= offset_mapping[idx][1]):
                        end_idx = idx + 1
                if start_idx == -1:
                    start_idx = end_idx
                if (start_idx != -1) & (end_idx != -1):
                    label[start_idx:end_idx] = 1
    return torch.tensor(label, dtype=torch.bfloat16)


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.feature_texts = df['feature_text'].values
        self.pn_historys = df['pn_history'].values
        self.annotation_lengths = df['annotation_length'].values
        self.locations = df['location'].values

    def __len__(self):
        return len(self.feature_texts)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, 
                               self.pn_historys[item], 
                               self.feature_texts[item])
        label = create_label(self.cfg, 
                             self.pn_historys[item], 
                             self.annotation_lengths[item], 
                             self.locations[item])
        return inputs, label

# Model

In [18]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from torch.utils.checkpoint import checkpoint

class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = AutoModel(self.config)

        self.model.gradient_checkpointing_enable()
        self.fc_dropout = nn.Dropout(cfg.fc_dropout)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        # Debugging input shapes and types
        #print("Feature extraction inputs:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        # Process the inputs through the model
        outputs = self.model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        
        # Debugging the output of the model
        #print(f"Last hidden states: shape={last_hidden_states.shape}, dtype={last_hidden_states.dtype}")
        return last_hidden_states

    def forward(self, inputs):
        # Debugging the forward pass
        #print("Starting forward pass...")
        #print("Input keys and shapes:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):  # Enable AMP
            #print("Running model under AMP...")
            outputs = self.model(**inputs)
        
        # Debugging model output
        #print("Model outputs:")
        #if isinstance(outputs, tuple):
           # for idx, output in enumerate(outputs):
                #print(f"Output {idx}: shape={output.shape if hasattr(output, 'shape') else 'N/A'}")
        #else:
            #print(f"Outputs: {outputs}")
            #pass
        
        # Final linear layer
        feature = outputs.last_hidden_state
        output = self.fc(self.fc_dropout(feature))
        
        # Debugging final output
        #print(f"Final output shape: {output.shape}, dtype={output.dtype}")
        return output


# Helper functions

In [19]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7270e916a690>

In [20]:
def valid_fn(valid_loader, model, criterion, device):
    """
    Validation function with detailed debugging for intermediate outputs.
    """
    losses = AverageMeter()
    model.eval()
    preds = []
    start = time.time()

    print("Starting validation...")
    for step, (inputs, labels) in enumerate(valid_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        
        with torch.no_grad():
            y_preds = model(inputs)  # Model predictions
        
        # Compute loss
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)

        # Store predictions
        preds.append(y_preds.sigmoid().to('cpu').numpy())

        end = time.time()

        # Debugging: Log step-wise progress and intermediate outputs
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(step, len(valid_loader),
                          loss=losses,
                          remain=timeSince(start, float(step+1)/len(valid_loader))))
            #print(f"Sample Predictions: {y_preds.sigmoid()[:5].view(-1).tolist()}")  # Example predictions
            #print(f"Sample Labels: {labels[:5].view(-1).tolist()}")  # Example labels

    predictions = np.concatenate(preds)

    # Debugging: Check shape and content of predictions
    #print(f"Validation Predictions Shape: {predictions.shape}")
    # print(f"First 5 Predictions: {predictions[:5]}")

    return losses.avg, predictions


def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    """
    Training function with detailed debugging for loss and gradients.
    """
    print(f"\nStarting training epoch {epoch + 1} for fold {fold}")
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)

    losses = AverageMeter()
    start = time.time()
    global_step = 0

    for step, (inputs, labels) in enumerate(train_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)

        # Mixed precision training
        with torch.cuda.amp.autocast(enabled=CFG.apex, dtype=torch.bfloat16):
            y_preds = model(inputs)
            loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
            loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()

        # Gradient accumulation
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps

        losses.update(loss.item(), batch_size)

        # Backpropagation
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)

        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()

        end = time.time()
        
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print(f'Epoch: [{epoch+1}][{step}/{len(train_loader)}] '
                  f'Elapsed {timeSince(start, float(step+1)/len(train_loader))} '
                  f'Loss: {losses.val:.4f}({losses.avg:.4f}) '
                  f'Grad Norm: {grad_norm:.4f}  '
                  f'LR: {scheduler.get_lr()[0]}')

    print(f"Epoch {epoch + 1} completed. Average loss: {losses.avg:.4f}")
    return losses.avg


def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.sigmoid().to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [21]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    valid_texts = valid_folds['pn_history'].values
    valid_labels = create_labels_for_scoring(valid_folds)
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = CustomModel(CFG, config_path=None, pretrained=True)

    checkpoint_path = f"/workspace/deberta-v3-large-5-folds-public/deberta-v3-large_fold{fold}_best.pth"
    if os.path.exists(checkpoint_path):
        LOGGER.info(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location="cuda", weights_only=False) # need to use weight_only = True for latest pytorch version
        model.load_state_dict(checkpoint['model'], strict=False)  # Load model weights
        model.to(device)
        model.train()  # Set model to training mode

    else:
        LOGGER.info("No checkpoint found, starting fresh.")
        
    torch.save(model.config, OUTPUT_DIR+'config.pth')
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss(reduction="none")
    
    best_score = 0.

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        print("Starting training")
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        print("Starting evaluation")
        avg_val_loss, predictions = valid_fn(valid_loader, model, criterion, device)
        predictions = predictions.reshape((len(valid_folds), CFG.max_len))
        
        # scoring
        char_probs = get_char_probs(valid_texts, predictions, CFG.tokenizer)
        results = get_results(char_probs, th=0.5)
        preds = get_predictions(results)
        score = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        if CFG.wandb:
            wandb.log({f"[fold{fold}] epoch": epoch+1, 
                       f"[fold{fold}] avg_train_loss": avg_loss, 
                       f"[fold{fold}] avg_val_loss": avg_val_loss,
                       f"[fold{fold}] score": score})
        
        if best_score <= score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(),
                        'predictions': predictions},
                        OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth")
            

    #predictions = torch.load(OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             #map_location=torch.device('cpu'))['predictions']
    predictions = torch.load(OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             map_location=torch.device('cpu'),weights_only=False)['predictions']

    
    valid_folds[[i for i in range(CFG.max_len)]] = predictions

    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

In [22]:
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    #model="microsoft/deberta-v3-large" #"answerdotai/ModernBERT-base"
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" If running on Kaggle
    model = "/workspace/deberta-v3-large" # running on runpod instance
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=0.05
    num_warmup_steps=0
    epochs=4
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 10 #batch_size=16
    fc_dropout=0.2
    max_len=354
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=1 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True

tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

if CFG.debug:
    CFG.epochs = 3
    CFG.trn_fold = [0]

In [23]:
torch.cuda.empty_cache()

In [24]:
def run_fold_training(oof_df, fold):
    """
    Helper function to run training for a specific fold.
    Concatenates the out-of-fold predictions to oof_df.
    """
    _oof_df = train_loop(train, fold)
    oof_df = pd.concat([oof_df, _oof_df])
    LOGGER.info(f"========== fold: {fold} result ==========")
    evaluate_results(_oof_df)  # Evaluate results after each fold
    return oof_df

def evaluate_results(oof_df):
    """
    Function to handle the result evaluation, including scoring and logging.
    """
    labels = create_labels_for_scoring(oof_df)
    predictions = oof_df[[i for i in range(CFG.max_len)]].values
    char_probs = get_char_probs(oof_df['pn_history'].values, predictions, CFG.tokenizer)
    results = get_results(char_probs, th=0.5)
    preds = get_predictions(results)
    score = get_score(labels, preds)
    LOGGER.info(f'Score: {score:<.4f}')

In [25]:
if __name__ == '__main__':
    CFG.debug = True

    if CFG.train:
        oof_df = pd.DataFrame()

        # Run training for a single fold in debug mode or multiple folds otherwise
        if CFG.debug:
            fold = 0  # In debug mode, we only use fold 0 along with 2-3 epochs 
            if fold in CFG.trn_fold:
                oof_df = run_fold_training(oof_df, fold)
        else:
            for fold in range(CFG.n_fold):  # For all folds in non-debug mode
                if fold in CFG.trn_fold:
                    oof_df = run_fold_training(oof_df, fold)

        # Final evaluation and saving results
        oof_df = oof_df.reset_index(drop=True)
        LOGGER.info(f"========== CV ==========")
        evaluate_results(oof_df)  # Final evaluation of all results
        oof_df.to_pickle(OUTPUT_DIR + 'oof_df.pkl')

    if CFG.wandb:
        wandb.finish()



Loading checkpoint from /workspace/deberta-v3-large-5-folds-public/deberta-v3-large_fold0_best.pth


Starting training

Starting training epoch 1 for fold 0


Epoch: [1][0/6898] Elapsed 0m 2s (remain 298m 56s) Loss: 0.0122(0.0122) Grad Norm: 86587.5938  LR: 9.999999999675951e-06


Epoch: [1][100/6898] Elapsed 1m 25s (remain 96m 26s) Loss: 0.0073(0.0063) Grad Norm: 41556.8477  LR: 9.99999669437843e-06


Epoch: [1][200/6898] Elapsed 2m 49s (remain 94m 15s) Loss: 0.0104(0.0061) Grad Norm: 18553.5176  LR: 9.999986908109645e-06


Epoch: [1][300/6898] Elapsed 4m 12s (remain 92m 15s) Loss: 0.0019(0.0059) Grad Norm: 8929.4521  LR: 9.999970640882282e-06


Epoch: [1][400/6898] Elapsed 5m 35s (remain 90m 42s) Loss: 0.0016(0.0058) Grad Norm: 12893.6094  LR: 9.999947892717426e-06


Epoch: [1][500/6898] Elapsed 6m 58s (remain 89m 2s) Loss: 0.0050(0.0057) Grad Norm: 12037.0732  LR: 9.999918663644563e-06


Epoch: [1][600/6898] Elapsed 8m 20s (remain 87m 28s) Loss: 0.0009(0.0055) Grad Norm: 3438.8408  LR: 9.999882953701581e-06


Epoch: [1][700/6898] Elapsed 9m 43s (remain 85m 57s) Loss: 0.0048(0.0053) Grad Norm: 56405.3438  LR: 9.999840762934764e-06


Epoch: [1][800/6898] Elapsed 11m 6s (remain 84m 32s) Loss: 0.0010(0.0052) Grad Norm: 24389.8828  LR: 9.999792091398804e-06


Epoch: [1][900/6898] Elapsed 12m 29s (remain 83m 11s) Loss: 0.0550(0.0053) Grad Norm: 56596.9062  LR: 9.999736939156785e-06


Epoch: [1][1000/6898] Elapsed 13m 53s (remain 81m 52s) Loss: 0.0096(0.0053) Grad Norm: 24598.0508  LR: 9.999675306280197e-06


Epoch: [1][1100/6898] Elapsed 15m 17s (remain 80m 30s) Loss: 0.0038(0.0052) Grad Norm: 15923.2549  LR: 9.999607192848925e-06


Epoch: [1][1200/6898] Elapsed 16m 39s (remain 79m 0s) Loss: 0.0035(0.0051) Grad Norm: 134362.8438  LR: 9.999532598951263e-06


Epoch: [1][1300/6898] Elapsed 18m 2s (remain 77m 36s) Loss: 0.0000(0.0050) Grad Norm: 45.0808  LR: 9.999451524683896e-06


Epoch: [1][1400/6898] Elapsed 19m 24s (remain 76m 7s) Loss: 0.0018(0.0050) Grad Norm: 9286.2549  LR: 9.99936397015191e-06


Epoch: [1][1500/6898] Elapsed 20m 47s (remain 74m 45s) Loss: 0.0017(0.0050) Grad Norm: 13822.4561  LR: 9.999269935468797e-06


Epoch: [1][1600/6898] Elapsed 22m 9s (remain 73m 20s) Loss: 0.0001(0.0049) Grad Norm: 276.9291  LR: 9.999169420756443e-06


Epoch: [1][1700/6898] Elapsed 23m 31s (remain 71m 51s) Loss: 0.0006(0.0049) Grad Norm: 5518.7627  LR: 9.999062426145132e-06


Epoch: [1][1800/6898] Elapsed 24m 52s (remain 70m 23s) Loss: 0.0037(0.0048) Grad Norm: 56807.0117  LR: 9.998948951773553e-06


Epoch: [1][1900/6898] Elapsed 26m 14s (remain 68m 58s) Loss: 0.0001(0.0048) Grad Norm: 711.6124  LR: 9.99882899778879e-06


Epoch: [1][2000/6898] Elapsed 27m 37s (remain 67m 35s) Loss: 0.0077(0.0047) Grad Norm: 47745.2852  LR: 9.998702564346325e-06


Epoch: [1][2100/6898] Elapsed 28m 57s (remain 66m 6s) Loss: 0.0005(0.0047) Grad Norm: 7888.9331  LR: 9.998569651610042e-06


Epoch: [1][2200/6898] Elapsed 30m 21s (remain 64m 46s) Loss: 0.0167(0.0047) Grad Norm: 225400.1719  LR: 9.998430259752222e-06


Epoch: [1][2300/6898] Elapsed 31m 44s (remain 63m 24s) Loss: 0.0001(0.0046) Grad Norm: 1181.3639  LR: 9.998284388953545e-06


Epoch: [1][2400/6898] Elapsed 33m 8s (remain 62m 4s) Loss: 0.0014(0.0046) Grad Norm: 25478.0234  LR: 9.998132039403086e-06


Epoch: [1][2500/6898] Elapsed 34m 34s (remain 60m 47s) Loss: 0.0113(0.0046) Grad Norm: 136653.5938  LR: 9.997973211298323e-06


Epoch: [1][2600/6898] Elapsed 36m 0s (remain 59m 28s) Loss: 0.0040(0.0046) Grad Norm: 27859.6562  LR: 9.997807904845123e-06


Epoch: [1][2700/6898] Elapsed 37m 24s (remain 58m 7s) Loss: 0.0060(0.0046) Grad Norm: 41644.1797  LR: 9.997636120257758e-06


Epoch: [1][2800/6898] Elapsed 38m 48s (remain 56m 45s) Loss: 0.0029(0.0046) Grad Norm: 14699.1406  LR: 9.997457857758896e-06


Epoch: [1][2900/6898] Elapsed 40m 11s (remain 55m 22s) Loss: 0.0083(0.0045) Grad Norm: 32292.1719  LR: 9.997273117579597e-06


Epoch: [1][3000/6898] Elapsed 41m 35s (remain 54m 0s) Loss: 0.0029(0.0045) Grad Norm: 41150.0078  LR: 9.997081899959324e-06


Epoch: [1][3100/6898] Elapsed 43m 0s (remain 52m 39s) Loss: 0.0079(0.0044) Grad Norm: 116875.3906  LR: 9.996884205145929e-06


Epoch: [1][3200/6898] Elapsed 44m 23s (remain 51m 16s) Loss: 0.0002(0.0044) Grad Norm: 4114.8677  LR: 9.996680033395664e-06


Epoch: [1][3300/6898] Elapsed 45m 48s (remain 49m 54s) Loss: 0.0002(0.0044) Grad Norm: 5128.7695  LR: 9.996469384973175e-06


Epoch: [1][3400/6898] Elapsed 47m 13s (remain 48m 33s) Loss: 0.0059(0.0044) Grad Norm: 169769.7812  LR: 9.996252260151506e-06


Epoch: [1][3500/6898] Elapsed 48m 37s (remain 47m 10s) Loss: 0.0013(0.0044) Grad Norm: 14936.5488  LR: 9.996028659212089e-06


Epoch: [1][3600/6898] Elapsed 50m 0s (remain 45m 47s) Loss: 0.0062(0.0043) Grad Norm: 186536.6562  LR: 9.995798582444759e-06


Epoch: [1][3700/6898] Elapsed 51m 21s (remain 44m 22s) Loss: 0.0113(0.0043) Grad Norm: 88859.5078  LR: 9.995562030147736e-06


Epoch: [1][3800/6898] Elapsed 52m 43s (remain 42m 57s) Loss: 0.0029(0.0043) Grad Norm: 19515.4316  LR: 9.995319002627643e-06


Epoch: [1][3900/6898] Elapsed 54m 6s (remain 41m 34s) Loss: 0.0023(0.0043) Grad Norm: 104090.4766  LR: 9.995069500199487e-06


Epoch: [1][4000/6898] Elapsed 55m 30s (remain 40m 11s) Loss: 0.0013(0.0043) Grad Norm: 35112.8242  LR: 9.994813523186671e-06


Epoch: [1][4100/6898] Elapsed 56m 56s (remain 38m 50s) Loss: 0.0001(0.0043) Grad Norm: 2260.3567  LR: 9.994551071920995e-06


Epoch: [1][4200/6898] Elapsed 58m 21s (remain 37m 28s) Loss: 0.0013(0.0042) Grad Norm: 44056.3750  LR: 9.994282146742643e-06


Epoch: [1][4300/6898] Elapsed 59m 43s (remain 36m 3s) Loss: 0.0008(0.0042) Grad Norm: 52513.2539  LR: 9.994006748000197e-06


Epoch: [1][4400/6898] Elapsed 61m 6s (remain 34m 40s) Loss: 0.0031(0.0042) Grad Norm: 78973.4531  LR: 9.99372487605063e-06


Epoch: [1][4500/6898] Elapsed 62m 30s (remain 33m 17s) Loss: 0.0057(0.0042) Grad Norm: 48939.0352  LR: 9.993436531259297e-06


Epoch: [1][4600/6898] Elapsed 63m 53s (remain 31m 53s) Loss: 0.0070(0.0042) Grad Norm: 88278.2500  LR: 9.993141713999955e-06


Epoch: [1][4700/6898] Elapsed 65m 16s (remain 30m 30s) Loss: 0.0027(0.0041) Grad Norm: 65571.8984  LR: 9.99284042465474e-06


Epoch: [1][4800/6898] Elapsed 66m 40s (remain 29m 7s) Loss: 0.0004(0.0041) Grad Norm: 16867.7715  LR: 9.992532663614185e-06


Epoch: [1][4900/6898] Elapsed 68m 5s (remain 27m 44s) Loss: 0.0031(0.0041) Grad Norm: 97865.3672  LR: 9.992218431277205e-06


Epoch: [1][5000/6898] Elapsed 69m 30s (remain 26m 21s) Loss: 0.0016(0.0041) Grad Norm: 50496.3438  LR: 9.99189772805111e-06


Epoch: [1][5100/6898] Elapsed 70m 52s (remain 24m 57s) Loss: 0.0008(0.0041) Grad Norm: 35016.1328  LR: 9.991570554351592e-06


Epoch: [1][5200/6898] Elapsed 72m 16s (remain 23m 34s) Loss: 0.0050(0.0041) Grad Norm: 178412.3438  LR: 9.991236910602735e-06


Epoch: [1][5300/6898] Elapsed 73m 38s (remain 22m 11s) Loss: 0.0022(0.0041) Grad Norm: 454662.8750  LR: 9.990896797237002e-06


Epoch: [1][5400/6898] Elapsed 75m 0s (remain 20m 47s) Loss: 0.0075(0.0041) Grad Norm: 142693.1250  LR: 9.990550214695248e-06


Epoch: [1][5500/6898] Elapsed 76m 24s (remain 19m 24s) Loss: 0.0024(0.0040) Grad Norm: 68486.9688  LR: 9.990197163426711e-06


Epoch: [1][5600/6898] Elapsed 77m 49s (remain 18m 1s) Loss: 0.0055(0.0040) Grad Norm: 66398.2109  LR: 9.989837643889015e-06


Epoch: [1][5700/6898] Elapsed 79m 10s (remain 16m 37s) Loss: 0.0000(0.0040) Grad Norm: 1325.4532  LR: 9.989471656548169e-06


Epoch: [1][5800/6898] Elapsed 80m 33s (remain 15m 14s) Loss: 0.0007(0.0040) Grad Norm: 29074.1660  LR: 9.989099201878561e-06


Epoch: [1][5900/6898] Elapsed 81m 56s (remain 13m 50s) Loss: 0.0062(0.0040) Grad Norm: 167666.2969  LR: 9.988720280362967e-06


Epoch: [1][6000/6898] Elapsed 83m 20s (remain 12m 27s) Loss: 0.0007(0.0040) Grad Norm: 27633.9395  LR: 9.988334892492542e-06


Epoch: [1][6100/6898] Elapsed 84m 45s (remain 11m 4s) Loss: 0.0006(0.0040) Grad Norm: 214395.8594  LR: 9.987943038766825e-06


Epoch: [1][6200/6898] Elapsed 86m 8s (remain 9m 40s) Loss: 0.0060(0.0040) Grad Norm: 754284.6250  LR: 9.987544719693735e-06


Epoch: [1][6300/6898] Elapsed 87m 34s (remain 8m 17s) Loss: 0.0015(0.0039) Grad Norm: 202696.6406  LR: 9.987139935789569e-06


Epoch: [1][6400/6898] Elapsed 88m 57s (remain 6m 54s) Loss: 0.0014(0.0039) Grad Norm: 197421.1562  LR: 9.986728687579009e-06


Epoch: [1][6500/6898] Elapsed 90m 22s (remain 5m 31s) Loss: 0.0022(0.0039) Grad Norm: 64534.4219  LR: 9.986310975595112e-06


Epoch: [1][6600/6898] Elapsed 91m 47s (remain 4m 7s) Loss: 0.0000(0.0039) Grad Norm: 1094.6815  LR: 9.985886800379311e-06


Epoch: [1][6700/6898] Elapsed 93m 13s (remain 2m 44s) Loss: 0.0004(0.0039) Grad Norm: 69333.4844  LR: 9.985456162481427e-06


Epoch: [1][6800/6898] Elapsed 94m 39s (remain 1m 21s) Loss: 0.0002(0.0039) Grad Norm: 13813.5957  LR: 9.985019062459642e-06


Epoch: [1][6897/6898] Elapsed 96m 1s (remain 0m 0s) Loss: 0.0012(0.0039) Grad Norm: 57039.5547  LR: 9.984588901738089e-06


Epoch 1 completed. Average loss: 0.0039
Starting evaluation
Starting validation...


EVAL: [0/1725] Elapsed 0m 0s (remain 13m 38s) Loss: 0.0018(0.0018) 


EVAL: [100/1725] Elapsed 0m 7s (remain 1m 53s) Loss: 0.0010(0.0030) 


EVAL: [200/1725] Elapsed 0m 13s (remain 1m 43s) Loss: 0.0019(0.0037) 


EVAL: [300/1725] Elapsed 0m 20s (remain 1m 35s) Loss: 0.0022(0.0032) 


EVAL: [400/1725] Elapsed 0m 26s (remain 1m 28s) Loss: 0.0002(0.0033) 


EVAL: [500/1725] Elapsed 0m 33s (remain 1m 21s) Loss: 0.0053(0.0031) 


EVAL: [600/1725] Elapsed 0m 40s (remain 1m 14s) Loss: 0.0011(0.0031) 


EVAL: [700/1725] Elapsed 0m 46s (remain 1m 8s) Loss: 0.0105(0.0030) 


EVAL: [800/1725] Elapsed 0m 53s (remain 1m 1s) Loss: 0.0069(0.0033) 


EVAL: [900/1725] Elapsed 0m 59s (remain 0m 54s) Loss: 0.0081(0.0035) 


EVAL: [1000/1725] Elapsed 1m 6s (remain 0m 48s) Loss: 0.0043(0.0036) 


EVAL: [1100/1725] Elapsed 1m 13s (remain 0m 41s) Loss: 0.0074(0.0038) 


EVAL: [1200/1725] Elapsed 1m 19s (remain 0m 34s) Loss: 0.0038(0.0038) 


EVAL: [1300/1725] Elapsed 1m 26s (remain 0m 28s) Loss: 0.0030(0.0039) 


EVAL: [1400/1725] Elapsed 1m 33s (remain 0m 21s) Loss: 0.0023(0.0038) 


EVAL: [1500/1725] Elapsed 1m 39s (remain 0m 14s) Loss: 0.0022(0.0037) 


EVAL: [1600/1725] Elapsed 1m 46s (remain 0m 8s) Loss: 0.0050(0.0037) 


EVAL: [1700/1725] Elapsed 1m 53s (remain 0m 1s) Loss: 0.0005(0.0036) 


EVAL: [1724/1725] Elapsed 1m 54s (remain 0m 0s) Loss: 0.0038(0.0035) 


Epoch 1 - avg_train_loss: 0.0039  avg_val_loss: 0.0035  time: 5887s


Epoch 1 - Score: 0.9780


Epoch 1 - Save Best Score: 0.9780 Model


Starting training

Starting training epoch 2 for fold 0


Epoch: [2][0/6898] Elapsed 0m 1s (remain 172m 0s) Loss: 0.0016(0.0016) Grad Norm: 15739.0996  LR: 9.984584435431668e-06


Epoch: [2][100/6898] Elapsed 1m 25s (remain 95m 36s) Loss: 0.0006(0.0028) Grad Norm: 7361.9819  LR: 9.984134542084126e-06


Epoch: [2][200/6898] Elapsed 2m 47s (remain 93m 3s) Loss: 0.0091(0.0029) Grad Norm: 23621.1348  LR: 9.983678188325765e-06


Epoch: [2][300/6898] Elapsed 4m 11s (remain 91m 49s) Loss: 0.0015(0.0027) Grad Norm: 30424.5098  LR: 9.98321537474811e-06


Epoch: [2][400/6898] Elapsed 5m 34s (remain 90m 23s) Loss: 0.0000(0.0028) Grad Norm: 28.8889  LR: 9.982746101951055e-06


Epoch: [2][500/6898] Elapsed 6m 57s (remain 88m 47s) Loss: 0.0023(0.0029) Grad Norm: 14682.8359  LR: 9.982270370542873e-06


Epoch: [2][600/6898] Elapsed 8m 22s (remain 87m 46s) Loss: 0.0039(0.0030) Grad Norm: 8693.3545  LR: 9.9817881811402e-06


Epoch: [2][700/6898] Elapsed 9m 48s (remain 86m 46s) Loss: 0.0035(0.0033) Grad Norm: 10553.2012  LR: 9.98129953436805e-06


Epoch: [2][800/6898] Elapsed 11m 16s (remain 85m 49s) Loss: 0.0053(0.0032) Grad Norm: 29442.2266  LR: 9.980804430859802e-06


Epoch: [2][900/6898] Elapsed 12m 41s (remain 84m 26s) Loss: 0.0005(0.0032) Grad Norm: 9363.5391  LR: 9.980302871257212e-06


Epoch: [2][1000/6898] Elapsed 14m 5s (remain 83m 1s) Loss: 0.0001(0.0031) Grad Norm: 462.4041  LR: 9.979794856210396e-06


Epoch: [2][1100/6898] Elapsed 15m 30s (remain 81m 38s) Loss: 0.0019(0.0032) Grad Norm: 29524.7793  LR: 9.979280386377841e-06


Epoch: [2][1200/6898] Elapsed 16m 54s (remain 80m 14s) Loss: 0.0001(0.0032) Grad Norm: 557.3111  LR: 9.978759462426399e-06


Epoch: [2][1300/6898] Elapsed 18m 20s (remain 78m 53s) Loss: 0.0031(0.0032) Grad Norm: 15798.3799  LR: 9.97823208503129e-06


Epoch: [2][1400/6898] Elapsed 19m 44s (remain 77m 27s) Loss: 0.0019(0.0032) Grad Norm: 10185.0000  LR: 9.977698254876099e-06


Epoch: [2][1500/6898] Elapsed 21m 7s (remain 75m 56s) Loss: 0.0001(0.0032) Grad Norm: 1475.0183  LR: 9.977157972652774e-06


Epoch: [2][1600/6898] Elapsed 22m 29s (remain 74m 24s) Loss: 0.0039(0.0032) Grad Norm: 9529.6992  LR: 9.976611239061623e-06


Epoch: [2][1700/6898] Elapsed 23m 50s (remain 72m 51s) Loss: 0.0024(0.0031) Grad Norm: 13157.7324  LR: 9.976058054811323e-06


Epoch: [2][1800/6898] Elapsed 25m 11s (remain 71m 18s) Loss: 0.0002(0.0031) Grad Norm: 3552.2886  LR: 9.975498420618907e-06


Epoch: [2][1900/6898] Elapsed 26m 35s (remain 69m 54s) Loss: 0.0072(0.0031) Grad Norm: 25261.6094  LR: 9.97493233720977e-06


Epoch: [2][2000/6898] Elapsed 27m 59s (remain 68m 31s) Loss: 0.0019(0.0031) Grad Norm: 28776.6074  LR: 9.974359805317669e-06


Epoch: [2][2100/6898] Elapsed 29m 24s (remain 67m 9s) Loss: 0.0060(0.0031) Grad Norm: 71333.0078  LR: 9.973780825684713e-06


Epoch: [2][2200/6898] Elapsed 30m 49s (remain 65m 47s) Loss: 0.0009(0.0031) Grad Norm: 5952.8799  LR: 9.973195399061376e-06


Epoch: [2][2300/6898] Elapsed 32m 13s (remain 64m 22s) Loss: 0.0025(0.0031) Grad Norm: 11155.5371  LR: 9.972603526206484e-06


Epoch: [2][2400/6898] Elapsed 33m 34s (remain 62m 53s) Loss: 0.0124(0.0031) Grad Norm: 132462.9844  LR: 9.972005207887219e-06


Epoch: [2][2500/6898] Elapsed 34m 57s (remain 61m 27s) Loss: 0.0041(0.0031) Grad Norm: 36515.3359  LR: 9.97140044487912e-06


Epoch: [2][2600/6898] Elapsed 36m 20s (remain 60m 2s) Loss: 0.0048(0.0031) Grad Norm: 23800.5723  LR: 9.970789237966076e-06


Epoch: [2][2700/6898] Elapsed 37m 44s (remain 58m 38s) Loss: 0.0019(0.0031) Grad Norm: 17059.5859  LR: 9.970171587940331e-06


Epoch: [2][2800/6898] Elapsed 39m 8s (remain 57m 15s) Loss: 0.0182(0.0031) Grad Norm: 96675.6562  LR: 9.96954749560248e-06


Epoch: [2][2900/6898] Elapsed 40m 33s (remain 55m 52s) Loss: 0.0018(0.0030) Grad Norm: 12847.9014  LR: 9.968916961761469e-06


Epoch: [2][3000/6898] Elapsed 41m 58s (remain 54m 30s) Loss: 0.0211(0.0030) Grad Norm: 54947.9648  LR: 9.968279987234591e-06


Epoch: [2][3100/6898] Elapsed 43m 21s (remain 53m 4s) Loss: 0.0007(0.0030) Grad Norm: 18056.9668  LR: 9.96763657284749e-06


Epoch: [2][3200/6898] Elapsed 44m 45s (remain 51m 41s) Loss: 0.0051(0.0030) Grad Norm: 37590.4141  LR: 9.966986719434159e-06


Epoch: [2][3300/6898] Elapsed 46m 6s (remain 50m 14s) Loss: 0.0044(0.0030) Grad Norm: 27633.6816  LR: 9.966330427836933e-06


Epoch: [2][3400/6898] Elapsed 47m 29s (remain 48m 50s) Loss: 0.0000(0.0030) Grad Norm: 165.0560  LR: 9.965667698906493e-06


Epoch: [2][3500/6898] Elapsed 48m 54s (remain 47m 27s) Loss: 0.0119(0.0030) Grad Norm: 79650.1016  LR: 9.964998533501867e-06


Epoch: [2][3600/6898] Elapsed 50m 18s (remain 46m 3s) Loss: 0.0000(0.0030) Grad Norm: 757.8278  LR: 9.964322932490422e-06


Epoch: [2][3700/6898] Elapsed 51m 42s (remain 44m 40s) Loss: 0.0014(0.0030) Grad Norm: 10565.2793  LR: 9.96364089674787e-06


Epoch: [2][3800/6898] Elapsed 53m 6s (remain 43m 16s) Loss: 0.0036(0.0030) Grad Norm: 54933.2031  LR: 9.962952427158263e-06


Epoch: [2][3900/6898] Elapsed 54m 30s (remain 41m 52s) Loss: 0.0018(0.0030) Grad Norm: 25147.0664  LR: 9.96225752461399e-06


Epoch: [2][4000/6898] Elapsed 55m 54s (remain 40m 28s) Loss: 0.0216(0.0030) Grad Norm: 340749.8750  LR: 9.961556190015781e-06


Epoch: [2][4100/6898] Elapsed 57m 17s (remain 39m 4s) Loss: 0.0019(0.0030) Grad Norm: 52425.2383  LR: 9.960848424272704e-06


Epoch: [2][4200/6898] Elapsed 58m 41s (remain 37m 40s) Loss: 0.0049(0.0030) Grad Norm: 132162.4844  LR: 9.960134228302158e-06


Epoch: [2][4300/6898] Elapsed 60m 5s (remain 36m 17s) Loss: 0.0001(0.0030) Grad Norm: 4356.5059  LR: 9.959413603029884e-06


Epoch: [2][4400/6898] Elapsed 61m 30s (remain 34m 53s) Loss: 0.0002(0.0030) Grad Norm: 10596.3271  LR: 9.95868654938995e-06


Epoch: [2][4500/6898] Elapsed 62m 54s (remain 33m 29s) Loss: 0.0004(0.0030) Grad Norm: 13535.7861  LR: 9.957953068324762e-06


Epoch: [2][4600/6898] Elapsed 64m 17s (remain 32m 5s) Loss: 0.0006(0.0030) Grad Norm: 29291.6836  LR: 9.957213160785053e-06


Epoch: [2][4700/6898] Elapsed 65m 41s (remain 30m 41s) Loss: 0.0004(0.0030) Grad Norm: 14574.2607  LR: 9.956466827729889e-06


Epoch: [2][4800/6898] Elapsed 67m 3s (remain 29m 17s) Loss: 0.0006(0.0030) Grad Norm: 62806.8594  LR: 9.955714070126663e-06


Epoch: [2][4900/6898] Elapsed 68m 27s (remain 27m 53s) Loss: 0.0027(0.0029) Grad Norm: 42552.3086  LR: 9.954954888951093e-06


Epoch: [2][5000/6898] Elapsed 69m 50s (remain 26m 29s) Loss: 0.0015(0.0029) Grad Norm: 73717.7578  LR: 9.954189285187228e-06


Epoch: [2][5100/6898] Elapsed 71m 14s (remain 25m 5s) Loss: 0.0000(0.0029) Grad Norm: 73.7413  LR: 9.95341725982744e-06


Epoch: [2][5200/6898] Elapsed 72m 39s (remain 23m 42s) Loss: 0.0000(0.0029) Grad Norm: 1452.3430  LR: 9.952638813872425e-06


Epoch: [2][5300/6898] Elapsed 74m 1s (remain 22m 18s) Loss: 0.0087(0.0029) Grad Norm: 195889.7344  LR: 9.951853948331198e-06


Epoch: [2][5400/6898] Elapsed 75m 26s (remain 20m 54s) Loss: 0.0000(0.0029) Grad Norm: 1107.0168  LR: 9.951062664221102e-06


Epoch: [2][5500/6898] Elapsed 76m 51s (remain 19m 31s) Loss: 0.0011(0.0029) Grad Norm: 143217.1719  LR: 9.950264962567792e-06


Epoch: [2][5600/6898] Elapsed 78m 14s (remain 18m 7s) Loss: 0.0103(0.0029) Grad Norm: 55533.3281  LR: 9.949460844405247e-06


Epoch: [2][5700/6898] Elapsed 79m 34s (remain 16m 42s) Loss: 0.0001(0.0029) Grad Norm: 15888.4561  LR: 9.94865031077576e-06


Epoch: [2][5800/6898] Elapsed 80m 55s (remain 15m 18s) Loss: 0.0027(0.0029) Grad Norm: 100782.0156  LR: 9.947833362729942e-06


Epoch: [2][5900/6898] Elapsed 82m 16s (remain 13m 53s) Loss: 0.0008(0.0029) Grad Norm: 50762.7969  LR: 9.947010001326718e-06


Epoch: [2][6000/6898] Elapsed 83m 37s (remain 12m 29s) Loss: 0.0004(0.0029) Grad Norm: 70907.6406  LR: 9.946180227633322e-06


Epoch: [2][6100/6898] Elapsed 85m 0s (remain 11m 6s) Loss: 0.0001(0.0029) Grad Norm: 15721.6279  LR: 9.945344042725302e-06


Epoch: [2][6200/6898] Elapsed 86m 24s (remain 9m 42s) Loss: 0.0017(0.0029) Grad Norm: 197467.7812  LR: 9.94450144768652e-06


Epoch: [2][6300/6898] Elapsed 87m 49s (remain 8m 19s) Loss: 0.0000(0.0029) Grad Norm: 6008.7729  LR: 9.943652443609143e-06


Epoch: [2][6400/6898] Elapsed 89m 13s (remain 6m 55s) Loss: 0.0132(0.0029) Grad Norm: 475428.6875  LR: 9.942797031593645e-06


Epoch: [2][6500/6898] Elapsed 90m 37s (remain 5m 32s) Loss: 0.0118(0.0029) Grad Norm: 319917.9062  LR: 9.941935212748808e-06


Epoch: [2][6600/6898] Elapsed 92m 0s (remain 4m 8s) Loss: 0.0019(0.0029) Grad Norm: 93269.1484  LR: 9.941066988191714e-06


Epoch: [2][6700/6898] Elapsed 93m 20s (remain 2m 44s) Loss: 0.0014(0.0029) Grad Norm: 82498.7266  LR: 9.940192359047756e-06


Epoch: [2][6800/6898] Elapsed 94m 40s (remain 1m 21s) Loss: 0.0020(0.0029) Grad Norm: 256638.4844  LR: 9.93931132645062e-06


Epoch: [2][6897/6898] Elapsed 96m 0s (remain 0m 0s) Loss: 0.0023(0.0029) Grad Norm: 70734.7109  LR: 9.938450607732207e-06
Epoch 2 completed. Average loss: 0.0029
Starting evaluation
Starting validation...


EVAL: [0/1725] Elapsed 0m 0s (remain 13m 4s) Loss: 0.0008(0.0008) 


EVAL: [100/1725] Elapsed 0m 7s (remain 1m 53s) Loss: 0.0001(0.0028) 


EVAL: [200/1725] Elapsed 0m 13s (remain 1m 43s) Loss: 0.0050(0.0037) 


EVAL: [300/1725] Elapsed 0m 20s (remain 1m 35s) Loss: 0.0015(0.0031) 


EVAL: [400/1725] Elapsed 0m 26s (remain 1m 28s) Loss: -0.0012(0.0031) 


EVAL: [500/1725] Elapsed 0m 33s (remain 1m 21s) Loss: 0.0097(0.0029) 


EVAL: [600/1725] Elapsed 0m 40s (remain 1m 14s) Loss: 0.0084(0.0029) 


EVAL: [700/1725] Elapsed 0m 46s (remain 1m 8s) Loss: 0.0049(0.0030) 


EVAL: [800/1725] Elapsed 0m 53s (remain 1m 1s) Loss: 0.0068(0.0034) 


EVAL: [900/1725] Elapsed 1m 0s (remain 0m 54s) Loss: 0.0078(0.0035) 


EVAL: [1000/1725] Elapsed 1m 6s (remain 0m 48s) Loss: -0.0004(0.0037) 


EVAL: [1100/1725] Elapsed 1m 13s (remain 0m 41s) Loss: 0.0003(0.0038) 


EVAL: [1200/1725] Elapsed 1m 19s (remain 0m 34s) Loss: 0.0019(0.0039) 


EVAL: [1300/1725] Elapsed 1m 26s (remain 0m 28s) Loss: -0.0002(0.0039) 


EVAL: [1400/1725] Elapsed 1m 33s (remain 0m 21s) Loss: 0.0064(0.0038) 


EVAL: [1500/1725] Elapsed 1m 39s (remain 0m 14s) Loss: 0.0022(0.0038) 


EVAL: [1600/1725] Elapsed 1m 46s (remain 0m 8s) Loss: 0.0039(0.0037) 


EVAL: [1700/1725] Elapsed 1m 53s (remain 0m 1s) Loss: -0.0007(0.0036) 


EVAL: [1724/1725] Elapsed 1m 54s (remain 0m 0s) Loss: 0.0042(0.0036) 


Epoch 2 - avg_train_loss: 0.0029  avg_val_loss: 0.0036  time: 5887s


Epoch 2 - Score: 0.9786


Epoch 2 - Save Best Score: 0.9786 Model


Starting training

Starting training epoch 3 for fold 0


Epoch: [3][0/6898] Elapsed 0m 1s (remain 155m 49s) Loss: 0.0001(0.0001) Grad Norm: 971.6568  LR: 9.938441702975689e-06


Epoch: [3][100/6898] Elapsed 1m 25s (remain 95m 58s) Loss: 0.0320(0.0025) Grad Norm: 58579.8867  LR: 9.937547994918364e-06


Epoch: [3][200/6898] Elapsed 2m 46s (remain 92m 21s) Loss: 0.0030(0.0024) Grad Norm: 16482.6855  LR: 9.936647886835472e-06


Epoch: [3][300/6898] Elapsed 4m 7s (remain 90m 18s) Loss: 0.0011(0.0023) Grad Norm: 2904.0078  LR: 9.935741379893732e-06


Epoch: [3][400/6898] Elapsed 5m 27s (remain 88m 33s) Loss: 0.0000(0.0022) Grad Norm: 290.6802  LR: 9.934828475268154e-06


Epoch: [3][500/6898] Elapsed 6m 50s (remain 87m 21s) Loss: 0.0000(0.0023) Grad Norm: 220.7211  LR: 9.933909174142042e-06


Epoch: [3][600/6898] Elapsed 8m 10s (remain 85m 42s) Loss: 0.0050(0.0023) Grad Norm: 19468.2969  LR: 9.932983477706985e-06


Epoch: [3][700/6898] Elapsed 9m 34s (remain 84m 38s) Loss: 0.0001(0.0022) Grad Norm: 737.6957  LR: 9.932051387162868e-06


Epoch: [3][800/6898] Elapsed 10m 55s (remain 83m 13s) Loss: 0.0001(0.0023) Grad Norm: 1141.0597  LR: 9.931112903717864e-06


Epoch: [3][900/6898] Elapsed 12m 18s (remain 81m 55s) Loss: 0.0000(0.0023) Grad Norm: 253.9905  LR: 9.93016802858843e-06


Epoch: [3][1000/6898] Elapsed 13m 39s (remain 80m 29s) Loss: 0.0003(0.0024) Grad Norm: 2651.0867  LR: 9.929216762999307e-06


Epoch: [3][1100/6898] Elapsed 15m 0s (remain 79m 2s) Loss: 0.0045(0.0024) Grad Norm: 16299.2500  LR: 9.928259108183523e-06


Epoch: [3][1200/6898] Elapsed 16m 24s (remain 77m 51s) Loss: 0.0017(0.0024) Grad Norm: 9947.3887  LR: 9.927295065382384e-06


Epoch: [3][1300/6898] Elapsed 17m 49s (remain 76m 41s) Loss: 0.0013(0.0025) Grad Norm: 5560.3589  LR: 9.926324635845478e-06


Epoch: [3][1400/6898] Elapsed 19m 13s (remain 75m 25s) Loss: 0.0001(0.0024) Grad Norm: 2337.9861  LR: 9.925347820830669e-06


Epoch: [3][1500/6898] Elapsed 20m 35s (remain 74m 3s) Loss: 0.0004(0.0025) Grad Norm: 13033.9131  LR: 9.924364621604103e-06


Epoch: [3][1600/6898] Elapsed 21m 57s (remain 72m 39s) Loss: 0.0012(0.0024) Grad Norm: 10811.7354  LR: 9.923375039440197e-06


Epoch: [3][1700/6898] Elapsed 23m 22s (remain 71m 25s) Loss: 0.0000(0.0025) Grad Norm: 344.9766  LR: 9.922379075621642e-06


Epoch: [3][1800/6898] Elapsed 24m 44s (remain 70m 0s) Loss: 0.0019(0.0024) Grad Norm: 37766.0234  LR: 9.921376731439403e-06


Epoch: [3][1900/6898] Elapsed 26m 5s (remain 68m 35s) Loss: 0.0002(0.0024) Grad Norm: 5273.7104  LR: 9.920368008192711e-06


Epoch: [3][2000/6898] Elapsed 27m 31s (remain 67m 21s) Loss: 0.0000(0.0025) Grad Norm: 308.2515  LR: 9.91935290718907e-06


Epoch: [3][2100/6898] Elapsed 28m 55s (remain 66m 2s) Loss: 0.0055(0.0024) Grad Norm: 35583.0508  LR: 9.918331429744247e-06


Epoch: [3][2200/6898] Elapsed 30m 20s (remain 64m 44s) Loss: 0.0002(0.0024) Grad Norm: 5920.8042  LR: 9.91730357718228e-06


Epoch: [3][2300/6898] Elapsed 31m 44s (remain 63m 24s) Loss: 0.0000(0.0025) Grad Norm: 61.2994  LR: 9.916269350835464e-06


Epoch: [3][2400/6898] Elapsed 33m 7s (remain 62m 2s) Loss: 0.0000(0.0025) Grad Norm: 53.8366  LR: 9.915228752044356e-06


Epoch: [3][2500/6898] Elapsed 34m 30s (remain 60m 40s) Loss: 0.0000(0.0024) Grad Norm: 45.5815  LR: 9.91418178215778e-06


Epoch: [3][2600/6898] Elapsed 35m 54s (remain 59m 18s) Loss: 0.0002(0.0024) Grad Norm: 5654.9902  LR: 9.913128442532809e-06


Epoch: [3][2700/6898] Elapsed 37m 15s (remain 57m 53s) Loss: 0.0002(0.0025) Grad Norm: 8241.7949  LR: 9.912068734534778e-06


Epoch: [3][2800/6898] Elapsed 38m 35s (remain 56m 26s) Loss: 0.0000(0.0024) Grad Norm: 130.7130  LR: 9.911002659537276e-06


Epoch: [3][2900/6898] Elapsed 39m 57s (remain 55m 3s) Loss: 0.0017(0.0024) Grad Norm: 20098.3164  LR: 9.909930218922143e-06


Epoch: [3][3000/6898] Elapsed 41m 20s (remain 53m 41s) Loss: 0.0026(0.0024) Grad Norm: 41065.4844  LR: 9.908851414079471e-06


Epoch: [3][3100/6898] Elapsed 42m 43s (remain 52m 18s) Loss: 0.0025(0.0024) Grad Norm: 19602.5898  LR: 9.907766246407606e-06


Epoch: [3][3200/6898] Elapsed 44m 3s (remain 50m 53s) Loss: 0.0010(0.0024) Grad Norm: 14581.2744  LR: 9.906674717313131e-06


Epoch: [3][3300/6898] Elapsed 45m 23s (remain 49m 27s) Loss: 0.0000(0.0024) Grad Norm: 231.3525  LR: 9.905576828210884e-06


Epoch: [3][3400/6898] Elapsed 46m 43s (remain 48m 2s) Loss: 0.0003(0.0024) Grad Norm: 6948.9604  LR: 9.90447258052394e-06


Epoch: [3][3500/6898] Elapsed 48m 4s (remain 46m 38s) Loss: 0.0008(0.0024) Grad Norm: 23638.4727  LR: 9.903361975683626e-06


Epoch: [3][3600/6898] Elapsed 49m 24s (remain 45m 14s) Loss: 0.0083(0.0024) Grad Norm: 144409.2188  LR: 9.902245015129497e-06


Epoch: [3][3700/6898] Elapsed 50m 46s (remain 43m 51s) Loss: 0.0032(0.0024) Grad Norm: 23613.8945  LR: 9.901121700309353e-06


Epoch: [3][3800/6898] Elapsed 52m 6s (remain 42m 27s) Loss: 0.0064(0.0024) Grad Norm: 53046.6953  LR: 9.89999203267923e-06


Epoch: [3][3900/6898] Elapsed 53m 27s (remain 41m 3s) Loss: 0.0000(0.0024) Grad Norm: 985.7111  LR: 9.898856013703398e-06


Epoch: [3][4000/6898] Elapsed 54m 47s (remain 39m 40s) Loss: 0.0000(0.0024) Grad Norm: 671.4966  LR: 9.897713644854359e-06


Epoch: [3][4100/6898] Elapsed 56m 9s (remain 38m 18s) Loss: 0.0011(0.0024) Grad Norm: 59631.6367  LR: 9.896564927612844e-06


Epoch: [3][4200/6898] Elapsed 57m 30s (remain 36m 55s) Loss: 0.0035(0.0024) Grad Norm: 66948.7109  LR: 9.895409863467817e-06


Epoch: [3][4300/6898] Elapsed 58m 51s (remain 35m 32s) Loss: 0.0076(0.0024) Grad Norm: 65566.1328  LR: 9.894248453916466e-06


Epoch: [3][4400/6898] Elapsed 60m 15s (remain 34m 11s) Loss: 0.0001(0.0024) Grad Norm: 2990.0598  LR: 9.893080700464203e-06


Epoch: [3][4500/6898] Elapsed 61m 38s (remain 32m 49s) Loss: 0.0006(0.0024) Grad Norm: 26430.2188  LR: 9.891906604624666e-06


Epoch: [3][4600/6898] Elapsed 63m 3s (remain 31m 28s) Loss: 0.0000(0.0024) Grad Norm: 1308.5687  LR: 9.890726167919712e-06


Epoch: [3][4700/6898] Elapsed 64m 27s (remain 30m 7s) Loss: 0.0015(0.0024) Grad Norm: 49054.6680  LR: 9.889539391879418e-06


Epoch: [3][4800/6898] Elapsed 65m 48s (remain 28m 44s) Loss: 0.0000(0.0024) Grad Norm: 1195.1071  LR: 9.888346278042074e-06


Epoch: [3][4900/6898] Elapsed 67m 10s (remain 27m 22s) Loss: 0.0040(0.0024) Grad Norm: 136663.8906  LR: 9.887146827954192e-06


Epoch: [3][5000/6898] Elapsed 68m 35s (remain 26m 1s) Loss: 0.0000(0.0024) Grad Norm: 1203.9294  LR: 9.885941043170491e-06


Epoch: [3][5100/6898] Elapsed 69m 58s (remain 24m 39s) Loss: 0.0017(0.0024) Grad Norm: 160404.0625  LR: 9.884728925253906e-06


Epoch: [3][5200/6898] Elapsed 71m 21s (remain 23m 16s) Loss: 0.0000(0.0024) Grad Norm: 349.1620  LR: 9.883510475775576e-06


Epoch: [3][5300/6898] Elapsed 72m 42s (remain 21m 54s) Loss: 0.0000(0.0024) Grad Norm: 1027.7078  LR: 9.882285696314846e-06


Epoch: [3][5400/6898] Elapsed 74m 7s (remain 20m 32s) Loss: 0.0000(0.0024) Grad Norm: 224.1151  LR: 9.881054588459278e-06


Epoch: [3][5500/6898] Elapsed 75m 27s (remain 19m 9s) Loss: 0.0008(0.0024) Grad Norm: 41018.6680  LR: 9.87981715380462e-06


Epoch: [3][5600/6898] Elapsed 76m 47s (remain 17m 46s) Loss: 0.0001(0.0024) Grad Norm: 19576.3555  LR: 9.878573393954834e-06


Epoch: [3][5700/6898] Elapsed 78m 7s (remain 16m 24s) Loss: 0.0151(0.0024) Grad Norm: 111830.9844  LR: 9.87732331052207e-06


Epoch: [3][5800/6898] Elapsed 79m 31s (remain 15m 2s) Loss: 0.0001(0.0024) Grad Norm: 9332.6289  LR: 9.876066905126687e-06


Epoch: [3][5900/6898] Elapsed 80m 55s (remain 13m 40s) Loss: 0.0001(0.0024) Grad Norm: 12887.6445  LR: 9.874804179397224e-06


Epoch: [3][6000/6898] Elapsed 82m 16s (remain 12m 17s) Loss: 0.0007(0.0024) Grad Norm: 72563.9297  LR: 9.873535134970426e-06


Epoch: [3][6100/6898] Elapsed 83m 40s (remain 10m 55s) Loss: 0.0019(0.0024) Grad Norm: 53148.8242  LR: 9.872259773491219e-06


Epoch: [3][6200/6898] Elapsed 85m 5s (remain 9m 33s) Loss: 0.0007(0.0024) Grad Norm: 103980.8125  LR: 9.870978096612721e-06


Epoch: [3][6300/6898] Elapsed 86m 30s (remain 8m 11s) Loss: 0.0005(0.0024) Grad Norm: 91814.1406  LR: 9.869690105996235e-06


Epoch: [3][6400/6898] Elapsed 87m 53s (remain 6m 49s) Loss: 0.0001(0.0024) Grad Norm: 8484.7217  LR: 9.86839580331125e-06


Epoch: [3][6500/6898] Elapsed 89m 16s (remain 5m 27s) Loss: 0.0114(0.0024) Grad Norm: 366892.0938  LR: 9.867095190235432e-06


Epoch: [3][6600/6898] Elapsed 90m 38s (remain 4m 4s) Loss: 0.0130(0.0024) Grad Norm: 606723.3750  LR: 9.86578826845463e-06


Epoch: [3][6700/6898] Elapsed 91m 59s (remain 2m 42s) Loss: 0.0027(0.0024) Grad Norm: 115773.4844  LR: 9.864475039662874e-06


Epoch: [3][6800/6898] Elapsed 93m 19s (remain 1m 19s) Loss: 0.0000(0.0024) Grad Norm: 1387.1295  LR: 9.863155505562356e-06


Epoch: [3][6897/6898] Elapsed 94m 37s (remain 0m 0s) Loss: 0.0014(0.0024) Grad Norm: 166705.5312  LR: 9.86186953469538e-06
Epoch 3 completed. Average loss: 0.0024
Starting evaluation
Starting validation...


EVAL: [0/1725] Elapsed 0m 0s (remain 12m 15s) Loss: -0.0005(-0.0005) 


EVAL: [100/1725] Elapsed 0m 7s (remain 1m 53s) Loss: 0.0023(0.0032) 


EVAL: [200/1725] Elapsed 0m 13s (remain 1m 43s) Loss: 0.0069(0.0044) 


EVAL: [300/1725] Elapsed 0m 20s (remain 1m 35s) Loss: 0.0007(0.0036) 


EVAL: [400/1725] Elapsed 0m 26s (remain 1m 28s) Loss: -0.0002(0.0035) 


EVAL: [500/1725] Elapsed 0m 33s (remain 1m 21s) Loss: 0.0070(0.0033) 


EVAL: [600/1725] Elapsed 0m 40s (remain 1m 15s) Loss: 0.0074(0.0033) 


EVAL: [700/1725] Elapsed 0m 46s (remain 1m 8s) Loss: 0.0078(0.0033) 


EVAL: [800/1725] Elapsed 0m 53s (remain 1m 1s) Loss: 0.0098(0.0039) 


EVAL: [900/1725] Elapsed 1m 0s (remain 0m 54s) Loss: 0.0099(0.0041) 


EVAL: [1000/1725] Elapsed 1m 6s (remain 0m 48s) Loss: -0.0004(0.0044) 


EVAL: [1100/1725] Elapsed 1m 13s (remain 0m 41s) Loss: 0.0056(0.0045) 


EVAL: [1200/1725] Elapsed 1m 19s (remain 0m 34s) Loss: 0.0051(0.0047) 


EVAL: [1300/1725] Elapsed 1m 26s (remain 0m 28s) Loss: 0.0015(0.0047) 


EVAL: [1400/1725] Elapsed 1m 33s (remain 0m 21s) Loss: 0.0093(0.0046) 


EVAL: [1500/1725] Elapsed 1m 39s (remain 0m 14s) Loss: 0.0031(0.0045) 


EVAL: [1600/1725] Elapsed 1m 46s (remain 0m 8s) Loss: 0.0052(0.0044) 


EVAL: [1700/1725] Elapsed 1m 53s (remain 0m 1s) Loss: -0.0013(0.0042) 


EVAL: [1724/1725] Elapsed 1m 54s (remain 0m 0s) Loss: 0.0026(0.0042) 


Epoch 3 - avg_train_loss: 0.0024  avg_val_loss: 0.0042  time: 5804s


Epoch 3 - Score: 0.9789


Epoch 3 - Save Best Score: 0.9789 Model


Starting training

Starting training epoch 4 for fold 0


Epoch: [4][0/6898] Elapsed 0m 1s (remain 147m 20s) Loss: 0.0000(0.0000) Grad Norm: 9.9372  LR: 9.861856246381599e-06


Epoch: [4][100/6898] Elapsed 1m 21s (remain 91m 34s) Loss: 0.0033(0.0021) Grad Norm: 8058.5752  LR: 9.860524232823562e-06


Epoch: [4][200/6898] Elapsed 2m 42s (remain 90m 16s) Loss: 0.0052(0.0023) Grad Norm: 14763.0752  LR: 9.859185919077785e-06


Epoch: [4][300/6898] Elapsed 4m 3s (remain 89m 5s) Loss: 0.0000(0.0022) Grad Norm: 53.9231  LR: 9.857841306878984e-06


Epoch: [4][400/6898] Elapsed 5m 23s (remain 87m 22s) Loss: 0.0087(0.0022) Grad Norm: 50875.2344  LR: 9.856490397970038e-06


Epoch: [4][500/6898] Elapsed 6m 44s (remain 85m 59s) Loss: 0.0013(0.0021) Grad Norm: 5480.8369  LR: 9.85513319410199e-06


Epoch: [4][600/6898] Elapsed 8m 4s (remain 84m 32s) Loss: 0.0028(0.0021) Grad Norm: 24851.9707  LR: 9.853769697034036e-06


Epoch: [4][700/6898] Elapsed 9m 25s (remain 83m 14s) Loss: 0.0020(0.0021) Grad Norm: 9998.7109  LR: 9.852399908533541e-06


Epoch: [4][800/6898] Elapsed 10m 49s (remain 82m 20s) Loss: 0.0001(0.0022) Grad Norm: 1955.9009  LR: 9.851023830376014e-06


Epoch: [4][900/6898] Elapsed 12m 9s (remain 80m 53s) Loss: 0.0000(0.0022) Grad Norm: 123.9668  LR: 9.84964146434512e-06


Epoch: [4][1000/6898] Elapsed 13m 32s (remain 79m 44s) Loss: 0.0001(0.0022) Grad Norm: 2558.2861  LR: 9.848252812232679e-06


Epoch: [4][1100/6898] Elapsed 14m 55s (remain 78m 36s) Loss: 0.0000(0.0022) Grad Norm: 238.1465  LR: 9.846857875838652e-06


Epoch: [4][1200/6898] Elapsed 16m 20s (remain 77m 29s) Loss: 0.0002(0.0022) Grad Norm: 3478.5564  LR: 9.845456656971152e-06


Epoch: [4][1300/6898] Elapsed 17m 43s (remain 76m 16s) Loss: 0.0000(0.0021) Grad Norm: 199.0674  LR: 9.844049157446425e-06


Epoch: [4][1400/6898] Elapsed 19m 5s (remain 74m 54s) Loss: 0.0023(0.0021) Grad Norm: 26379.8027  LR: 9.842635379088873e-06


Epoch: [4][1500/6898] Elapsed 20m 29s (remain 73m 39s) Loss: 0.0001(0.0021) Grad Norm: 2559.6621  LR: 9.841215323731023e-06


Epoch: [4][1600/6898] Elapsed 21m 51s (remain 72m 19s) Loss: 0.0010(0.0020) Grad Norm: 10155.3184  LR: 9.839788993213548e-06


Epoch: [4][1700/6898] Elapsed 23m 15s (remain 71m 5s) Loss: 0.0002(0.0021) Grad Norm: 934.9715  LR: 9.838356389385249e-06


Epoch: [4][1800/6898] Elapsed 24m 40s (remain 69m 51s) Loss: 0.0018(0.0021) Grad Norm: 15437.2266  LR: 9.836917514103056e-06


Epoch: [4][1900/6898] Elapsed 26m 4s (remain 68m 33s) Loss: 0.0094(0.0021) Grad Norm: 45673.3203  LR: 9.83547236923204e-06


Epoch: [4][2000/6898] Elapsed 27m 29s (remain 67m 17s) Loss: 0.0045(0.0021) Grad Norm: 23991.4062  LR: 9.834020956645386e-06


Epoch: [4][2100/6898] Elapsed 28m 55s (remain 66m 1s) Loss: 0.0029(0.0021) Grad Norm: 103938.2578  LR: 9.832563278224407e-06


Epoch: [4][2200/6898] Elapsed 30m 19s (remain 64m 43s) Loss: 0.0022(0.0021) Grad Norm: 43361.9727  LR: 9.831099335858542e-06


Epoch: [4][2300/6898] Elapsed 31m 44s (remain 63m 24s) Loss: 0.0000(0.0020) Grad Norm: 127.2610  LR: 9.829629131445342e-06


Epoch: [4][2400/6898] Elapsed 33m 9s (remain 62m 5s) Loss: 0.0012(0.0020) Grad Norm: 29916.1387  LR: 9.828152666890482e-06


Epoch: [4][2500/6898] Elapsed 34m 35s (remain 60m 49s) Loss: 0.0019(0.0020) Grad Norm: 14016.8164  LR: 9.826669944107747e-06


Epoch: [4][2600/6898] Elapsed 36m 0s (remain 59m 29s) Loss: 0.0029(0.0021) Grad Norm: 71156.4609  LR: 9.825180965019035e-06


Epoch: [4][2700/6898] Elapsed 37m 24s (remain 58m 7s) Loss: 0.0015(0.0021) Grad Norm: 21004.0703  LR: 9.823685731554353e-06


Epoch: [4][2800/6898] Elapsed 38m 48s (remain 56m 45s) Loss: 0.0022(0.0021) Grad Norm: 22464.6973  LR: 9.822184245651817e-06


Epoch: [4][2900/6898] Elapsed 40m 11s (remain 55m 22s) Loss: 0.0032(0.0021) Grad Norm: 8307.8115  LR: 9.820676509257641e-06


Epoch: [4][3000/6898] Elapsed 41m 35s (remain 54m 1s) Loss: 0.0001(0.0021) Grad Norm: 6487.0615  LR: 9.81916252432615e-06


Epoch: [4][3100/6898] Elapsed 42m 59s (remain 52m 38s) Loss: 0.0008(0.0021) Grad Norm: 26762.3574  LR: 9.817642292819766e-06


Epoch: [4][3200/6898] Elapsed 44m 23s (remain 51m 15s) Loss: 0.0000(0.0021) Grad Norm: 239.4819  LR: 9.816115816709e-06


Epoch: [4][3300/6898] Elapsed 45m 43s (remain 49m 48s) Loss: 0.0020(0.0021) Grad Norm: 21894.1387  LR: 9.814583097972465e-06


Epoch: [4][3400/6898] Elapsed 47m 4s (remain 48m 23s) Loss: 0.0000(0.0021) Grad Norm: 805.0276  LR: 9.813044138596865e-06


Epoch: [4][3500/6898] Elapsed 48m 24s (remain 46m 58s) Loss: 0.0146(0.0021) Grad Norm: 154631.8906  LR: 9.81149894057699e-06


Epoch: [4][3600/6898] Elapsed 49m 44s (remain 45m 32s) Loss: 0.0016(0.0021) Grad Norm: 43181.9883  LR: 9.809947505915718e-06


Epoch: [4][3700/6898] Elapsed 51m 5s (remain 44m 8s) Loss: 0.0024(0.0021) Grad Norm: 17133.0371  LR: 9.808389836624013e-06


Epoch: [4][3800/6898] Elapsed 52m 26s (remain 42m 43s) Loss: 0.0000(0.0021) Grad Norm: 226.1469  LR: 9.806825934720916e-06


Epoch: [4][3900/6898] Elapsed 53m 48s (remain 41m 20s) Loss: 0.0001(0.0021) Grad Norm: 1869.7736  LR: 9.80525580223355e-06


Epoch: [4][4000/6898] Elapsed 55m 8s (remain 39m 55s) Loss: 0.0075(0.0021) Grad Norm: 228850.7344  LR: 9.803679441197112e-06


Epoch: [4][4100/6898] Elapsed 56m 27s (remain 38m 30s) Loss: 0.0001(0.0021) Grad Norm: 4077.8914  LR: 9.802096853654876e-06


Epoch: [4][4200/6898] Elapsed 57m 46s (remain 37m 5s) Loss: 0.0013(0.0021) Grad Norm: 58771.7773  LR: 9.800508041658182e-06


Epoch: [4][4300/6898] Elapsed 59m 6s (remain 35m 41s) Loss: 0.0055(0.0021) Grad Norm: 113884.1484  LR: 9.79891300726644e-06


Epoch: [4][4400/6898] Elapsed 60m 30s (remain 34m 19s) Loss: 0.0001(0.0021) Grad Norm: 11356.9111  LR: 9.797311752547129e-06


Epoch: [4][4500/6898] Elapsed 61m 52s (remain 32m 56s) Loss: 0.0000(0.0021) Grad Norm: 96.8502  LR: 9.795704279575783e-06


Epoch: [4][4600/6898] Elapsed 63m 13s (remain 31m 33s) Loss: 0.0054(0.0021) Grad Norm: 394354.5625  LR: 9.794090590436006e-06


Epoch: [4][4700/6898] Elapsed 64m 35s (remain 30m 11s) Loss: 0.0000(0.0021) Grad Norm: 40.0782  LR: 9.792470687219448e-06


Epoch: [4][4800/6898] Elapsed 66m 1s (remain 28m 50s) Loss: 0.0181(0.0021) Grad Norm: 120726.6406  LR: 9.790844572025823e-06


Epoch: [4][4900/6898] Elapsed 67m 24s (remain 27m 28s) Loss: 0.0019(0.0021) Grad Norm: 66487.7500  LR: 9.789212246962893e-06


Epoch: [4][5000/6898] Elapsed 68m 48s (remain 26m 6s) Loss: 0.0007(0.0021) Grad Norm: 75419.0078  LR: 9.787573714146472e-06


Epoch: [4][5100/6898] Elapsed 70m 12s (remain 24m 43s) Loss: 0.0000(0.0021) Grad Norm: 290.4821  LR: 9.785928975700414e-06


Epoch: [4][5200/6898] Elapsed 71m 36s (remain 23m 21s) Loss: 0.0001(0.0021) Grad Norm: 17887.9180  LR: 9.784278033756623e-06


Epoch: [4][5300/6898] Elapsed 73m 0s (remain 21m 59s) Loss: 0.0067(0.0021) Grad Norm: 273493.0938  LR: 9.78262089045504e-06


Epoch: [4][5400/6898] Elapsed 74m 22s (remain 20m 37s) Loss: 0.0000(0.0021) Grad Norm: 106.0208  LR: 9.780957547943653e-06


Epoch: [4][5500/6898] Elapsed 75m 46s (remain 19m 14s) Loss: 0.0000(0.0021) Grad Norm: 1328.1692  LR: 9.779288008378469e-06


Epoch: [4][5600/6898] Elapsed 77m 7s (remain 17m 51s) Loss: 0.0009(0.0021) Grad Norm: 121603.7500  LR: 9.777612273923544e-06


Epoch: [4][5700/6898] Elapsed 78m 30s (remain 16m 29s) Loss: 0.0000(0.0021) Grad Norm: 1623.0942  LR: 9.775930346750953e-06


Epoch: [4][5800/6898] Elapsed 79m 50s (remain 15m 5s) Loss: 0.0000(0.0021) Grad Norm: 14.7922  LR: 9.774242229040803e-06


Epoch: [4][5900/6898] Elapsed 81m 10s (remain 13m 42s) Loss: 0.0013(0.0021) Grad Norm: 53775.1562  LR: 9.772547922981223e-06


Epoch: [4][6000/6898] Elapsed 82m 34s (remain 12m 20s) Loss: 0.0019(0.0021) Grad Norm: 395896.4375  LR: 9.770847430768366e-06


Epoch: [4][6100/6898] Elapsed 83m 58s (remain 10m 58s) Loss: 0.0016(0.0021) Grad Norm: 218470.3906  LR: 9.769140754606401e-06


Epoch: [4][6200/6898] Elapsed 85m 21s (remain 9m 35s) Loss: 0.0000(0.0021) Grad Norm: 4561.6328  LR: 9.767427896707512e-06


Epoch: [4][6300/6898] Elapsed 86m 45s (remain 8m 13s) Loss: 0.0000(0.0021) Grad Norm: 3448.0276  LR: 9.7657088592919e-06


Epoch: [4][6400/6898] Elapsed 88m 11s (remain 6m 50s) Loss: 0.0020(0.0021) Grad Norm: 241821.4062  LR: 9.763983644587768e-06


Epoch: [4][6500/6898] Elapsed 89m 34s (remain 5m 28s) Loss: 0.0018(0.0021) Grad Norm: 167694.9062  LR: 9.762252254831338e-06


Epoch: [4][6600/6898] Elapsed 90m 55s (remain 4m 5s) Loss: 0.0053(0.0021) Grad Norm: 326622.4688  LR: 9.760514692266823e-06


Epoch: [4][6700/6898] Elapsed 92m 17s (remain 2m 42s) Loss: 0.0000(0.0020) Grad Norm: 245.5715  LR: 9.758770959146443e-06


Epoch: [4][6800/6898] Elapsed 93m 40s (remain 1m 20s) Loss: 0.0000(0.0020) Grad Norm: 1997.6903  LR: 9.75702105773042e-06


Epoch: [4][6897/6898] Elapsed 95m 2s (remain 0m 0s) Loss: 0.0017(0.0021) Grad Norm: 218906.9844  LR: 9.755317762004239e-06
Epoch 4 completed. Average loss: 0.0021
Starting evaluation
Starting validation...


EVAL: [0/1725] Elapsed 0m 0s (remain 11m 34s) Loss: -0.0009(-0.0009) 


EVAL: [100/1725] Elapsed 0m 7s (remain 1m 52s) Loss: 0.0055(0.0042) 


EVAL: [200/1725] Elapsed 0m 13s (remain 1m 43s) Loss: 0.0045(0.0049) 


EVAL: [300/1725] Elapsed 0m 20s (remain 1m 35s) Loss: 0.0017(0.0040) 


EVAL: [400/1725] Elapsed 0m 26s (remain 1m 28s) Loss: -0.0005(0.0040) 


EVAL: [500/1725] Elapsed 0m 33s (remain 1m 21s) Loss: 0.0011(0.0037) 


EVAL: [600/1725] Elapsed 0m 40s (remain 1m 14s) Loss: 0.0129(0.0036) 


EVAL: [700/1725] Elapsed 0m 46s (remain 1m 8s) Loss: 0.0059(0.0036) 


EVAL: [800/1725] Elapsed 0m 53s (remain 1m 1s) Loss: 0.0091(0.0041) 


EVAL: [900/1725] Elapsed 0m 59s (remain 0m 54s) Loss: 0.0028(0.0043) 


EVAL: [1000/1725] Elapsed 1m 6s (remain 0m 48s) Loss: 0.0020(0.0045) 


EVAL: [1100/1725] Elapsed 1m 13s (remain 0m 41s) Loss: 0.0004(0.0047) 


EVAL: [1200/1725] Elapsed 1m 19s (remain 0m 34s) Loss: 0.0061(0.0048) 


EVAL: [1300/1725] Elapsed 1m 26s (remain 0m 28s) Loss: 0.0003(0.0048) 


EVAL: [1400/1725] Elapsed 1m 33s (remain 0m 21s) Loss: 0.0079(0.0047) 


EVAL: [1500/1725] Elapsed 1m 39s (remain 0m 14s) Loss: 0.0028(0.0047) 


EVAL: [1600/1725] Elapsed 1m 46s (remain 0m 8s) Loss: 0.0041(0.0047) 


EVAL: [1700/1725] Elapsed 1m 53s (remain 0m 1s) Loss: 0.0029(0.0045) 


EVAL: [1724/1725] Elapsed 1m 54s (remain 0m 0s) Loss: 0.0029(0.0045) 


Epoch 4 - avg_train_loss: 0.0021  avg_val_loss: 0.0045  time: 5828s


Epoch 4 - Score: 0.9790


Epoch 4 - Save Best Score: 0.9790 Model




Score: 0.9790




Score: 0.9790


[34m[1mwandb[0m:                                                                                


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m: [fold0] avg_train_loss █▄▂▁
[34m[1mwandb[0m:   [fold0] avg_val_loss ▁▁▆█
[34m[1mwandb[0m:          [fold0] epoch ▁▃▆█
[34m[1mwandb[0m:          [fold0] score ▁▆▇█
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m: [fold0] avg_train_loss 0.00205
[34m[1mwandb[0m:   [fold0] avg_val_loss 0.00446
[34m[1mwandb[0m:          [fold0] epoch 4
[34m[1mwandb[0m:          [fold0] score 0.97902
[34m[1mwandb[0m: 


[34m[1mwandb[0m: 🚀 View run [33m/workspace/deberta-v3-large[0m at: [34m[4mhttps://wandb.ai/anony-moose-628529071760834444/NBME-Public/runs/kl09y3t4?apiKey=e17c3e277c305c3106a5fedd7c64321c1329fc02[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/anony-moose-628529071760834444/NBME-Public?apiKey=e17c3e277c305c3106a5fedd7c64321c1329fc02[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20250211_210932-kl09y3t4/logs[0m
