In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from config import DEVICE, SEED, MODEL_CONFIG, TRAINING_CONFIG, DATASET_CONFIG
from model import JointCausalModel
from utility import compute_class_weights, label_value_counts
from dataset_collator import CausalDataset, CausalDatasetCollator
from config import id2label_cls, id2label_bio, id2label_rel
from evaluate_joint_causal_model import evaluate_model, print_eval_report
from trainer import train_model
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/train.csv"
val_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/val.csv"
test_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/test.csv"
train_df = pd.read_csv(train_data_path)
val_df = pd.read_csv(val_data_path)
test_df = pd.read_csv(test_data_path)

In [4]:
train_dataset = CausalDataset(
    train_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
val_dataset = CausalDataset(
    val_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
test_dataset = CausalDataset(
    test_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%

In [5]:
labels_flat = label_value_counts(train_dataset)
# %%
cls_label_flat = labels_flat["cls_labels_flat"]
bio_label_flat = labels_flat["bio_labels_flat"]
rel_label_flat = labels_flat["rel_labels_flat"]
# %%
# Calculate class weights
cls_weights = compute_class_weights(labels_list=cls_label_flat, num_classes=MODEL_CONFIG["num_cls_labels"], technique="ens", ignore_index=-100)
bio_weights = compute_class_weights(labels_list=bio_label_flat, num_classes=MODEL_CONFIG["num_bio_labels"], technique="ens", ignore_index=-100)
rel_weights = compute_class_weights(labels_list=rel_label_flat, num_classes=MODEL_CONFIG["num_rel_labels"], technique="ens", ignore_index=-100)
print(f"CLS Weights: {cls_weights}")
print(f"BIO Weights: {bio_weights}")
print(f"REL Weights: {rel_weights}")
# %%

cls_labels_value_counts:
 0    1047
1    1035
Name: count, dtype: int64
bio_labels_value_counts:
  6      52764
 3       8717
 1       6948
-100     4164
 2       1320
 0       1179
 5        483
 4         79
Name: count, dtype: int64
rel_labels_value_counts:
 0    2887
1    1494
Name: count, dtype: int64
CLS Weights: tensor([0.0015, 0.0016])
BIO Weights: tensor([0.0014, 0.0010, 0.0014, 0.0010, 0.0132, 0.0026, 0.0010])
REL Weights: tensor([0.0011, 0.0013])


In [6]:
collator = CausalDatasetCollator(
    tokenizer=train_dataset.tokenizer
)
# %%
# take a 100 samples from train_dataset
# train_dataset = torch.utils.data.Subset(train_dataset, random.sample(range(len(train_dataset)), 20))
# val_dataset = torch.utils.data.Subset(val_dataset, random.sample(range(len(val_dataset)), 20))
# # %%
train_dataloader = DataLoader(
    train_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=True

)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=False
)

In [7]:
model = JointCausalModel(
    encoder_name=MODEL_CONFIG["encoder_name"],
    num_cls_labels=MODEL_CONFIG["num_cls_labels"],
    num_bio_labels=MODEL_CONFIG["num_bio_labels"],
    num_rel_labels=MODEL_CONFIG["num_rel_labels"],
    dropout=MODEL_CONFIG["dropout"]
)


In [8]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAINING_CONFIG["learning_rate"],
    weight_decay=TRAINING_CONFIG["weight_decay"]
)
# %%
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=2
)

In [9]:
model_save_path = r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/expert_bert_softmax_model.pt"

In [10]:
trained_model, training_history = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        num_epochs=TRAINING_CONFIG["num_epochs"],
        device=DEVICE,
        id2label_cls=id2label_cls,
        id2label_bio=id2label_bio,
        id2label_rel=id2label_rel,
        model_save_path=model_save_path,
        scheduler=scheduler,
        cls_class_weights=cls_weights,
        bio_class_weights=bio_weights, # Only for softmax
        rel_class_weights=rel_weights,
        patience_epochs=TRAINING_CONFIG["num_epochs"],
        seed=SEED,
        max_grad_norm=TRAINING_CONFIG["gradient_clip_val"],
        eval_fn_metrics=evaluate_model, # Pass your evaluate_model function here
        print_report_fn=print_eval_report # Pass your print_eval_report function here
    )

Epoch 1/20 [Training]:   0%|          | 0/131 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
                                                                                                                               


Epoch 1/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           4.6890
  Average Validation Loss:         1.5285
  Overall Validation Avg F1 (Macro): 0.6625
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7698
    Macro Precision: 0.8133
    Macro Recall:    0.7786
    Accuracy:        0.7756
    Per-class details:
      non-causal  : F1=0.7335 (P=0.9267, R=0.6070, Support=229.0)
      causal      : F1=0.8061 (P=0.7000, R=0.9502, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3752
    Macro Precision: 0.4539
    Macro Recall:    0.4179
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.323 (P=0.290, R=0.365, S=263.0)
      I-C       : F1=0.453 (P=0.622, R=0.356, S=1451.0)
      B-E       : F1=0.049 (P=0.583, R=0.026

                                                                                                                               


Epoch 2/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           2.9168
  Average Validation Loss:         1.2967
  Overall Validation Avg F1 (Macro): 0.7052
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7828
    Macro Precision: 0.8147
    Macro Recall:    0.7893
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7538 (P=0.9130, R=0.6419, Support=229.0)
      causal      : F1=0.8118 (P=0.7163, R=0.9367, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4805
    Macro Precision: 0.5111
    Macro Recall:    0.5302
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.467 (P=0.455, R=0.479, S=263.0)
      I-C       : F1=0.600 (P=0.607, R=0.593, S=1451.0)
      B-E       : F1=0.190 (P=0.554, R=0.114

                                                                                                                               


Epoch 3/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.8206
  Average Validation Loss:         1.4290
  Overall Validation Avg F1 (Macro): 0.7387
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8198
    Macro Precision: 0.8225
    Macro Recall:    0.8208
    Accuracy:        0.8200
    Per-class details:
      non-causal  : F1=0.8146 (P=0.8558, R=0.7773, Support=229.0)
      causal      : F1=0.8251 (P=0.7893, R=0.8643, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5391
    Macro Precision: 0.4972
    Macro Recall:    0.6335
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.466, R=0.578, S=263.0)
      I-C       : F1=0.621 (P=0.575, R=0.675, S=1451.0)
      B-E       : F1=0.422 (P=0.391, R=0.458

                                                                                                                               


Epoch 4/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2496
  Average Validation Loss:         1.6650
  Overall Validation Avg F1 (Macro): 0.7494
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8332
    Macro Precision: 0.8354
    Macro Recall:    0.8340
    Accuracy:        0.8333
    Per-class details:
      non-causal  : F1=0.8292 (P=0.8667, R=0.7948, Support=229.0)
      causal      : F1=0.8373 (P=0.8042, R=0.8733, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5523
    Macro Precision: 0.5043
    Macro Recall:    0.6333
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.474 (P=0.393, R=0.597, S=263.0)
      I-C       : F1=0.638 (P=0.612, R=0.666, S=1451.0)
      B-E       : F1=0.441 (P=0.393, R=0.502

                                                                                                                               


Epoch 5/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.7662
  Average Validation Loss:         1.8860
  Overall Validation Avg F1 (Macro): 0.7494
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8372
    Macro Precision: 0.8456
    Macro Recall:    0.8391
    Accuracy:        0.8378
    Per-class details:
      non-causal  : F1=0.8274 (P=0.9021, R=0.7642, Support=229.0)
      causal      : F1=0.8470 (P=0.7891, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5560
    Macro Precision: 0.5033
    Macro Recall:    0.6698
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.514 (P=0.449, R=0.601, S=263.0)
      I-C       : F1=0.645 (P=0.625, R=0.667, S=1451.0)
      B-E       : F1=0.451 (P=0.400, R=0.517

                                                                                                                               


Epoch 6/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.6169
  Average Validation Loss:         2.0160
  Overall Validation Avg F1 (Macro): 0.7496
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8371
    Macro Precision: 0.8465
    Macro Recall:    0.8392
    Accuracy:        0.8378
    Per-class details:
      non-causal  : F1=0.8266 (P=0.9062, R=0.7598, Support=229.0)
      causal      : F1=0.8476 (P=0.7868, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5635
    Macro Precision: 0.5207
    Macro Recall:    0.6464
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.506 (P=0.452, R=0.574, S=263.0)
      I-C       : F1=0.639 (P=0.657, R=0.622, S=1451.0)
      B-E       : F1=0.462 (P=0.416, R=0.520

                                                                                                                               


Epoch 7/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.5364
  Average Validation Loss:         2.1468
  Overall Validation Avg F1 (Macro): 0.7532
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8345
    Macro Precision: 0.8481
    Macro Recall:    0.8372
    Accuracy:        0.8356
    Per-class details:
      non-causal  : F1=0.8213 (P=0.9189, R=0.7424, Support=229.0)
      causal      : F1=0.8477 (P=0.7774, R=0.9321, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5712
    Macro Precision: 0.5241
    Macro Recall:    0.6583
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.518 (P=0.463, R=0.589, S=263.0)
      I-C       : F1=0.649 (P=0.648, R=0.651, S=1451.0)
      B-E       : F1=0.465 (P=0.429, R=0.509

                                                                                                                               


Epoch 8/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4411
  Average Validation Loss:         2.1058
  Overall Validation Avg F1 (Macro): 0.7585
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8393
    Macro Precision: 0.8493
    Macro Recall:    0.8415
    Accuracy:        0.8400
    Per-class details:
      non-causal  : F1=0.8286 (P=0.9110, R=0.7598, Support=229.0)
      causal      : F1=0.8500 (P=0.7876, R=0.9231, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5770
    Macro Precision: 0.5374
    Macro Recall:    0.6501
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.506 (P=0.455, R=0.570, S=263.0)
      I-C       : F1=0.648 (P=0.665, R=0.631, S=1451.0)
      B-E       : F1=0.467 (P=0.431, R=0.509

                                                                                                                               


Epoch 9/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4528
  Average Validation Loss:         2.1359
  Overall Validation Avg F1 (Macro): 0.7594
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8395
    Macro Precision: 0.8474
    Macro Recall:    0.8413
    Accuracy:        0.8400
    Per-class details:
      non-causal  : F1=0.8302 (P=0.9026, R=0.7686, Support=229.0)
      causal      : F1=0.8487 (P=0.7922, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5782
    Macro Precision: 0.5417
    Macro Recall:    0.6455
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.513 (P=0.466, R=0.570, S=263.0)
      I-C       : F1=0.646 (P=0.670, R=0.624, S=1451.0)
      B-E       : F1=0.467 (P=0.434, R=0.506

                                                                                                                                


Epoch 10/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4420
  Average Validation Loss:         2.1431
  Overall Validation Avg F1 (Macro): 0.7601
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5775
    Macro Precision: 0.5393
    Macro Recall:    0.6472
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.506 (P=0.457, R=0.567, S=263.0)
      I-C       : F1=0.646 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.469 (P=0.435, R=0.50

                                                                                                                                


Epoch 11/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4368
  Average Validation Loss:         2.1544
  Overall Validation Avg F1 (Macro): 0.7580
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5774
    Macro Precision: 0.5394
    Macro Recall:    0.6470
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.506 (P=0.457, R=0.567, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.469 (P=0.435, R=0.50

                                                                                                                                


Epoch 12/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4306
  Average Validation Loss:         2.1536
  Overall Validation Avg F1 (Macro): 0.7584
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5775
    Macro Precision: 0.5396
    Macro Recall:    0.6470
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.506 (P=0.457, R=0.567, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.470 (P=0.437, R=0.50

                                                                                                                                


Epoch 13/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4367
  Average Validation Loss:         2.1771
  Overall Validation Avg F1 (Macro): 0.7592
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5410
    Macro Recall:    0.6470
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.457, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50

                                                                                                                                


Epoch 14/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4339
  Average Validation Loss:         2.1345
  Overall Validation Avg F1 (Macro): 0.7567
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5410
    Macro Recall:    0.6470
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.457, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50

                                                                                                                                


Epoch 15/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4218
  Average Validation Loss:         2.1305
  Overall Validation Avg F1 (Macro): 0.7568
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5412
    Macro Recall:    0.6469
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.459, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50

                                                                                                                                


Epoch 16/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4319
  Average Validation Loss:         2.1581
  Overall Validation Avg F1 (Macro): 0.7589
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5412
    Macro Recall:    0.6469
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.459, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50

                                                                                                                                


Epoch 17/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4241
  Average Validation Loss:         2.1536
  Overall Validation Avg F1 (Macro): 0.7555
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5412
    Macro Recall:    0.6469
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.459, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50

                                                                                                                                


Epoch 18/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4316
  Average Validation Loss:         2.1404
  Overall Validation Avg F1 (Macro): 0.7618
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5412
    Macro Recall:    0.6469
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.459, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50

                                                                                                                                


Epoch 19/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4175
  Average Validation Loss:         2.1458
  Overall Validation Avg F1 (Macro): 0.7603
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5412
    Macro Recall:    0.6469
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.459, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50

                                                                                                                                


Epoch 20/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4320
  Average Validation Loss:         2.1481
  Overall Validation Avg F1 (Macro): 0.7569
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8417
    Macro Precision: 0.8501
    Macro Recall:    0.8436
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8322 (P=0.9072, R=0.7686, Support=229.0)
      causal      : F1=0.8512 (P=0.7930, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5788
    Macro Precision: 0.5412
    Macro Recall:    0.6469
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.508 (P=0.459, R=0.570, S=263.0)
      I-C       : F1=0.647 (P=0.667, R=0.627, S=1451.0)
      B-E       : F1=0.468 (P=0.435, R=0.50



In [11]:
trained_model.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/hf_exper_bert_softmax"
)

In [17]:
train_dataset.tokenizer.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/hf_exper_bert_softmax"
)

('/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/hf_exper_bert_softmax/tokenizer_config.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/hf_exper_bert_softmax/special_tokens_map.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/hf_exper_bert_softmax/vocab.txt',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/hf_exper_bert_softmax/added_tokens.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/hf_exper_bert_softmax/tokenizer.json')