In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from config import DEVICE, SEED, MODEL_CONFIG, TRAINING_CONFIG, DATASET_CONFIG
from model import JointCausalModel
from utility import compute_class_weights, label_value_counts
from dataset_collator import CausalDataset, CausalDatasetCollator
from config import id2label_cls, id2label_bio, id2label_rel
from evaluate_joint_causal_model import evaluate_model, print_eval_report
from trainer import train_model
import random
from transformers import AutoTokenizer
from utility import freeze_encoder_layers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/train.csv"
val_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/val.csv"
test_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/test.csv"
train_df = pd.read_csv(train_data_path)
val_df = pd.read_csv(val_data_path)
test_df = pd.read_csv(test_data_path)

In [3]:
train_dataset = CausalDataset(
    train_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
val_dataset = CausalDataset(
    val_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
test_dataset = CausalDataset(
    test_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%

In [4]:
labels_flat = label_value_counts(train_dataset)
# %%
cls_label_flat = labels_flat["cls_labels_flat"]
bio_label_flat = labels_flat["bio_labels_flat"]
rel_label_flat = labels_flat["rel_labels_flat"]
# %%
# Calculate class weights
cls_weights = compute_class_weights(labels_list=cls_label_flat, num_classes=MODEL_CONFIG["num_cls_labels"], technique="ens", ignore_index=-100)
bio_weights = compute_class_weights(labels_list=bio_label_flat, num_classes=MODEL_CONFIG["num_bio_labels"], technique="ens", ignore_index=-100)
rel_weights = compute_class_weights(labels_list=rel_label_flat, num_classes=MODEL_CONFIG["num_rel_labels"], technique="ens", ignore_index=-100)
print(f"CLS Weights: {cls_weights}")
print(f"BIO Weights: {bio_weights}")
print(f"REL Weights: {rel_weights}")
# %%

cls_labels_value_counts:
 0    1047
1    1035
Name: count, dtype: int64
bio_labels_value_counts:
  6      52764
 3       8717
 1       6948
-100     4164
 2       1320
 0       1179
 5        483
 4         79
Name: count, dtype: int64
rel_labels_value_counts:
 0    2887
1    1494
Name: count, dtype: int64
CLS Weights: tensor([0.0015, 0.0016])
BIO Weights: tensor([0.0014, 0.0010, 0.0014, 0.0010, 0.0132, 0.0026, 0.0010])
REL Weights: tensor([0.0011, 0.0013])


In [5]:
collator = CausalDatasetCollator(
    tokenizer=train_dataset.tokenizer
)
# %%
# take a 100 samples from train_dataset
# train_dataset = torch.utils.data.Subset(train_dataset, random.sample(range(len(train_dataset)), 20))
# val_dataset = torch.utils.data.Subset(val_dataset, random.sample(range(len(val_dataset)), 20))
# # %%
train_dataloader = DataLoader(
    train_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=True

)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=False
)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG["encoder_name"])
model_path = r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/expert_bert_GCE_weakSP_model.pt"
model = JointCausalModel(**MODEL_CONFIG)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()

JointCausalModel(
  (enc): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [7]:
freeze_encoder_layers(model)
# verify that the encoder layers are frozen
for name, param in model.named_parameters():
    if "encoder" in name:
        assert not param.requires_grad, f"Parameter {name} is not frozen!"
print("Encoder layers are frozen successfully.")

Encoder layers are frozen successfully.


In [8]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAINING_CONFIG["learning_rate"],
    weight_decay=TRAINING_CONFIG["weight_decay"]
)
# %%
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=2
)

In [9]:
model_save_path = r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/expert_bert_GCE_Softmax_Freeze_model.pt"

In [10]:
trained_model, training_history = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        num_epochs=TRAINING_CONFIG["num_epochs"],
        device=DEVICE,
        id2label_cls=id2label_cls,
        id2label_bio=id2label_bio,
        id2label_rel=id2label_rel,
        model_save_path=model_save_path,
        scheduler=scheduler,
        cls_class_weights=cls_weights,
        bio_class_weights=bio_weights, # Only for softmax
        rel_class_weights=rel_weights,
        patience_epochs=TRAINING_CONFIG["num_epochs"],
        seed=SEED,
        max_grad_norm=TRAINING_CONFIG["gradient_clip_val"],
        eval_fn_metrics=evaluate_model, # Pass your evaluate_model function here
        print_report_fn=print_eval_report # Pass your print_eval_report function here
    )

--- Training Configuration ---
Device: cuda
Number of Epochs: 20
Seed: 8642
Optimizer: AdamW (LR: 1e-05, Weight Decay: 0.1)
Scheduler: ReduceLROnPlateau
Gradient Clipping: Enabled (Max Norm: 1.0)
Early Stopping Patience: 20
Model Save Path: /home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/expert_bert_GCE_Softmax_Freeze_model.pt
Mode: Standard Training (CrossEntropy)
Task loss weights not provided, using default: {'cls': 1.0, 'bio': 1.0, 'rel': 1.0}
CLS Class Weights: Provided
BIO Class Weights: Provided
REL Class Weights: Provided
----------------------------


Epoch 1/20 [Training]:   0%|          | 0/131 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
                                                                                                                               


Epoch 1/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           3.8062
  Average Validation Loss:         3.3185
  Overall Validation Avg F1 (Macro): 0.7095
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7620
    Macro Precision: 0.7955
    Macro Recall:    0.7694
    Accuracy:        0.7667
    Per-class details:
      non-causal  : F1=0.7287 (P=0.8924, R=0.6157, Support=229.0)
      causal      : F1=0.7953 (P=0.6986, R=0.9231, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4495
    Macro Precision: 0.4075
    Macro Recall:    0.5407
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.400 (P=0.345, R=0.475, S=263.0)
      I-C       : F1=0.560 (P=0.490, R=0.653, S=1451.0)
      B-E       : F1=0.364 (P=0.326, R=0.413

                                                                                                                               


Epoch 2/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           2.4599
  Average Validation Loss:         1.8608
  Overall Validation Avg F1 (Macro): 0.7030
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7607
    Macro Precision: 0.7872
    Macro Recall:    0.7669
    Accuracy:        0.7644
    Per-class details:
      non-causal  : F1=0.7310 (P=0.8727, R=0.6288, Support=229.0)
      causal      : F1=0.7905 (P=0.7018, R=0.9050, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4506
    Macro Precision: 0.4082
    Macro Recall:    0.5467
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.409 (P=0.351, R=0.490, S=263.0)
      I-C       : F1=0.556 (P=0.501, R=0.624, S=1451.0)
      B-E       : F1=0.367 (P=0.317, R=0.435

                                                                                                                               


Epoch 3/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.7743
  Average Validation Loss:         1.4442
  Overall Validation Avg F1 (Macro): 0.7126
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7664
    Macro Precision: 0.7849
    Macro Recall:    0.7709
    Accuracy:        0.7689
    Per-class details:
      non-causal  : F1=0.7426 (P=0.8571, R=0.6550, Support=229.0)
      causal      : F1=0.7903 (P=0.7127, R=0.8869, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4601
    Macro Precision: 0.4204
    Macro Recall:    0.5658
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.420 (P=0.370, R=0.487, S=263.0)
      I-C       : F1=0.543 (P=0.520, R=0.567, S=1451.0)
      B-E       : F1=0.389 (P=0.341, R=0.454

                                                                                                                               


Epoch 4/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.5512
  Average Validation Loss:         1.3700
  Overall Validation Avg F1 (Macro): 0.7168
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7703
    Macro Precision: 0.7937
    Macro Recall:    0.7756
    Accuracy:        0.7733
    Per-class details:
      non-causal  : F1=0.7437 (P=0.8757, R=0.6463, Support=229.0)
      causal      : F1=0.7968 (P=0.7117, R=0.9050, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4584
    Macro Precision: 0.4230
    Macro Recall:    0.5623
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.426 (P=0.383, R=0.479, S=263.0)
      I-C       : F1=0.537 (P=0.526, R=0.549, S=1451.0)
      B-E       : F1=0.355 (P=0.322, R=0.395

                                                                                                                               


Epoch 5/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4683
  Average Validation Loss:         1.3232
  Overall Validation Avg F1 (Macro): 0.7202
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7850
    Macro Precision: 0.7999
    Macro Recall:    0.7885
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7659 (P=0.8674, R=0.6856, Support=229.0)
      causal      : F1=0.8041 (P=0.7323, R=0.8914, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4559
    Macro Precision: 0.4288
    Macro Recall:    0.5432
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.430 (P=0.398, R=0.468, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.512, S=1451.0)
      B-E       : F1=0.347 (P=0.331, R=0.365

                                                                                                                               


Epoch 6/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4650
  Average Validation Loss:         1.3291
  Overall Validation Avg F1 (Macro): 0.7206
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4552
    Macro Precision: 0.4285
    Macro Recall:    0.5419
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.429 (P=0.397, R=0.468, S=263.0)
      I-C       : F1=0.523 (P=0.541, R=0.507, S=1451.0)
      B-E       : F1=0.345 (P=0.330, R=0.362

                                                                                                                               


Epoch 7/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4337
  Average Validation Loss:         1.3105
  Overall Validation Avg F1 (Macro): 0.7210
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4565
    Macro Precision: 0.4298
    Macro Recall:    0.5431
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.434 (P=0.401, R=0.471, S=263.0)
      I-C       : F1=0.526 (P=0.543, R=0.510, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.362

                                                                                                                               


Epoch 8/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4322
  Average Validation Loss:         1.3029
  Overall Validation Avg F1 (Macro): 0.7184
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4298
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.526 (P=0.542, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.362

                                                                                                                               


Epoch 9/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4562
  Average Validation Loss:         1.2954
  Overall Validation Avg F1 (Macro): 0.7215
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4298
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.526 (P=0.542, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.362

                                                                                                                                


Epoch 10/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4647
  Average Validation Loss:         1.3114
  Overall Validation Avg F1 (Macro): 0.7242
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4569
    Macro Precision: 0.4299
    Macro Recall:    0.5438
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.526 (P=0.542, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 11/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4309
  Average Validation Loss:         1.2983
  Overall Validation Avg F1 (Macro): 0.7216
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4299
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 12/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4421
  Average Validation Loss:         1.2994
  Overall Validation Avg F1 (Macro): 0.7231
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4299
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 13/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4368
  Average Validation Loss:         1.2903
  Overall Validation Avg F1 (Macro): 0.7208
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4299
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 14/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4119
  Average Validation Loss:         1.3080
  Overall Validation Avg F1 (Macro): 0.7222
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4299
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 15/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4343
  Average Validation Loss:         1.2775
  Overall Validation Avg F1 (Macro): 0.7226
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4569
    Macro Precision: 0.4300
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 16/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4659
  Average Validation Loss:         1.2673
  Overall Validation Avg F1 (Macro): 0.7186
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4569
    Macro Precision: 0.4300
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 17/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4124
  Average Validation Loss:         1.3197
  Overall Validation Avg F1 (Macro): 0.7219
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4569
    Macro Precision: 0.4300
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 18/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4174
  Average Validation Loss:         1.2889
  Overall Validation Avg F1 (Macro): 0.7227
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4569
    Macro Precision: 0.4300
    Macro Recall:    0.5437
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.527 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 19/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4301
  Average Validation Loss:         1.2846
  Overall Validation Avg F1 (Macro): 0.7231
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4299
    Macro Recall:    0.5436
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.526 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36

                                                                                                                                


Epoch 20/20 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           1.4166
  Average Validation Loss:         1.3057
  Overall Validation Avg F1 (Macro): 0.7208
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.7848
    Macro Precision: 0.8011
    Macro Recall:    0.7886
    Accuracy:        0.7867
    Per-class details:
      non-causal  : F1=0.7647 (P=0.8715, R=0.6812, Support=229.0)
      causal      : F1=0.8049 (P=0.7306, R=0.8959, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.4568
    Macro Precision: 0.4299
    Macro Recall:    0.5436
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.436 (P=0.403, R=0.475, S=263.0)
      I-C       : F1=0.526 (P=0.543, R=0.511, S=1451.0)
      B-E       : F1=0.346 (P=0.331, R=0.36



In [11]:
trained_model.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze"
)

In [12]:
train_dataset.tokenizer.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze"
)

('/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/tokenizer_config.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/special_tokens_map.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/vocab.txt',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/added_tokens.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_Softmax_Freeze/hf_expert_bert_GCE_Softmax_Freeze/tokenizer.json')