In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from config import DEVICE, SEED, MODEL_CONFIG, TRAINING_CONFIG, DATASET_CONFIG
from model import JointCausalModel
from utility import compute_class_weights, label_value_counts
from dataset_collator import CausalDataset, CausalDatasetCollator
from config import id2label_cls, id2label_bio, id2label_rel
from evaluate_joint_causal_model import evaluate_model, print_eval_report
from trainer import train_model
import random
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/train.csv"
val_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/val.csv"
test_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/test.csv"
train_df = pd.read_csv(train_data_path)
val_df = pd.read_csv(val_data_path)
test_df = pd.read_csv(test_data_path)

In [3]:
train_dataset = CausalDataset(
    train_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
val_dataset = CausalDataset(
    val_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
test_dataset = CausalDataset(
    test_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%

In [4]:
labels_flat = label_value_counts(train_dataset)
# %%
cls_label_flat = labels_flat["cls_labels_flat"]
bio_label_flat = labels_flat["bio_labels_flat"]
rel_label_flat = labels_flat["rel_labels_flat"]
# %%
# Calculate class weights
cls_weights = compute_class_weights(labels_list=cls_label_flat, num_classes=MODEL_CONFIG["num_cls_labels"], technique="ens", ignore_index=-100)
bio_weights = compute_class_weights(labels_list=bio_label_flat, num_classes=MODEL_CONFIG["num_bio_labels"], technique="ens", ignore_index=-100)
rel_weights = compute_class_weights(labels_list=rel_label_flat, num_classes=MODEL_CONFIG["num_rel_labels"], technique="ens", ignore_index=-100)
print(f"CLS Weights: {cls_weights}")
print(f"BIO Weights: {bio_weights}")
print(f"REL Weights: {rel_weights}")
# %%

cls_labels_value_counts:
 0    1047
1    1035
Name: count, dtype: int64
bio_labels_value_counts:
  6      52764
 3       8717
 1       6948
-100     4164
 2       1320
 0       1179
 5        483
 4         79
Name: count, dtype: int64
rel_labels_value_counts:
 0    2887
1    1494
Name: count, dtype: int64
CLS Weights: tensor([0.0015, 0.0016])
BIO Weights: tensor([0.0014, 0.0010, 0.0014, 0.0010, 0.0132, 0.0026, 0.0010])
REL Weights: tensor([0.0011, 0.0013])


In [5]:
collator = CausalDatasetCollator(
    tokenizer=train_dataset.tokenizer
)
# %%
# take a 100 samples from train_dataset
# train_dataset = torch.utils.data.Subset(train_dataset, random.sample(range(len(train_dataset)), 20))
# val_dataset = torch.utils.data.Subset(val_dataset, random.sample(range(len(val_dataset)), 20))
# # %%
train_dataloader = DataLoader(
    train_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=True

)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=False
)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG["encoder_name"])
model_path = r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/expert_bert_GCE_weakSP_model.pt"
model = JointCausalModel(**MODEL_CONFIG)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()

JointCausalModel(
  (enc): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [7]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAINING_CONFIG["learning_rate"],
    weight_decay=TRAINING_CONFIG["weight_decay"]
)
# %%
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=2
)

In [8]:
model_save_path = r"src/jointlearning/expert_bert_GCE_Softmax_Normal/expert_bert_GCE_Softmax_Normal_model.pt"

In [9]:
trained_model, training_history = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        num_epochs=TRAINING_CONFIG["num_epochs"],
        device=DEVICE,
        id2label_cls=id2label_cls,
        id2label_bio=id2label_bio,
        id2label_rel=id2label_rel,
        model_save_path=model_save_path,
        scheduler=scheduler,
        cls_class_weights=cls_weights,
        bio_class_weights=bio_weights, # Only for softmax
        rel_class_weights=rel_weights,
        patience_epochs=TRAINING_CONFIG["num_epochs"],
        seed=SEED,
        task_loss_weights={"cls": 1.0, "bio": 4.0, "rel": 1.0},
        max_grad_norm=TRAINING_CONFIG["gradient_clip_val"],
        eval_fn_metrics=evaluate_model, # Pass your evaluate_model function here
        print_report_fn=print_eval_report # Pass your print_eval_report function here
    )

--- Training Configuration ---
Device: cuda
Number of Epochs: 20
Seed: 8642
Optimizer: AdamW (LR: 1e-05, Weight Decay: 0.1)
Scheduler: ReduceLROnPlateau
Max Grad Norm: 1.0 (Mode: L2 norm if enabled)
Early Stopping Patience: 20
Model Save Path: src/jointlearning/expert_bert_GCE_Softmax_Normal/expert_bert_GCE_Softmax_Normal_model.pt
Mode: Standard Training (CrossEntropy)
Using task loss weights: {'cls': 1.0, 'bio': 4.0, 'rel': 1.0}
CLS Class Weights: Provided
BIO Class Weights: Provided
REL Class Weights: Provided
----------------------------


Epoch 1/20 [Training]:   0%|          | 0/131 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
                                                                                                                               


Epoch 1/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           5.5857
  Average Validation Loss:         1.2026
  Overall Validation Avg F1 (Macro): 0.7282
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8162
    Macro Precision: 0.8335
    Macro Recall:    0.8197
    Accuracy:        0.8178
    Per-class details:
      non-causal  : F1=0.7990 (P=0.9106, R=0.7118, Support=229.0)
      causal      : F1=0.8333 (P=0.7565, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4295
    Macro Precision: 0.4270
    Macro Recall:    0.4332
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.465 (P=0.442, R=0.490, S=263.0)
      I-C       : F1=0.596 (P=0.623, R=0.571, S=1451.0)
      B-E       : F1=0.414 (P=0.423, R=0.406

                                                                                                                               


Epoch 2/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           2.9043
  Average Validation Loss:         1.0754
  Overall Validation Avg F1 (Macro): 0.7401
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8390
    Macro Precision: 0.8527
    Macro Recall:    0.8417
    Accuracy:        0.8400
    Per-class details:
      non-causal  : F1=0.8261 (P=0.9243, R=0.7467, Support=229.0)
      causal      : F1=0.8519 (P=0.7811, R=0.9367, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4535
    Macro Precision: 0.4385
    Macro Recall:    0.4721
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.511 (P=0.453, R=0.586, S=263.0)
      I-C       : F1=0.648 (P=0.652, R=0.644, S=1451.0)
      B-E       : F1=0.452 (P=0.430, R=0.476

                                                                                                                               


Epoch 3/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           2.2416
  Average Validation Loss:         1.1568
  Overall Validation Avg F1 (Macro): 0.7837
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8507
    Macro Precision: 0.8573
    Macro Recall:    0.8523
    Accuracy:        0.8511
    Per-class details:
      non-causal  : F1=0.8431 (P=0.9091, R=0.7860, Support=229.0)
      causal      : F1=0.8584 (P=0.8056, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5566
    Macro Precision: 0.5712
    Macro Recall:    0.5633
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.521 (P=0.476, R=0.574, S=263.0)
      I-C       : F1=0.663 (P=0.662, R=0.664, S=1451.0)
      B-E       : F1=0.458 (P=0.452, R=0.465

                                                                                                                               


Epoch 4/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.7347
  Average Validation Loss:         1.1811
  Overall Validation Avg F1 (Macro): 0.7876
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8419
    Macro Precision: 0.8475
    Macro Recall:    0.8433
    Accuracy:        0.8422
    Per-class details:
      non-causal  : F1=0.8345 (P=0.8950, R=0.7817, Support=229.0)
      causal      : F1=0.8493 (P=0.8000, R=0.9050, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5750
    Macro Precision: 0.5289
    Macro Recall:    0.6930
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.530 (P=0.479, R=0.593, S=263.0)
      I-C       : F1=0.671 (P=0.680, R=0.662, S=1451.0)
      B-E       : F1=0.473 (P=0.447, R=0.502

                                                                                                                               


Epoch 5/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.3772
  Average Validation Loss:         1.2679
  Overall Validation Avg F1 (Macro): 0.7837
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8462
    Macro Precision: 0.8537
    Macro Recall:    0.8479
    Accuracy:        0.8467
    Per-class details:
      non-causal  : F1=0.8376 (P=0.9082, R=0.7773, Support=229.0)
      causal      : F1=0.8547 (P=0.7992, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5686
    Macro Precision: 0.5238
    Macro Recall:    0.6618
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.526 (P=0.477, R=0.586, S=263.0)
      I-C       : F1=0.673 (P=0.667, R=0.680, S=1451.0)
      B-E       : F1=0.473 (P=0.442, R=0.509

                                                                                                                               


Epoch 6/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2948
  Average Validation Loss:         1.3352
  Overall Validation Avg F1 (Macro): 0.7855
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8484
    Macro Precision: 0.8564
    Macro Recall:    0.8502
    Accuracy:        0.8489
    Per-class details:
      non-causal  : F1=0.8396 (P=0.9128, R=0.7773, Support=229.0)
      causal      : F1=0.8571 (P=0.8000, R=0.9231, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5676
    Macro Precision: 0.5237
    Macro Recall:    0.6622
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.518 (P=0.475, R=0.570, S=263.0)
      I-C       : F1=0.675 (P=0.675, R=0.676, S=1451.0)
      B-E       : F1=0.474 (P=0.441, R=0.513

                                                                                                                               


Epoch 7/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2603
  Average Validation Loss:         1.3274
  Overall Validation Avg F1 (Macro): 0.7873
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8552
    Macro Precision: 0.8618
    Macro Recall:    0.8567
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8478 (P=0.9141, R=0.7904, Support=229.0)
      causal      : F1=0.8626 (P=0.8095, R=0.9231, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5649
    Macro Precision: 0.5236
    Macro Recall:    0.6591
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.476, R=0.563, S=263.0)
      I-C       : F1=0.668 (P=0.677, R=0.660, S=1451.0)
      B-E       : F1=0.466 (P=0.441, R=0.494

                                                                                                                               


Epoch 8/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2433
  Average Validation Loss:         1.3357
  Overall Validation Avg F1 (Macro): 0.7855
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8552
    Macro Precision: 0.8618
    Macro Recall:    0.8567
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8478 (P=0.9141, R=0.7904, Support=229.0)
      causal      : F1=0.8626 (P=0.8095, R=0.9231, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5654
    Macro Precision: 0.5229
    Macro Recall:    0.6606
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.518 (P=0.475, R=0.570, S=263.0)
      I-C       : F1=0.669 (P=0.676, R=0.663, S=1451.0)
      B-E       : F1=0.472 (P=0.443, R=0.506

                                                                                                                               


Epoch 9/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2445
  Average Validation Loss:         1.3370
  Overall Validation Avg F1 (Macro): 0.7880
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8574
    Macro Precision: 0.8645
    Macro Recall:    0.8590
    Accuracy:        0.8578
    Per-class details:
      non-causal  : F1=0.8498 (P=0.9188, R=0.7904, Support=229.0)
      causal      : F1=0.8650 (P=0.8103, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5661
    Macro Precision: 0.5233
    Macro Recall:    0.6620
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.518 (P=0.475, R=0.570, S=263.0)
      I-C       : F1=0.670 (P=0.675, R=0.665, S=1451.0)
      B-E       : F1=0.478 (P=0.447, R=0.513

                                                                                                                                


Epoch 10/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1887
  Average Validation Loss:         1.3538
  Overall Validation Avg F1 (Macro): 0.7894
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8528
    Macro Precision: 0.8609
    Macro Recall:    0.8546
    Accuracy:        0.8533
    Per-class details:
      non-causal  : F1=0.8443 (P=0.9179, R=0.7817, Support=229.0)
      causal      : F1=0.8613 (P=0.8039, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5657
    Macro Precision: 0.5217
    Macro Recall:    0.6632
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.515 (P=0.470, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.675, R=0.669, S=1451.0)
      B-E       : F1=0.478 (P=0.444, R=0.51

                                                                                                                                


Epoch 11/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1772
  Average Validation Loss:         1.3481
  Overall Validation Avg F1 (Macro): 0.7876
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5659
    Macro Precision: 0.5220
    Macro Recall:    0.6632
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.472, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.675, R=0.669, S=1451.0)
      B-E       : F1=0.478 (P=0.444, R=0.51

                                                                                                                                


Epoch 12/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1790
  Average Validation Loss:         1.3481
  Overall Validation Avg F1 (Macro): 0.7891
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5659
    Macro Precision: 0.5220
    Macro Recall:    0.6632
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.472, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.675, R=0.669, S=1451.0)
      B-E       : F1=0.478 (P=0.444, R=0.51

                                                                                                                                


Epoch 13/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1998
  Average Validation Loss:         1.3443
  Overall Validation Avg F1 (Macro): 0.7884
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5659
    Macro Precision: 0.5222
    Macro Recall:    0.6631
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.472, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.675, R=0.668, S=1451.0)
      B-E       : F1=0.479 (P=0.446, R=0.51

                                                                                                                                


Epoch 14/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1873
  Average Validation Loss:         1.3529
  Overall Validation Avg F1 (Macro): 0.7887
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5659
    Macro Precision: 0.5222
    Macro Recall:    0.6631
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.472, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.675, R=0.668, S=1451.0)
      B-E       : F1=0.479 (P=0.446, R=0.51

                                                                                                                                


Epoch 15/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1617
  Average Validation Loss:         1.3215
  Overall Validation Avg F1 (Macro): 0.7916
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5663
    Macro Precision: 0.5227
    Macro Recall:    0.6631
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.472, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.676, R=0.668, S=1451.0)
      B-E       : F1=0.479 (P=0.447, R=0.51

                                                                                                                                


Epoch 16/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2000
  Average Validation Loss:         1.3354
  Overall Validation Avg F1 (Macro): 0.7877
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5662
    Macro Precision: 0.5227
    Macro Recall:    0.6626
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.517 (P=0.473, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.676, R=0.669, S=1451.0)
      B-E       : F1=0.476 (P=0.444, R=0.51

                                                                                                                                


Epoch 17/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1845
  Average Validation Loss:         1.3704
  Overall Validation Avg F1 (Macro): 0.7886
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5666
    Macro Precision: 0.5230
    Macro Recall:    0.6632
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.516 (P=0.472, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.676, R=0.669, S=1451.0)
      B-E       : F1=0.479 (P=0.447, R=0.51

                                                                                                                                


Epoch 18/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1780
  Average Validation Loss:         1.3562
  Overall Validation Avg F1 (Macro): 0.7887
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5670
    Macro Precision: 0.5235
    Macro Recall:    0.6632
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.517 (P=0.473, R=0.570, S=263.0)
      I-C       : F1=0.672 (P=0.676, R=0.669, S=1451.0)
      B-E       : F1=0.479 (P=0.447, R=0.51

                                                                                                                                


Epoch 19/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1747
  Average Validation Loss:         1.3406
  Overall Validation Avg F1 (Macro): 0.7899
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5669
    Macro Precision: 0.5235
    Macro Recall:    0.6627
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.517 (P=0.473, R=0.570, S=263.0)
      I-C       : F1=0.673 (P=0.676, R=0.669, S=1451.0)
      B-E       : F1=0.477 (P=0.446, R=0.51

                                                                                                                                


Epoch 20/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.1910
  Average Validation Loss:         1.3456
  Overall Validation Avg F1 (Macro): 0.7883
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.8551
    Macro Precision: 0.8627
    Macro Recall:    0.8568
    Accuracy:        0.8556
    Per-class details:
      non-causal  : F1=0.8471 (P=0.9184, R=0.7860, Support=229.0)
      causal      : F1=0.8632 (P=0.8071, R=0.9276, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5669
    Macro Precision: 0.5235
    Macro Recall:    0.6627
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.517 (P=0.473, R=0.570, S=263.0)
      I-C       : F1=0.673 (P=0.676, R=0.669, S=1451.0)
      B-E       : F1=0.477 (P=0.446, R=0.51



In [10]:
trained_model.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Normal/hf_exper_bert_GCE_Softmax_Normal"
)

In [11]:
train_dataset.tokenizer.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Normal/hf_exper_bert_GCE_Softmax_Normal"
)

('/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Normal/hf_exper_bert_GCE_Softmax_Normal/tokenizer_config.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Normal/hf_exper_bert_GCE_Softmax_Normal/special_tokens_map.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Normal/hf_exper_bert_GCE_Softmax_Normal/vocab.txt',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Normal/hf_exper_bert_GCE_Softmax_Normal/added_tokens.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Normal/hf_exper_bert_GCE_Softmax_Normal/tokenizer.json')