In [1]:
import wandb
import yaml

from pathlib import Path
from torch.utils.data import DataLoader

from classifier.file_reader import read_files_from_folder
from classifier.dataset import BertPandasDataset, collate_fn, create_bert_datasets, preprocess_dataframe
from classifier.model import ContinualMultilabelBERTClassifier, MultilabelBERTClassifier

FOLDER_PATH = Path("train_classifier.ipynb").parent.absolute()
print(FOLDER_PATH)


/home/woi/code/Energy-Optimal-Inferencing/classifier


In [2]:
SEED = 42
DATASET = "boolq"
MODEL_NAME = "answerdotai/ModernBERT-base"
MINIBATCH_SIZE = 64
N_EPOCHS = 50
TEST_VAL_SET_SIZE = 0.15

benchmark_config_path = Path(f"{FOLDER_PATH.parent}/config/messplus/boolq.yaml")

# Read and parse the YAML file
with benchmark_config_path.open("r") as f:
    classifier_config = yaml.safe_load(f)["classifier_model"]

f.close()

df = read_files_from_folder(f"{FOLDER_PATH.parent}/data/inference_outputs/boolq", file_ext=".csv")
display(df.head())

33


Unnamed: 0_level_0,input_text,benchmark_name,label_small,acc_small,energy_consumption_small,inference_time_small,label_medium,acc_medium,energy_consumption_medium,inference_time_medium
doc_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,does ethanol take more energy make that produces,boolq,0.0,0.0,13.306,0.157293,0.0,0.0,160.453,0.564679
1,is house tax and property tax are same,boolq,0.0,0.0,20.845,0.143748,0.0,0.0,166.045,0.557246
2,is pain experienced in a missing body part or ...,boolq,1.0,1.0,20.205,0.130104,1.0,1.0,132.83,0.499714
3,is harry potter and the escape from gringotts ...,boolq,1.0,1.0,21.512,0.103894,1.0,1.0,159.363,0.513891
4,is there a difference between hydroxyzine hcl ...,boolq,1.0,1.0,22.972,0.128254,1.0,1.0,130.709,0.508619


In [3]:
display(len(df["input_text"]))

3270

In [4]:
text_col = ["input_text"]
label_cols = ["label_small", "label_medium"]

dataset = df[text_col + label_cols]
dataset = preprocess_dataframe(dataset, label_cols=label_cols)

# Create train and validation datasets
train_dataset, val_dataset, tokenizer = create_bert_datasets(
    dataset,
    text_col,
    label_cols,
    model_name=MODEL_NAME,
    max_length=1024,
    val_ratio=0.10,
)

# Create DataLoaders with the custom collate function
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    collate_fn=collate_fn
)

display(f"Training dataset size: {len(train_dataset)}")
display(f"Validation dataset size: {len(val_dataset)}")

'Training dataset size: 2943'

'Validation dataset size: 327'

## Full model training
Training the full model yields strong results but shows overfitting behavior very quickly.
We also exhibit local batch instabilities (observable from loss spikes).
I tried to adjust the classifier architecture to account for those instabilities.
We might need some form of regularization to treat the losses.

In [5]:
classifier = MultilabelBERTClassifier(
    model_name=MODEL_NAME,  # Replace with your preferred BERT variant
    num_labels=len(label_cols),
    learning_rate=1e-3,
    momentum=0.85,
    weight_decay=0.01,
    batch_size=16,
    max_length=128,
    warmup_ratio=0.05,
    threshold=0.5,
    freeze_bert_layers=True,
    config=classifier_config,
)

with wandb.init(
    entity="tum-i13",
    project="mess-plus-classifier-training-offline",
    name="minibatch_size-16-mom-0.9"
):

    # Train the model
    classifier.fit(train_dataset, val_dataset, epochs=1, early_stopping_patience=2)

wandb.finish()


INFO:classifier.model:Using device: cuda
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mherbertw[0m ([33mtum-i13[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO:classifier.model:Initializing custom BERTClassifier: answerdotai/ModernBERT-base with 2 labels
Epoch 1/5 [Training]: 100%|██████████| 184/184 [00:27<00:00,  6.63it/s, loss=0.5148, batch_loss=0.6345, lr=0.001]
Epoch 1/5 [Validation]: 100%|██████████| 21/21 [00:03<00:00,  6.29it/s, val_loss=0.7940, avg_val_loss=0.5499]
INFO:classifier.model:Epoch 1/5 - Time: 31.10s
INFO:classifier.model:  Train Loss: 0.5425 - Val Loss: 0.5499
INFO:classifier.model:  Val Metrics - Accuracy: 0.7645, F1: 0.8666, F1(macro): 0.8649
INFO:classifier.model:  Per-label metrics:
INFO:classifier.model:    Label 0: F1=0.8216, Prec=0.6972, Rec=1.0000
INFO:classifier.model:    Label 1: F1=0.9082, Prec=0.8318, Rec=1.0000
INFO:classifier.model:  ✓ Best model saved!
Epoch 2/5 [Training]: 100%|██████████| 184/184 [00:26<00:00,  6.86it/s, loss=0.5408, batch_loss=0.8077, lr=0.001]
Epoch 2/5 [Validation]: 100%|██████████| 21/21 [00:03<00:00,  7.00it/s, val_loss=0.7884, avg_val_loss=0.5481]
INFO:classifier.model:Epoch 2/

0,1
batch,▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
batch_loss,▆▅▄▄▅▄▇▃▄▅▂▄▃█▃▇█▇▆▃▇▇▃▆▁▂▇▅▂▇▂▅▂▇▃▆▃▅▄▅
epoch,▁▅█
learning_rate,▁▅▇██▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂
running_loss,██▇▆▄▃▄▂▄▃▄▄▃▃▄▂▄▃▅▄▆▂▂▂▂▂▂▂▂▂▁▂▂▄▃▁▃▄▄▃
time/epoch_seconds,█▁▁
train/loss,█▂▁
val/accuracy,▁▁▁
val/f1_macro,▁▁▁
val/f1_micro,▁▁▁

0,1
batch,551.0
batch_loss,0.53441
epoch,3.0
learning_rate,0.00042
running_loss,0.54455
time/epoch_seconds,29.79943
train/loss,0.52695
val/accuracy,0.76453
val/f1_macro,0.8649
val/f1_micro,0.86655


In [6]:
classifier.predict(texts=[
    "does ethanol take more energy make that produces",
    "is the liver part of the excretory system"
])

(array([[1, 1],
        [1, 1]]),
 array([[0.7006271, 0.8165475],
        [0.7049741, 0.8402771]], dtype=float32))

## Continuous learning approach

In [7]:
# cont_model = ContinualMultilabelBERTClassifier(
#     model_name=MODEL_NAME,  # Replace with your preferred BERT variant
#     num_labels=len(label_cols),
#     learning_rate=8e-7,
#     weight_decay=0.01,
#     batch_size=16,
#     max_length=128,
#     warmup_ratio=0.1,
#     threshold=0.5,
#     freeze_bert_layers=True,
#     memory_size=0
# )
#
#
# for idx in range(len(dataset)):
#     print(f"Fetching sample {idx}/{len(dataset)}...")
#     sample = BertPandasDataset(df.loc[idx], text_col, label_cols, tokenizer, 128)
#     cont_model.incremental_fit(
#         new_train_dataset=sample,
#         new_val_dataset=val_dataset,
#     )
#
#     if idx % 50 == 0 and idx != 0:
#         display(f"Done.")
#         break
