# Training deberta-v3-large model to generate high quality pseudolabels

Initially deberta-v3-large model raises OOM error. 

Optimized the training pipeline via Pytorch
- Gradient accumulation
- Gradient checkpoint
- Using bfloat data types
- Autocast

1 fold (with 4-5 epochs) takes around ~5 hours, each epoch ~50 minutes. %87+ accuracy. (See notebook version 6)

Another approach would be to use [huggingface accelerator](https://huggingface.co/docs/accelerate/en/index) as it automates this process without the need to customize manually as well as distribution of training while using multiple GPUs.

# Why Deberta-v3-large?
![Comparison of BERT variants in 2025](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/modernbert/modernbert_pareto_curve.png)

# Directory settings

In [1]:
# ====================================================
# Directory settings
# ====================================================
import os, shutil

OUTPUT_DIR = os.path.join(os.path.dirname(os.getcwd()), 'deberta-v3-large-finetuned-with-pseudolabels/')
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

for item in os.listdir(OUTPUT_DIR):
    item_path = os.path.join(OUTPUT_DIR, item)
    if os.path.isfile(item_path):  # If it's a file, remove it
        os.remove(item_path)
        print(f"Removed file: {item}")
    elif os.path.isdir(item_path):  # If it's a directory, remove it
        shutil.rmtree(item_path)
        print(f"Removed folder: {item}")

# Library

In [2]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset

os.system('pip uninstall -y transformers')
os.system('python -m pip install transformers>=4.48.0')
os.system('pip uninstall -y tokenizer')
os.system('python -m pip install tokenizers')
import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Found existing installation: transformers 4.48.3
Uninstalling transformers-4.48.3:


  Successfully uninstalled transformers-4.48.3


[0m

[0m

[0m



[0m

tokenizers.__version__: 0.21.0
transformers.__version__: 4.48.3


env: TOKENIZERS_PARALLELISM=true


In [3]:
# ====================================================
# CFG
# ====================================================
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" # if running on Kaggle, use this path 
    model = "/workspace/deberta-v3-large" # if running on local machine, use this path  
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=1 #0.05
    num_warmup_steps=0
    epochs=5
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 12 #batch_size=16
    fc_dropout=0.2
    max_len=512
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=2 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True
    
if CFG.debug:
    CFG.epochs = 3
    CFG.trn_fold = [0]

# ====================================================
# tokenizer
# ====================================================
tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

In [4]:
# ====================================================
# wandb: local notebook login
# ====================================================
if CFG.wandb:
    
    import wandb
    from dotenv import load_dotenv
    import os

    # Load environment variables from .env file in the parent folder
    load_dotenv(os.path.join(os.path.dirname(os.getcwd()), '.env'))

    secret_value_0 = os.getenv("WANDB_API_KEY")
    if secret_value_0:
        wandb.login(key=secret_value_0)
        anony = None
    else:
        anony = "must"
        print('If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. \nGet your W&B access token from here: https://wandb.ai/authorize')

    def class2dict(f):
        return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

    run = wandb.init(project='NBME-Public', 
                     name=CFG.model,
                     config=class2dict(CFG),
                     group=CFG.model,
                     job_type="train",
                     anonymous=anony)



[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


If you want to use your W&B account, create a .env file in the parent folder and provide your W&B access token as WANDB_API_KEY. 
Get your W&B access token from here: https://wandb.ai/authorize


[34m[1mwandb[0m: Currently logged in as: [33manony-moose-628529071760834444[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Tracking run with wandb version 0.19.6


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/workspace/wandb/run-20250211_183053-v5xaqz69[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33m/workspace/deberta-v3-large[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/anony-moose-628529071760834444/NBME-Public?apiKey=e17c3e277c305c3106a5fedd7c64321c1329fc02[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/anony-moose-628529071760834444/NBME-Public/runs/v5xaqz69?apiKey=e17c3e277c305c3106a5fedd7c64321c1329fc02[0m




# Helper functions for scoring

In [5]:
# From https://www.kaggle.com/theoviel/evaluation-metric-folds-baseline

def micro_f1(preds, truths):
    """
    Micro f1 on binary arrays.

    Args:
        preds (list of lists of ints): Predictions.
        truths (list of lists of ints): Ground truths.

    Returns:
        float: f1 score.
    """
    # Micro : aggregating over all instances
    preds = np.concatenate(preds)
    truths = np.concatenate(truths)
    return f1_score(truths, preds)


def spans_to_binary(spans, length=None):
    """
    Converts spans to a binary array indicating whether each character is in the span.

    Args:
        spans (list of lists of two ints): Spans.

    Returns:
        np array [length]: Binarized spans.
    """
    length = np.max(spans) if length is None else length
    binary = np.zeros(length)
    for start, end in spans:
        binary[start:end] = 1
    return binary


def span_micro_f1(preds, truths):
    """
    Micro f1 on spans.

    Args:
        preds (list of lists of two ints): Prediction spans.
        truths (list of lists of two ints): Ground truth spans.

    Returns:
        float: f1 score.
    """
    bin_preds = []
    bin_truths = []
    for pred, truth in zip(preds, truths):
        if not len(pred) and not len(truth):
            continue
        length = max(np.max(pred) if len(pred) else 0, np.max(truth) if len(truth) else 0)
        bin_preds.append(spans_to_binary(pred, length))
        bin_truths.append(spans_to_binary(truth, length))
    return micro_f1(bin_preds, bin_truths)

In [6]:
## UPDATED FOR HANDLING NONEs in THE PSEUDOLABELS / only including confident inferences
import ast

# Update the create_labels_for_scoring function
def create_labels_for_scoring(df):
    # Initialize the 'location_for_create_labels' column with empty lists
    df['location_for_create_labels'] = [ast.literal_eval(f'[]')] * len(df)
    
    for i in range(len(df)):
        lst = df.loc[i, 'location']
        if lst and lst != '':  # Check if lst is not None or empty
            # Ensure lst is a list of strings
            if isinstance(lst, str):
                lst = [lst]
            elif isinstance(lst, list):
                lst = [str(item) for item in lst]
            new_lst = ';'.join(lst)
            df.loc[i, 'location_for_create_labels'] = ast.literal_eval(f'[["{new_lst}"]]')

    # Create labels
    truths = []
    for location_list in df['location_for_create_labels'].values:
        truth = []
        if len(location_list) > 0:
            location = location_list[0]
            for loc in [s.split() for s in location.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                truth.append([start, end])
        truths.append(truth)
    
    return truths


def get_char_probs(texts, predictions, tokenizer):
    results = [np.zeros(len(t)) for t in texts]
    for i, (text, prediction) in enumerate(zip(texts, predictions)):
        encoded = tokenizer(text, 
                            add_special_tokens=True,
                            return_offsets_mapping=True)
        for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)):
            start = offset_mapping[0]
            end = offset_mapping[1]
            results[i][start:end] = pred
    return results


def get_results(char_probs, th=0.5):
    results = []
    for char_prob in char_probs:
        result = np.where(char_prob >= th)[0] + 1
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results


def get_predictions(results):
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    return predictions

# Utils

In [7]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    score = span_micro_f1(y_true, y_pred)
    return score


def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

# Data Loading

In [8]:
# ====================================================
# Data Loading
# ====================================================
train = pd.read_csv('/workspace/data/train.csv')
train['annotation'] = train['annotation'].apply(ast.literal_eval)
train['location'] = train['location'].apply(ast.literal_eval)
features = pd.read_csv('/workspace/data/features.csv')
def preprocess_features(features):
    features.loc[27, 'feature_text'] = "Last-Pap-smear-1-year-ago"
    return features
features = preprocess_features(features)
patient_notes = pd.read_csv('/workspace/data/patient_notes.csv')

print(f"train.shape: {train.shape}")
display(train.head())
print(f"features.shape: {features.shape}")
display(features.head())
print(f"patient_notes.shape: {patient_notes.shape}")
display(patient_notes.head())


train.shape: (14300, 6)


Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724]
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693]
2,00016_002,0,16,2,[chest pressure],[203 217]
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]"
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258]


features.shape: (143, 3)


Unnamed: 0,feature_num,case_num,feature_text
0,0,0,Family-history-of-MI-OR-Family-history-of-myoc...
1,1,0,Family-history-of-thyroid-disorder
2,2,0,Chest-pressure
3,3,0,Intermittent-symptoms
4,4,0,Lightheaded


patient_notes.shape: (42146, 3)


Unnamed: 0,pn_num,case_num,pn_history
0,0,0,"17-year-old male, has come to the student heal..."
1,1,0,17 yo male with recurrent palpitations for the...
2,2,0,Dillon Cleveland is a 17 y.o. male patient wit...
3,3,0,a 17 yo m c/o palpitation started 3 mos ago; \...
4,4,0,17yo male with no pmh here for evaluation of p...


In [9]:
train['annotation_length'] = train['annotation'].apply(len) # refers to number of annotation
train

Unnamed: 0,id,case_num,pn_num,feature_num,annotation,location,annotation_length
0,00016_000,0,16,0,[dad with recent heart attcak],[696 724],1
1,00016_001,0,16,1,"[mom with ""thyroid disease]",[668 693],1
2,00016_002,0,16,2,[chest pressure],[203 217],1
3,00016_003,0,16,3,"[intermittent episodes, episode]","[70 91, 176 183]",2
4,00016_004,0,16,4,[felt as if he were going to pass out],[222 258],1
...,...,...,...,...,...,...,...
14295,95333_912,9,95333,912,[],[],0
14296,95333_913,9,95333,913,[],[],0
14297,95333_914,9,95333,914,[photobia],[274 282],1
14298,95333_915,9,95333,915,[no sick contacts],[421 437],1


In [10]:
pseudolabels = pd.read_csv('/workspace/data/pseudolabels.csv')
pseudolabels =  pseudolabels.merge(features, on=['feature_num', 'case_num'], how='left')
pseudolabels = pseudolabels.merge(patient_notes, on=['pn_num', 'case_num'], how='left')

pseudolabels = pseudolabels.dropna(subset=['location'])
pseudolabels = pseudolabels[pseudolabels['location'].apply(lambda x: len(x) > 0)]

def convert_location_format(location):
    if pd.isna(location) or not location:
        return []
    if isinstance(location, str):
        return [loc.strip() for loc in location.split(';')]
    return location

pseudolabels['location'] = pseudolabels['location'].apply(convert_location_format)

pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal..."
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...
...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...


In [11]:
pseudolabels['annotation_length'] = pseudolabels['location'].apply(len)

In [12]:
pseudolabels

Unnamed: 0,id,case_num,pn_num,feature_num,fold,location,feature_text,pn_history,annotation_length
0,00000_006,0,0,6,1,[521 526],Adderall-use,"17-year-old male, has come to the student heal...",1
1,00001_004,0,1,4,1,[179 195],Lightheaded,17 yo male with recurrent palpitations for the...,1
2,00001_006,0,1,6,1,[347 353],Adderall-use,17 yo male with recurrent palpitations for the...,1
3,00001_007,0,1,7,1,[220 235],Shortness-of-breath,17 yo male with recurrent palpitations for the...,1
4,00001_011,0,1,11,1,[1 5],17-year,17 yo male with recurrent palpitations for the...,1
...,...,...,...,...,...,...,...,...,...
122512,95331_912,9,95331,912,1,"[283 302, 401 421]",Family-history-of-migraines,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,2
122513,95331_913,9,95331,913,1,[8 9],Female,A 20 YO F CAME COMPLAIN A DULL 8/10 HEADACHE T...,1
122516,95332_906,9,95332,906,1,[319 327],Vomiting,Ms. Madden is a 20yo female who presents with ...,1
122518,95334_901,9,95334,901,1,[13 18],20-year,patient is a 20 yo F who presents with a heada...,1


In [13]:
train = pseudolabels
# drop fold column
train = train.drop(columns=['fold'])

# CV split

In [14]:
# ====================================================
# CV split
# ====================================================
train = train.reset_index(drop=True)

Fold = GroupKFold(n_splits=CFG.n_fold)
groups = train['pn_num'].values
for n, (train_index, val_index) in enumerate(Fold.split(train, train['location'], groups)):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)
display(train.groupby('fold').size())

fold
0    17247
1    17247
2    17246
3    17246
4    17246
dtype: int64

In [15]:
if CFG.debug:
    display(train.groupby('fold').size())
    train = train.sample(n=1000, random_state=0).reset_index(drop=True)
    display(train.groupby('fold').size())

# Dataset

In [16]:
# ====================================================
# Define max_len
# ====================================================
for text_col in ['pn_history']:
    pn_history_lengths = []
    tk0 = tqdm(patient_notes[text_col].fillna("").values, total=len(patient_notes))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        pn_history_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(pn_history_lengths)}')

for text_col in ['feature_text']:
    features_lengths = []
    tk0 = tqdm(features[text_col].fillna("").values, total=len(features))
    for text in tk0:
        length = len(tokenizer(text, add_special_tokens=False)['input_ids'])
        features_lengths.append(length)
    LOGGER.info(f'{text_col} max(lengths): {max(features_lengths)}')

CFG.max_len = max(pn_history_lengths) + max(features_lengths) + 3 # cls & sep & sep
LOGGER.info(f"max_len: {CFG.max_len}")

  0%|          | 0/42146 [00:00<?, ?it/s]

  0%|          | 197/42146 [00:00<00:21, 1966.38it/s]

  1%|          | 481/42146 [00:00<00:16, 2473.97it/s]

  2%|▏         | 781/42146 [00:00<00:15, 2712.39it/s]

  3%|▎         | 1083/42146 [00:00<00:14, 2831.52it/s]

  3%|▎         | 1373/42146 [00:00<00:14, 2854.81it/s]

  4%|▍         | 1659/42146 [00:00<00:14, 2844.50it/s]

  5%|▍         | 1945/42146 [00:00<00:14, 2849.24it/s]

  5%|▌         | 2237/42146 [00:00<00:13, 2869.62it/s]

  6%|▌         | 2524/42146 [00:00<00:14, 2740.14it/s]

  7%|▋         | 2800/42146 [00:01<00:14, 2652.51it/s]

  7%|▋         | 3067/42146 [00:01<00:15, 2603.86it/s]

  8%|▊         | 3329/42146 [00:01<00:15, 2465.43it/s]

  8%|▊         | 3578/42146 [00:01<00:16, 2372.93it/s]

  9%|▉         | 3817/42146 [00:01<00:16, 2318.60it/s]

 10%|▉         | 4050/42146 [00:01<00:16, 2284.59it/s]

 10%|█         | 4279/42146 [00:01<00:16, 2256.77it/s]

 11%|█         | 4505/42146 [00:01<00:16, 2227.05it/s]

 11%|█         | 4728/42146 [00:01<00:16, 2208.88it/s]

 12%|█▏        | 4949/42146 [00:02<00:16, 2198.10it/s]

 12%|█▏        | 5177/42146 [00:02<00:16, 2219.46it/s]

 13%|█▎        | 5415/42146 [00:02<00:16, 2265.90it/s]

 13%|█▎        | 5653/42146 [00:02<00:15, 2298.62it/s]

 14%|█▍        | 5896/42146 [00:02<00:15, 2333.88it/s]

 15%|█▍        | 6136/42146 [00:02<00:15, 2352.80it/s]

 15%|█▌        | 6372/42146 [00:02<00:15, 2348.09it/s]

 16%|█▌        | 6611/42146 [00:02<00:15, 2357.93it/s]

 16%|█▋        | 6855/42146 [00:02<00:14, 2381.08it/s]

 17%|█▋        | 7094/42146 [00:02<00:14, 2380.54it/s]

 17%|█▋        | 7333/42146 [00:03<00:14, 2379.55it/s]

 18%|█▊        | 7571/42146 [00:03<00:14, 2372.73it/s]

 19%|█▊        | 7809/42146 [00:03<00:14, 2303.92it/s]

 19%|█▉        | 8040/42146 [00:03<00:15, 2244.96it/s]

 20%|█▉        | 8266/42146 [00:03<00:15, 2248.43it/s]

 20%|██        | 8492/42146 [00:03<00:15, 2239.96it/s]

 21%|██        | 8717/42146 [00:03<00:14, 2234.18it/s]

 21%|██        | 8952/42146 [00:03<00:14, 2266.63it/s]

 22%|██▏       | 9179/42146 [00:03<00:14, 2262.29it/s]

 22%|██▏       | 9412/42146 [00:03<00:14, 2281.07it/s]

 23%|██▎       | 9647/42146 [00:04<00:14, 2298.14it/s]

 23%|██▎       | 9881/42146 [00:04<00:13, 2308.93it/s]

 24%|██▍       | 10116/42146 [00:04<00:13, 2319.10it/s]

 25%|██▍       | 10352/42146 [00:04<00:13, 2329.88it/s]

 25%|██▌       | 10587/42146 [00:04<00:13, 2333.35it/s]

 26%|██▌       | 10823/42146 [00:04<00:13, 2340.03it/s]

 26%|██▌       | 11063/42146 [00:04<00:13, 2355.86it/s]

 27%|██▋       | 11301/42146 [00:04<00:13, 2361.16it/s]

 27%|██▋       | 11543/42146 [00:04<00:12, 2378.03it/s]

 28%|██▊       | 11787/42146 [00:04<00:12, 2394.00it/s]

 29%|██▊       | 12027/42146 [00:05<00:12, 2391.67it/s]

 29%|██▉       | 12267/42146 [00:05<00:12, 2393.72it/s]

 30%|██▉       | 12507/42146 [00:05<00:12, 2386.44it/s]

 30%|███       | 12746/42146 [00:05<00:12, 2371.51it/s]

 31%|███       | 12984/42146 [00:05<00:12, 2370.95it/s]

 31%|███▏      | 13222/42146 [00:05<00:12, 2360.77it/s]

 32%|███▏      | 13459/42146 [00:05<00:12, 2360.38it/s]

 32%|███▏      | 13696/42146 [00:05<00:12, 2358.14it/s]

 33%|███▎      | 13936/42146 [00:05<00:11, 2369.24it/s]

 34%|███▎      | 14174/42146 [00:05<00:11, 2371.56it/s]

 34%|███▍      | 14412/42146 [00:06<00:11, 2359.69it/s]

 35%|███▍      | 14651/42146 [00:06<00:11, 2367.05it/s]

 35%|███▌      | 14888/42146 [00:06<00:11, 2367.19it/s]

 36%|███▌      | 15125/42146 [00:06<00:11, 2352.74it/s]

 36%|███▋      | 15361/42146 [00:06<00:11, 2255.25it/s]

 37%|███▋      | 15591/42146 [00:06<00:11, 2265.85it/s]

 38%|███▊      | 15823/42146 [00:06<00:11, 2281.57it/s]

 38%|███▊      | 16056/42146 [00:06<00:11, 2295.69it/s]

 39%|███▊      | 16287/42146 [00:06<00:11, 2298.47it/s]

 39%|███▉      | 16518/42146 [00:06<00:11, 2300.21it/s]

 40%|███▉      | 16749/42146 [00:07<00:11, 2289.48it/s]

 40%|████      | 16981/42146 [00:07<00:10, 2296.72it/s]

 41%|████      | 17215/42146 [00:07<00:10, 2308.87it/s]

 41%|████▏     | 17450/42146 [00:07<00:10, 2319.65it/s]

 42%|████▏     | 17683/42146 [00:07<00:10, 2319.34it/s]

 43%|████▎     | 17915/42146 [00:07<00:10, 2316.94it/s]

 43%|████▎     | 18151/42146 [00:07<00:10, 2328.69it/s]

 44%|████▎     | 18384/42146 [00:07<00:10, 2312.97it/s]

 44%|████▍     | 18616/42146 [00:07<00:10, 2313.13it/s]

 45%|████▍     | 18848/42146 [00:07<00:10, 2313.84it/s]

 45%|████▌     | 19080/42146 [00:08<00:10, 2293.39it/s]

 46%|████▌     | 19310/42146 [00:08<00:09, 2290.83it/s]

 46%|████▋     | 19542/42146 [00:08<00:09, 2297.78it/s]

 47%|████▋     | 19772/42146 [00:08<00:09, 2292.93it/s]

 47%|████▋     | 20002/42146 [00:08<00:09, 2293.99it/s]

 48%|████▊     | 20234/42146 [00:08<00:09, 2300.76it/s]

 49%|████▊     | 20480/42146 [00:08<00:09, 2346.97it/s]

 49%|████▉     | 20731/42146 [00:08<00:08, 2395.61it/s]

 50%|████▉     | 20977/42146 [00:08<00:08, 2412.25it/s]

 50%|█████     | 21224/42146 [00:08<00:08, 2428.95it/s]

 51%|█████     | 21470/42146 [00:09<00:08, 2436.75it/s]

 52%|█████▏    | 21715/42146 [00:09<00:08, 2439.07it/s]

 52%|█████▏    | 21965/42146 [00:09<00:08, 2454.90it/s]

 53%|█████▎    | 22211/42146 [00:09<00:08, 2452.42it/s]

 53%|█████▎    | 22457/42146 [00:09<00:08, 2452.45it/s]

 54%|█████▍    | 22704/42146 [00:09<00:07, 2455.90it/s]

 54%|█████▍    | 22950/42146 [00:09<00:07, 2449.75it/s]

 55%|█████▌    | 23196/42146 [00:09<00:07, 2451.18it/s]

 56%|█████▌    | 23442/42146 [00:09<00:07, 2430.99it/s]

 56%|█████▌    | 23690/42146 [00:09<00:07, 2443.82it/s]

 57%|█████▋    | 23936/42146 [00:10<00:07, 2448.01it/s]

 57%|█████▋    | 24190/42146 [00:10<00:07, 2472.88it/s]

 58%|█████▊    | 24446/42146 [00:10<00:07, 2498.65it/s]

 59%|█████▊    | 24703/42146 [00:10<00:06, 2518.39it/s]

 59%|█████▉    | 24958/42146 [00:10<00:06, 2527.78it/s]

 60%|█████▉    | 25218/42146 [00:10<00:06, 2548.80it/s]

 60%|██████    | 25474/42146 [00:10<00:06, 2551.03it/s]

 61%|██████    | 25741/42146 [00:10<00:06, 2584.70it/s]

 62%|██████▏   | 26000/42146 [00:10<00:06, 2569.83it/s]

 62%|██████▏   | 26263/42146 [00:10<00:06, 2585.83it/s]

 63%|██████▎   | 26522/42146 [00:11<00:06, 2544.95it/s]

 64%|██████▎   | 26777/42146 [00:11<00:06, 2506.21it/s]

 64%|██████▍   | 27028/42146 [00:11<00:06, 2485.67it/s]

 65%|██████▍   | 27277/42146 [00:11<00:06, 2449.94it/s]

 65%|██████▌   | 27525/42146 [00:11<00:05, 2456.88it/s]

 66%|██████▌   | 27772/42146 [00:11<00:05, 2460.36it/s]

 66%|██████▋   | 28022/42146 [00:11<00:05, 2471.26it/s]

 67%|██████▋   | 28270/42146 [00:11<00:05, 2428.78it/s]

 68%|██████▊   | 28514/42146 [00:11<00:05, 2411.10it/s]

 68%|██████▊   | 28756/42146 [00:12<00:05, 2363.12it/s]

 69%|██████▉   | 28993/42146 [00:12<00:05, 2284.02it/s]

 69%|██████▉   | 29222/42146 [00:12<00:05, 2257.47it/s]

 70%|██████▉   | 29449/42146 [00:12<00:05, 2248.67it/s]

 70%|███████   | 29677/42146 [00:12<00:05, 2255.82it/s]

 71%|███████   | 29903/42146 [00:12<00:05, 2239.48it/s]

 71%|███████▏  | 30128/42146 [00:12<00:05, 2205.12it/s]

 72%|███████▏  | 30349/42146 [00:12<00:05, 2186.42it/s]

 73%|███████▎  | 30569/42146 [00:12<00:05, 2189.10it/s]

 73%|███████▎  | 30791/42146 [00:12<00:05, 2196.00it/s]

 74%|███████▎  | 31011/42146 [00:13<00:05, 2142.46it/s]

 74%|███████▍  | 31226/42146 [00:13<00:05, 2142.67it/s]

 75%|███████▍  | 31441/42146 [00:13<00:05, 2139.40it/s]

 75%|███████▌  | 31662/42146 [00:13<00:04, 2159.02it/s]

 76%|███████▌  | 31890/42146 [00:13<00:04, 2194.09it/s]

 76%|███████▌  | 32116/42146 [00:13<00:04, 2211.05it/s]

 77%|███████▋  | 32338/42146 [00:13<00:04, 2195.75it/s]

 77%|███████▋  | 32558/42146 [00:13<00:04, 2193.57it/s]

 78%|███████▊  | 32778/42146 [00:13<00:04, 2164.18it/s]

 78%|███████▊  | 32995/42146 [00:13<00:04, 2140.84it/s]

 79%|███████▉  | 33210/42146 [00:14<00:04, 2104.36it/s]

 79%|███████▉  | 33421/42146 [00:14<00:04, 2081.55it/s]

 80%|███████▉  | 33630/42146 [00:14<00:04, 2083.88it/s]

 80%|████████  | 33839/42146 [00:14<00:04, 2069.03it/s]

 81%|████████  | 34050/42146 [00:14<00:03, 2080.34it/s]

 81%|████████▏ | 34262/42146 [00:14<00:03, 2091.71it/s]

 82%|████████▏ | 34472/42146 [00:14<00:03, 2084.14it/s]

 82%|████████▏ | 34686/42146 [00:14<00:03, 2098.57it/s]

 83%|████████▎ | 34898/42146 [00:14<00:03, 2102.62it/s]

 83%|████████▎ | 35110/42146 [00:14<00:03, 2107.23it/s]

 84%|████████▍ | 35323/42146 [00:15<00:03, 2113.05it/s]

 84%|████████▍ | 35535/42146 [00:15<00:03, 2100.81it/s]

 85%|████████▍ | 35747/42146 [00:15<00:03, 2104.75it/s]

 85%|████████▌ | 35958/42146 [00:15<00:02, 2100.78it/s]

 86%|████████▌ | 36169/42146 [00:15<00:02, 2096.55it/s]

 86%|████████▋ | 36380/42146 [00:15<00:02, 2100.10it/s]

 87%|████████▋ | 36591/42146 [00:15<00:02, 2099.06it/s]

 87%|████████▋ | 36808/42146 [00:15<00:02, 2120.16it/s]

 88%|████████▊ | 37025/42146 [00:15<00:02, 2133.61it/s]

 88%|████████▊ | 37278/42146 [00:15<00:02, 2250.70it/s]

 89%|████████▉ | 37537/42146 [00:16<00:01, 2351.12it/s]

 90%|████████▉ | 37783/42146 [00:16<00:01, 2383.52it/s]

 90%|█████████ | 38031/42146 [00:16<00:01, 2410.24it/s]

 91%|█████████ | 38273/42146 [00:16<00:01, 2404.93it/s]

 91%|█████████▏| 38518/42146 [00:16<00:01, 2416.45it/s]

 92%|█████████▏| 38760/42146 [00:16<00:01, 2411.41it/s]

 93%|█████████▎| 39007/42146 [00:16<00:01, 2426.54it/s]

 93%|█████████▎| 39253/42146 [00:16<00:01, 2434.83it/s]

 94%|█████████▎| 39497/42146 [00:16<00:01, 2433.14it/s]

 94%|█████████▍| 39748/42146 [00:16<00:00, 2454.72it/s]

 95%|█████████▍| 40007/42146 [00:17<00:00, 2493.96it/s]

 96%|█████████▌| 40257/42146 [00:17<00:00, 2470.64it/s]

 96%|█████████▌| 40505/42146 [00:17<00:00, 2461.98it/s]

 97%|█████████▋| 40752/42146 [00:17<00:00, 2450.43it/s]

 97%|█████████▋| 40998/42146 [00:17<00:00, 2441.18it/s]

 98%|█████████▊| 41243/42146 [00:17<00:00, 2421.87it/s]

 98%|█████████▊| 41488/42146 [00:17<00:00, 2426.63it/s]

 99%|█████████▉| 41731/42146 [00:17<00:00, 2420.83it/s]

100%|█████████▉| 41974/42146 [00:17<00:00, 2407.93it/s]

100%|██████████| 42146/42146 [00:17<00:00, 2343.40it/s]


pn_history max(lengths): 323


  0%|          | 0/143 [00:00<?, ?it/s]

100%|██████████| 143/143 [00:00<00:00, 15479.13it/s]


feature_text max(lengths): 28


max_len: 354


In [17]:
# ====================================================
# Dataset
# ====================================================
def prepare_input(cfg, text, feature_text):

    # Omit token_type_ids for ModernBERT
    is_token_type_ids = CFG.model != 'answerdotai/ModernBERT-base'

    inputs = cfg.tokenizer(text, feature_text, 
                           add_special_tokens=True,
                           max_length=CFG.max_len,
                           padding="max_length",
                           return_offsets_mapping=False,
                           return_token_type_ids=is_token_type_ids)
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs


def create_label(cfg, text, annotation_length, location_list):
    encoded = cfg.tokenizer(text,
                            add_special_tokens=True,
                            max_length=CFG.max_len,
                            padding="max_length",
                            return_offsets_mapping=True)
    offset_mapping = encoded['offset_mapping']
    ignore_idxes = np.where(np.array(encoded.sequence_ids()) != 0)[0]
    label = np.zeros(len(offset_mapping))
    label[ignore_idxes] = -1
    if annotation_length != 0:
        for location in location_list:
            for loc in [s.split() for s in location.split(';')]:
                start_idx = -1
                end_idx = -1
                start, end = int(loc[0]), int(loc[1])
                for idx in range(len(offset_mapping)):
                    if (start_idx == -1) & (start < offset_mapping[idx][0]):
                        start_idx = idx - 1
                    if (end_idx == -1) & (end <= offset_mapping[idx][1]):
                        end_idx = idx + 1
                if start_idx == -1:
                    start_idx = end_idx
                if (start_idx != -1) & (end_idx != -1):
                    label[start_idx:end_idx] = 1
    return torch.tensor(label, dtype=torch.bfloat16)


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.feature_texts = df['feature_text'].values
        self.pn_historys = df['pn_history'].values
        self.annotation_lengths = df['annotation_length'].values
        self.locations = df['location'].values

    def __len__(self):
        return len(self.feature_texts)

    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, 
                               self.pn_historys[item], 
                               self.feature_texts[item])
        label = create_label(self.cfg, 
                             self.pn_historys[item], 
                             self.annotation_lengths[item], 
                             self.locations[item])
        return inputs, label

# Model

In [18]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
from torch.utils.checkpoint import checkpoint

class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = AutoModel(self.config)

        self.model.gradient_checkpointing_enable()
        self.fc_dropout = nn.Dropout(cfg.fc_dropout)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, inputs):
        # Debugging input shapes and types
        #print("Feature extraction inputs:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        # Process the inputs through the model
        outputs = self.model(**inputs)
        last_hidden_states = outputs.last_hidden_state
        
        # Debugging the output of the model
        #print(f"Last hidden states: shape={last_hidden_states.shape}, dtype={last_hidden_states.dtype}")
        return last_hidden_states

    def forward(self, inputs):
        # Debugging the forward pass
        #print("Starting forward pass...")
        #print("Input keys and shapes:")
        #for k, v in inputs.items():
            #print(f"{k}: shape={v.shape}, dtype={v.dtype}, device={v.device}")
        
        with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):  # Enable AMP
            #print("Running model under AMP...")
            outputs = self.model(**inputs)
        
        # Debugging model output
        #print("Model outputs:")
        #if isinstance(outputs, tuple):
           # for idx, output in enumerate(outputs):
                #print(f"Output {idx}: shape={output.shape if hasattr(output, 'shape') else 'N/A'}")
        #else:
            #print(f"Outputs: {outputs}")
            #pass
        
        # Final linear layer
        feature = outputs.last_hidden_state
        output = self.fc(self.fc_dropout(feature))
        
        # Debugging final output
        #print(f"Final output shape: {output.shape}, dtype={output.dtype}")
        return output


# Helper functions

In [19]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x717b540fe9d0>

In [20]:
def valid_fn(valid_loader, model, criterion, device):
    """
    Validation function with detailed debugging for intermediate outputs.
    """
    losses = AverageMeter()
    model.eval()
    preds = []
    start = time.time()

    print("Starting validation...")
    for step, (inputs, labels) in enumerate(valid_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        
        with torch.no_grad():
            y_preds = model(inputs)  # Model predictions
        
        # Compute loss
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)

        # Store predictions
        preds.append(y_preds.sigmoid().to('cpu').numpy())

        end = time.time()

        # Debugging: Log step-wise progress and intermediate outputs
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(step, len(valid_loader),
                          loss=losses,
                          remain=timeSince(start, float(step+1)/len(valid_loader))))
            #print(f"Sample Predictions: {y_preds.sigmoid()[:5].view(-1).tolist()}")  # Example predictions
            #print(f"Sample Labels: {labels[:5].view(-1).tolist()}")  # Example labels

    predictions = np.concatenate(preds)

    # Debugging: Check shape and content of predictions
    #print(f"Validation Predictions Shape: {predictions.shape}")
    # print(f"First 5 Predictions: {predictions[:5]}")

    return losses.avg, predictions


def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    """
    Training function with detailed debugging for loss and gradients.
    """
    print(f"\nStarting training epoch {epoch + 1} for fold {fold}")
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)

    losses = AverageMeter()
    start = time.time()
    global_step = 0

    for step, (inputs, labels) in enumerate(train_loader):
        # Move inputs and labels to the correct device
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)

        # Mixed precision training
        with torch.cuda.amp.autocast(enabled=CFG.apex, dtype=torch.bfloat16):
            y_preds = model(inputs)
            loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
            loss = torch.masked_select(loss, labels.view(-1, 1) != -1).mean()

        # Gradient accumulation
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps

        losses.update(loss.item(), batch_size)

        # Backpropagation
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)

        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()

        end = time.time()
        
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print(f'Epoch: [{epoch+1}][{step}/{len(train_loader)}] '
                  f'Elapsed {timeSince(start, float(step+1)/len(train_loader))} '
                  f'Loss: {losses.val:.4f}({losses.avg:.4f}) '
                  f'Grad Norm: {grad_norm:.4f}  '
                  f'LR: {scheduler.get_lr()[0]}')

    print(f"Epoch {epoch + 1} completed. Average loss: {losses.avg:.4f}")
    return losses.avg


def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.sigmoid().to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [21]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    valid_texts = valid_folds['pn_history'].values
    valid_labels = create_labels_for_scoring(valid_folds)
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = CustomModel(CFG, config_path=None, pretrained=True)

    checkpoint_path = f"/workspace/deberta-v3-large-5-folds-public/deberta-v3-large_fold{fold}_best.pth"
    if os.path.exists(checkpoint_path):
        LOGGER.info(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location="cuda", weights_only=False) # need to use weight_only = True for latest pytorch version
        model.load_state_dict(checkpoint['model'], strict=False)  # Load model weights
        model.to(device)
        model.train()  # Set model to training mode

    else:
        LOGGER.info("No checkpoint found, starting fresh.")
        
    torch.save(model.config, OUTPUT_DIR+'config.pth')
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss(reduction="none")
    
    best_score = 0.

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        print("Starting training")
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        print("Starting evaluation")
        avg_val_loss, predictions = valid_fn(valid_loader, model, criterion, device)
        predictions = predictions.reshape((len(valid_folds), CFG.max_len))
        
        # scoring
        char_probs = get_char_probs(valid_texts, predictions, CFG.tokenizer)
        results = get_results(char_probs, th=0.5)
        preds = get_predictions(results)
        score = get_score(valid_labels, preds)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        if CFG.wandb:
            wandb.log({f"[fold{fold}] epoch": epoch+1, 
                       f"[fold{fold}] avg_train_loss": avg_loss, 
                       f"[fold{fold}] avg_val_loss": avg_val_loss,
                       f"[fold{fold}] score": score})
        
        if best_score <= score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(),
                        'predictions': predictions},
                        OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth")
            

    #predictions = torch.load(OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             #map_location=torch.device('cpu'))['predictions']
    predictions = torch.load(OUTPUT_DIR + f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             map_location=torch.device('cpu'),weights_only=False)['predictions']

    
    valid_folds[[i for i in range(CFG.max_len)]] = predictions

    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

In [22]:
class CFG:
    wandb=True
    competition='NBME'
    _wandb_kernel='zehra'
    debug=False
    apex=True
    print_freq=100
    num_workers=4
    #model="microsoft/deberta-v3-large" #"answerdotai/ModernBERT-base"
    # model="/kaggle/input/deberta-v3-large/pytorch/0501/1" If running on Kaggle
    model = "/workspace/deberta-v3-large" # running on runpod instance
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=0.05
    num_warmup_steps=0
    epochs=4
    encoder_lr = 1e-5 #encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size = 10 #batch_size=16
    fc_dropout=0.2
    max_len=354
    weight_decay=0.1 #0.01
    gradient_accumulation_steps=1 #1
    max_grad_norm=1000
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    train=True

tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

if CFG.debug:
    CFG.epochs = 3
    CFG.trn_fold = [0]

In [23]:
torch.cuda.empty_cache()

In [24]:
def run_fold_training(oof_df, fold):
    """
    Helper function to run training for a specific fold.
    Concatenates the out-of-fold predictions to oof_df.
    """
    _oof_df = train_loop(train, fold)
    oof_df = pd.concat([oof_df, _oof_df])
    LOGGER.info(f"========== fold: {fold} result ==========")
    evaluate_results(_oof_df)  # Evaluate results after each fold
    return oof_df

def evaluate_results(oof_df):
    """
    Function to handle the result evaluation, including scoring and logging.
    """
    labels = create_labels_for_scoring(oof_df)
    predictions = oof_df[[i for i in range(CFG.max_len)]].values
    char_probs = get_char_probs(oof_df['pn_history'].values, predictions, CFG.tokenizer)
    results = get_results(char_probs, th=0.5)
    preds = get_predictions(results)
    score = get_score(labels, preds)
    LOGGER.info(f'Score: {score:<.4f}')

In [25]:
if __name__ == '__main__':
    CFG.debug = True

    if CFG.train:
        oof_df = pd.DataFrame()

        # Run training for a single fold in debug mode or multiple folds otherwise
        if CFG.debug:
            fold = 0  # In debug mode, we only use fold 0
            if fold in CFG.trn_fold:
                oof_df = run_fold_training(oof_df, fold)
        else:
            for fold in range(CFG.n_fold):  # For all folds in non-debug mode
                if fold in CFG.trn_fold:
                    oof_df = run_fold_training(oof_df, fold)

        # Final evaluation and saving results
        oof_df = oof_df.reset_index(drop=True)
        LOGGER.info(f"========== CV ==========")
        evaluate_results(oof_df)  # Final evaluation of all results
        oof_df.to_pickle(OUTPUT_DIR + 'oof_df.pkl')

    if CFG.wandb:
        wandb.finish()



Loading checkpoint from /workspace/deberta-v3-large-5-folds-public/deberta-v3-large_fold0_best.pth


Starting training

Starting training epoch 1 for fold 0


Epoch: [1][0/6898] Elapsed 0m 2s (remain 277m 11s) Loss: 0.0122(0.0122) Grad Norm: 86621.5469  LR: 9.999999999675951e-06


Epoch: [1][100/6898] Elapsed 2m 0s (remain 135m 17s) Loss: 0.0072(0.0063) Grad Norm: 41214.7656  LR: 9.99999669437843e-06


Epoch: [1][200/6898] Elapsed 5m 27s (remain 181m 37s) Loss: 0.0105(0.0061) Grad Norm: 18489.2773  LR: 9.999986908109645e-06


Epoch: [1][300/6898] Elapsed 8m 54s (remain 195m 25s) Loss: 0.0019(0.0059) Grad Norm: 8818.4082  LR: 9.999970640882282e-06


Epoch: [1][400/6898] Elapsed 12m 23s (remain 200m 42s) Loss: 0.0017(0.0058) Grad Norm: 14749.0674  LR: 9.999947892717426e-06


Epoch: [1][500/6898] Elapsed 15m 50s (remain 202m 11s) Loss: 0.0050(0.0056) Grad Norm: 11383.1670  LR: 9.999918663644563e-06


Epoch: [1][600/6898] Elapsed 19m 16s (remain 201m 54s) Loss: 0.0008(0.0055) Grad Norm: 2717.6829  LR: 9.999882953701581e-06


Epoch: [1][700/6898] Elapsed 22m 44s (remain 201m 2s) Loss: 0.0048(0.0053) Grad Norm: 53704.6133  LR: 9.999840762934764e-06


Epoch: [1][800/6898] Elapsed 26m 12s (remain 199m 25s) Loss: 0.0015(0.0052) Grad Norm: 27202.2148  LR: 9.999792091398804e-06


Epoch: [1][900/6898] Elapsed 29m 39s (remain 197m 23s) Loss: 0.0547(0.0053) Grad Norm: 58664.6797  LR: 9.999736939156785e-06


Epoch: [1][1000/6898] Elapsed 33m 6s (remain 195m 3s) Loss: 0.0088(0.0053) Grad Norm: 21805.4766  LR: 9.999675306280197e-06


Epoch: [1][1100/6898] Elapsed 36m 35s (remain 192m 40s) Loss: 0.0039(0.0052) Grad Norm: 15249.6924  LR: 9.999607192848925e-06


Epoch: [1][1200/6898] Elapsed 40m 3s (remain 190m 0s) Loss: 0.0108(0.0051) Grad Norm: 45413.6055  LR: 9.999532598951263e-06


Epoch: [1][1300/6898] Elapsed 43m 30s (remain 187m 10s) Loss: 0.0000(0.0050) Grad Norm: 50.2443  LR: 9.999451524683896e-06


Epoch: [1][1400/6898] Elapsed 46m 58s (remain 184m 17s) Loss: 0.0019(0.0050) Grad Norm: 10048.0176  LR: 9.99936397015191e-06


Epoch: [1][1500/6898] Elapsed 50m 25s (remain 181m 17s) Loss: 0.0018(0.0049) Grad Norm: 17036.7109  LR: 9.999269935468797e-06


Epoch: [1][1600/6898] Elapsed 53m 52s (remain 178m 16s) Loss: 0.0001(0.0049) Grad Norm: 295.9210  LR: 9.999169420756443e-06


Epoch: [1][1700/6898] Elapsed 57m 20s (remain 175m 10s) Loss: 0.0004(0.0049) Grad Norm: 3862.3428  LR: 9.999062426145132e-06


Epoch: [1][1800/6898] Elapsed 60m 46s (remain 172m 0s) Loss: 0.0110(0.0048) Grad Norm: 87798.3203  LR: 9.998948951773553e-06


Epoch: [1][1900/6898] Elapsed 64m 13s (remain 168m 50s) Loss: 0.0001(0.0048) Grad Norm: 1733.2506  LR: 9.99882899778879e-06


Epoch: [1][2000/6898] Elapsed 67m 41s (remain 165m 40s) Loss: 0.0077(0.0047) Grad Norm: 53345.4414  LR: 9.998702564346325e-06


Epoch: [1][2100/6898] Elapsed 71m 10s (remain 162m 31s) Loss: 0.0008(0.0047) Grad Norm: 11223.8828  LR: 9.998569651610042e-06


Epoch: [1][2200/6898] Elapsed 74m 38s (remain 159m 17s) Loss: 0.0217(0.0047) Grad Norm: 173212.4062  LR: 9.998430259752222e-06


Epoch: [1][2300/6898] Elapsed 78m 6s (remain 156m 1s) Loss: 0.0001(0.0046) Grad Norm: 1280.6219  LR: 9.998284388953545e-06


Epoch: [1][2400/6898] Elapsed 81m 34s (remain 152m 47s) Loss: 0.0022(0.0046) Grad Norm: 45547.7148  LR: 9.998132039403086e-06


Epoch: [1][2500/6898] Elapsed 85m 2s (remain 149m 31s) Loss: 0.0131(0.0045) Grad Norm: 153785.6562  LR: 9.997973211298323e-06


Epoch: [1][2600/6898] Elapsed 88m 31s (remain 146m 15s) Loss: 0.0049(0.0046) Grad Norm: 33719.5703  LR: 9.997807904845123e-06


Epoch: [1][2700/6898] Elapsed 91m 59s (remain 142m 56s) Loss: 0.0053(0.0045) Grad Norm: 34434.7734  LR: 9.997636120257758e-06


Epoch: [1][2800/6898] Elapsed 95m 27s (remain 139m 37s) Loss: 0.0028(0.0045) Grad Norm: 25234.0059  LR: 9.997457857758896e-06


Epoch: [1][2900/6898] Elapsed 98m 55s (remain 136m 17s) Loss: 0.0086(0.0045) Grad Norm: 26884.9492  LR: 9.997273117579597e-06


Epoch: [1][3000/6898] Elapsed 102m 23s (remain 132m 57s) Loss: 0.0017(0.0044) Grad Norm: 27281.7949  LR: 9.997081899959324e-06


Epoch: [1][3100/6898] Elapsed 105m 51s (remain 129m 37s) Loss: 0.0046(0.0044) Grad Norm: 26328.7969  LR: 9.996884205145929e-06


Epoch: [1][3200/6898] Elapsed 109m 21s (remain 126m 18s) Loss: 0.0002(0.0044) Grad Norm: 4520.5742  LR: 9.996680033395664e-06


Epoch: [1][3300/6898] Elapsed 112m 50s (remain 122m 58s) Loss: 0.0001(0.0044) Grad Norm: 2131.8279  LR: 9.996469384973175e-06


Epoch: [1][3400/6898] Elapsed 116m 17s (remain 119m 34s) Loss: 0.0093(0.0044) Grad Norm: 32567.0195  LR: 9.996252260151506e-06


Epoch: [1][3500/6898] Elapsed 119m 45s (remain 116m 11s) Loss: 0.0013(0.0043) Grad Norm: 14637.7598  LR: 9.996028659212089e-06


Epoch: [1][3600/6898] Elapsed 123m 29s (remain 113m 4s) Loss: 0.0151(0.0043) Grad Norm: 136679.4219  LR: 9.995798582444759e-06


Epoch: [1][3700/6898] Elapsed 126m 56s (remain 109m 39s) Loss: 0.0094(0.0043) Grad Norm: 64251.7422  LR: 9.995562030147736e-06


Epoch: [1][3800/6898] Elapsed 130m 22s (remain 106m 13s) Loss: 0.0029(0.0043) Grad Norm: 21117.6172  LR: 9.995319002627643e-06


Epoch: [1][3900/6898] Elapsed 133m 50s (remain 102m 49s) Loss: 0.0023(0.0043) Grad Norm: 99178.1641  LR: 9.995069500199487e-06


Epoch: [1][4000/6898] Elapsed 137m 18s (remain 99m 25s) Loss: 0.0014(0.0043) Grad Norm: 36063.9258  LR: 9.994813523186671e-06


Epoch: [1][4100/6898] Elapsed 140m 44s (remain 95m 59s) Loss: 0.0001(0.0042) Grad Norm: 2401.2698  LR: 9.994551071920995e-06


Epoch: [1][4200/6898] Elapsed 144m 11s (remain 92m 33s) Loss: 0.0016(0.0042) Grad Norm: 44863.3438  LR: 9.994282146742643e-06


Epoch: [1][4300/6898] Elapsed 147m 39s (remain 89m 9s) Loss: 0.0006(0.0042) Grad Norm: 49994.0781  LR: 9.994006748000197e-06


Epoch: [1][4400/6898] Elapsed 151m 6s (remain 85m 44s) Loss: 0.0020(0.0042) Grad Norm: 57971.9570  LR: 9.99372487605063e-06


# Reference

https://www.kaggle.com/code/yasufuminakama/nbme-deberta-base-baseline-train


See the run log on Kaggle: https://www.kaggle.com/code/zehrakorkusuz/training-deberta-v3-large