In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from config import DEVICE, SEED, MODEL_CONFIG, TRAINING_CONFIG, DATASET_CONFIG
from model import JointCausalModel
from utility import compute_class_weights, label_value_counts
from dataset_collator import CausalDataset, CausalDatasetCollator
from config import id2label_cls, id2label_bio, id2label_rel
from evaluate_joint_causal_model import evaluate_model, print_eval_report
from trainer import train_model
import random
from transformers import AutoTokenizer
from utility import freeze_encoder_layers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/train.csv"
val_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/val.csv"
test_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/test.csv"
train_df = pd.read_csv(train_data_path)
val_df = pd.read_csv(val_data_path)
test_df = pd.read_csv(test_data_path)

In [3]:
train_dataset = CausalDataset(
    train_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
val_dataset = CausalDataset(
    val_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
test_dataset = CausalDataset(
    test_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%

In [4]:
labels_flat = label_value_counts(train_dataset)
# %%
cls_label_flat = labels_flat["cls_labels_flat"]
bio_label_flat = labels_flat["bio_labels_flat"]
rel_label_flat = labels_flat["rel_labels_flat"]
# %%
# Calculate class weights
cls_weights = compute_class_weights(labels_list=cls_label_flat, num_classes=MODEL_CONFIG["num_cls_labels"], technique="ens", ignore_index=-100)
bio_weights = compute_class_weights(labels_list=bio_label_flat, num_classes=MODEL_CONFIG["num_bio_labels"], technique="ens", ignore_index=-100)
rel_weights = compute_class_weights(labels_list=rel_label_flat, num_classes=MODEL_CONFIG["num_rel_labels"], technique="ens", ignore_index=-100)
print(f"CLS Weights: {cls_weights}")
print(f"BIO Weights: {bio_weights}")
print(f"REL Weights: {rel_weights}")
# %%

cls_labels_value_counts:
 0    1047
1    1035
Name: count, dtype: int64
bio_labels_value_counts:
  6      52764
 3       8717
 1       6948
-100     4164
 2       1320
 0       1179
 5        483
 4         79
Name: count, dtype: int64
rel_labels_value_counts:
 0    2887
1    1494
Name: count, dtype: int64
CLS Weights: tensor([0.0015, 0.0016])
BIO Weights: tensor([0.0014, 0.0010, 0.0014, 0.0010, 0.0132, 0.0026, 0.0010])
REL Weights: tensor([0.0011, 0.0013])


In [5]:
collator = CausalDatasetCollator(
    tokenizer=train_dataset.tokenizer
)
# %%
# take a 100 samples from train_dataset
# train_dataset = torch.utils.data.Subset(train_dataset, random.sample(range(len(train_dataset)), 20))
# val_dataset = torch.utils.data.Subset(val_dataset, random.sample(range(len(val_dataset)), 20))
# # %%
train_dataloader = DataLoader(
    train_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=True

)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=False
)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG["encoder_name"])
model_path = r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/expert_bert_GCE_weakSP_model.pt"
model = JointCausalModel(**MODEL_CONFIG)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()

JointCausalModel(
  (enc): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [7]:
freeze_encoder_layers(model)
# verify that the encoder layers are frozen
for name, param in model.named_parameters():
    if "encoder" in name:
        assert not param.requires_grad, f"Parameter {name} is not frozen!"
print("Encoder layers are frozen successfully.")

Encoder layers are frozen successfully.


In [8]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAINING_CONFIG["learning_rate"],
    weight_decay=TRAINING_CONFIG["weight_decay"]
)
# %%
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=2
)

In [9]:
model_save_path = r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/expert_bert_GCE_Softmax_Freeze_model.pt"

In [10]:
trained_model, training_history = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        num_epochs=TRAINING_CONFIG["num_epochs"],
        device=DEVICE,
        id2label_cls=id2label_cls,
        id2label_bio=id2label_bio,
        id2label_rel=id2label_rel,
        model_save_path=model_save_path,
        scheduler=scheduler,
        cls_class_weights=cls_weights,
        bio_class_weights=bio_weights, # Only for softmax
        rel_class_weights=rel_weights,
        patience_epochs=TRAINING_CONFIG["num_epochs"],
        seed=SEED,
        max_grad_norm=TRAINING_CONFIG["gradient_clip_val"],
        eval_fn_metrics=evaluate_model, # Pass your evaluate_model function here
        print_report_fn=print_eval_report # Pass your print_eval_report function here
    )

--- Training Configuration ---
Device: cuda
Number of Epochs: 30
Seed: 8642
Optimizer: AdamW (LR: 5e-05, Weight Decay: 0.01)
Scheduler: ReduceLROnPlateau
Max Grad Norm: 1.0 (Mode: L2 norm if enabled)
Early Stopping Patience: 30
Model Save Path: /home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/expert_bert_GCE_Softmax_Freeze_model.pt
Mode: Standard Training (CrossEntropy)
Task loss weights not provided, using default: {'cls': 1.0, 'bio': 1.0, 'rel': 1.0}
CLS Class Weights: Provided
BIO Class Weights: Provided
REL Class Weights: Provided
----------------------------


Epoch 1/30 [Training]:   0%|          | 0/131 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
                                                                                                                               


Epoch 1/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           2.6753
  Average Validation Loss:         1.3516
  Overall Validation Avg F1 (Macro): 0.6891
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7533
    Macro Precision: 0.7990
    Macro Recall:    0.7632
    Accuracy:        0.7600
    Per-class details:
      non-causal  : F1=0.7128 (P=0.9116, R=0.5852, Support=229.0)
      causal      : F1=0.7939 (P=0.6865, R=0.9412, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3774
    Macro Precision: 0.3883
    Macro Recall:    0.3744
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.422 (P=0.429, R=0.414, S=263.0)
      I-C       : F1=0.491 (P=0.489, R=0.494, S=1451.0)
      B-E       : F1=0.283 (P=0.392, R=0.221

                                                                                                                               


Epoch 2/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4584
  Average Validation Loss:         1.2655
  Overall Validation Avg F1 (Macro): 0.7163
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7628
    Macro Precision: 0.8069
    Macro Recall:    0.7720
    Accuracy:        0.7689
    Per-class details:
      non-causal  : F1=0.7249 (P=0.9195, R=0.5983, Support=229.0)
      causal      : F1=0.8008 (P=0.6944, R=0.9457, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4537
    Macro Precision: 0.4426
    Macro Recall:    0.4727
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.448 (P=0.426, R=0.471, S=263.0)
      I-C       : F1=0.529 (P=0.541, R=0.518, S=1451.0)
      B-E       : F1=0.360 (P=0.389, R=0.336

                                                                                                                               


Epoch 3/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.3732
  Average Validation Loss:         1.2208
  Overall Validation Avg F1 (Macro): 0.7366
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7671
    Macro Precision: 0.7969
    Macro Recall:    0.7737
    Accuracy:        0.7711
    Per-class details:
      non-causal  : F1=0.7366 (P=0.8889, R=0.6288, Support=229.0)
      causal      : F1=0.7976 (P=0.7049, R=0.9186, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5014
    Macro Precision: 0.5030
    Macro Recall:    0.5055
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.448 (P=0.478, R=0.422, S=263.0)
      I-C       : F1=0.538 (P=0.540, R=0.536, S=1451.0)
      B-E       : F1=0.357 (P=0.422, R=0.310

                                                                                                                               


Epoch 4/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.3447
  Average Validation Loss:         1.1994
  Overall Validation Avg F1 (Macro): 0.7426
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7793
    Macro Precision: 0.8032
    Macro Recall:    0.7845
    Accuracy:        0.7822
    Per-class details:
      non-causal  : F1=0.7538 (P=0.8876, R=0.6550, Support=229.0)
      causal      : F1=0.8048 (P=0.7189, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5026
    Macro Precision: 0.4812
    Macro Recall:    0.5515
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.459 (P=0.429, R=0.494, S=263.0)
      I-C       : F1=0.551 (P=0.522, R=0.584, S=1451.0)
      B-E       : F1=0.357 (P=0.386, R=0.332

                                                                                                                               


Epoch 5/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.3013
  Average Validation Loss:         1.1901
  Overall Validation Avg F1 (Macro): 0.7393
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7793
    Macro Precision: 0.8032
    Macro Recall:    0.7845
    Accuracy:        0.7822
    Per-class details:
      non-causal  : F1=0.7538 (P=0.8876, R=0.6550, Support=229.0)
      causal      : F1=0.8048 (P=0.7189, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5061
    Macro Precision: 0.4897
    Macro Recall:    0.5480
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.456 (P=0.434, R=0.479, S=263.0)
      I-C       : F1=0.548 (P=0.543, R=0.553, S=1451.0)
      B-E       : F1=0.361 (P=0.395, R=0.332

                                                                                                                               


Epoch 6/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2789
  Average Validation Loss:         1.1951
  Overall Validation Avg F1 (Macro): 0.7430
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7793
    Macro Precision: 0.8032
    Macro Recall:    0.7845
    Accuracy:        0.7822
    Per-class details:
      non-causal  : F1=0.7538 (P=0.8876, R=0.6550, Support=229.0)
      causal      : F1=0.8048 (P=0.7189, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5073
    Macro Precision: 0.4905
    Macro Recall:    0.5483
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.455 (P=0.442, R=0.468, S=263.0)
      I-C       : F1=0.550 (P=0.550, R=0.549, S=1451.0)
      B-E       : F1=0.356 (P=0.385, R=0.332

                                                                                                                               


Epoch 7/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2771
  Average Validation Loss:         1.1754
  Overall Validation Avg F1 (Macro): 0.7418
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7840
    Macro Precision: 0.8065
    Macro Recall:    0.7889
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7600 (P=0.8889, R=0.6638, Support=229.0)
      causal      : F1=0.8080 (P=0.7240, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4860
    Macro Recall:    0.5382
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.453 (P=0.442, R=0.464, S=263.0)
      I-C       : F1=0.549 (P=0.553, R=0.545, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.336

                                                                                                                               


Epoch 8/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2775
  Average Validation Loss:         1.1733
  Overall Validation Avg F1 (Macro): 0.7402
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5022
    Macro Precision: 0.4862
    Macro Recall:    0.5382
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.444, R=0.464, S=263.0)
      I-C       : F1=0.549 (P=0.553, R=0.545, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.336

                                                                                                                               


Epoch 9/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.3126
  Average Validation Loss:         1.1644
  Overall Validation Avg F1 (Macro): 0.7464
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7840
    Macro Precision: 0.8065
    Macro Recall:    0.7889
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7600 (P=0.8889, R=0.6638, Support=229.0)
      causal      : F1=0.8080 (P=0.7240, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5020
    Macro Precision: 0.4862
    Macro Recall:    0.5380
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.444, R=0.464, S=263.0)
      I-C       : F1=0.549 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.336

                                                                                                                                


Epoch 10/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2580
  Average Validation Loss:         1.1784
  Overall Validation Avg F1 (Macro): 0.7455
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5020
    Macro Precision: 0.4862
    Macro Recall:    0.5380
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.444, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 11/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2633
  Average Validation Loss:         1.1855
  Overall Validation Avg F1 (Macro): 0.7440
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5020
    Macro Precision: 0.4862
    Macro Recall:    0.5380
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.444, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 12/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2808
  Average Validation Loss:         1.1722
  Overall Validation Avg F1 (Macro): 0.7449
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5022
    Macro Precision: 0.4864
    Macro Recall:    0.5380
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 13/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2847
  Average Validation Loss:         1.1658
  Overall Validation Avg F1 (Macro): 0.7425
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 14/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2748
  Average Validation Loss:         1.1835
  Overall Validation Avg F1 (Macro): 0.7432
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 15/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2603
  Average Validation Loss:         1.1511
  Overall Validation Avg F1 (Macro): 0.7463
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 16/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2807
  Average Validation Loss:         1.1576
  Overall Validation Avg F1 (Macro): 0.7418
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 17/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2724
  Average Validation Loss:         1.1958
  Overall Validation Avg F1 (Macro): 0.7441
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 18/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2623
  Average Validation Loss:         1.1739
  Overall Validation Avg F1 (Macro): 0.7433
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 19/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2722
  Average Validation Loss:         1.1644
  Overall Validation Avg F1 (Macro): 0.7429
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 20/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2670
  Average Validation Loss:         1.1681
  Overall Validation Avg F1 (Macro): 0.7437
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 21/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2823
  Average Validation Loss:         1.1783
  Overall Validation Avg F1 (Macro): 0.7414
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 22/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2859
  Average Validation Loss:         1.1872
  Overall Validation Avg F1 (Macro): 0.7402
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 23/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2718
  Average Validation Loss:         1.1748
  Overall Validation Avg F1 (Macro): 0.7422
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 24/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2873
  Average Validation Loss:         1.1782
  Overall Validation Avg F1 (Macro): 0.7452
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 25/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2721
  Average Validation Loss:         1.1564
  Overall Validation Avg F1 (Macro): 0.7425
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 26/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2623
  Average Validation Loss:         1.1701
  Overall Validation Avg F1 (Macro): 0.7426
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 27/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2517
  Average Validation Loss:         1.1758
  Overall Validation Avg F1 (Macro): 0.7433
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 28/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2825
  Average Validation Loss:         1.1708
  Overall Validation Avg F1 (Macro): 0.7425
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 29/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2647
  Average Validation Loss:         1.1442
  Overall Validation Avg F1 (Macro): 0.7406
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33

                                                                                                                                


Epoch 30/30 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.2699
  Average Validation Loss:         1.1746
  Overall Validation Avg F1 (Macro): 0.7395
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7816
    Macro Precision: 0.8048
    Macro Recall:    0.7867
    Accuracy:        0.7844
    Per-class details:
      non-causal  : F1=0.7569 (P=0.8882, R=0.6594, Support=229.0)
      causal      : F1=0.8064 (P=0.7214, R=0.9140, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.5021
    Macro Precision: 0.4864
    Macro Recall:    0.5379
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.454 (P=0.445, R=0.464, S=263.0)
      I-C       : F1=0.548 (P=0.553, R=0.544, S=1451.0)
      B-E       : F1=0.361 (P=0.391, R=0.33



In [11]:
trained_model.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze"
)

In [12]:
train_dataset.tokenizer.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze"
)

('/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/tokenizer_config.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/special_tokens_map.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/vocab.txt',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/added_tokens.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/tokenizer.json')