In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from config import DEVICE, SEED, MODEL_CONFIG, TRAINING_CONFIG, DATASET_CONFIG
from model import JointCausalModel
from utility import compute_class_weights, label_value_counts
from dataset_collator import CausalDataset, CausalDatasetCollator
from config import id2label_cls, id2label_bio, id2label_rel
from evaluate_joint_causal_model import evaluate_model, print_eval_report
from trainer import train_model
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# train_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/train.csv"
train_data_path = "/home/rnorouzini/JointLearning/datasets/pseudo_annotate_data/llama3_8b_processed.csv"
val_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/val.csv"
test_data_path = "/home/rnorouzini/JointLearning/datasets/expert_multi_task_data/test.csv"
train_df = pd.read_csv(train_data_path)
val_df = pd.read_csv(val_data_path)
test_df = pd.read_csv(test_data_path)

In [3]:
train_dataset = CausalDataset(
    train_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
val_dataset = CausalDataset(
    val_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%
test_dataset = CausalDataset(
    test_df,
    tokenizer_name=MODEL_CONFIG["encoder_name"],
    max_length=DATASET_CONFIG["max_length"],
)
# %%

In [4]:
labels_flat = label_value_counts(train_dataset)
# %%
cls_label_flat = labels_flat["cls_labels_flat"]
bio_label_flat = labels_flat["bio_labels_flat"]
rel_label_flat = labels_flat["rel_labels_flat"]
# %%
# Calculate class weights
cls_weights = compute_class_weights(labels_list=cls_label_flat, num_classes=MODEL_CONFIG["num_cls_labels"], technique="ens", ignore_index=-100)
bio_weights = compute_class_weights(labels_list=bio_label_flat, num_classes=MODEL_CONFIG["num_bio_labels"], technique="ens", ignore_index=-100)
rel_weights = compute_class_weights(labels_list=rel_label_flat, num_classes=MODEL_CONFIG["num_rel_labels"], technique="ens", ignore_index=-100)
print(f"CLS Weights: {cls_weights}")
print(f"BIO Weights: {bio_weights}")
print(f"REL Weights: {rel_weights}")
# %%

cls_labels_value_counts:
 1    90808
0     9192
Name: count, dtype: int64
bio_labels_value_counts:
  6      1705521
 3       784784
 1       674758
-100     200000
 2       117875
 0       113565
 5           84
 4            8
Name: count, dtype: int64
rel_labels_value_counts:
 0    249811
1    128406
Name: count, dtype: int64
CLS Weights: tensor([0.0010, 0.0010])
BIO Weights: tensor([0.0010, 0.0010, 0.0010, 0.0010, 0.1254, 0.0124, 0.0010])
REL Weights: tensor([0.0010, 0.0010])


In [8]:
collator = CausalDatasetCollator(
    tokenizer=train_dataset.tokenizer
)
# %%
# take a 100 samples from train_dataset
# train_dataset = torch.utils.data.Subset(train_dataset, random.sample(range(len(train_dataset)), 20))
# val_dataset = torch.utils.data.Subset(val_dataset, random.sample(range(len(val_dataset)), 20))
# # %%
train_dataloader = DataLoader(
    train_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=True

)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    collate_fn=collator,
    shuffle=False
)

In [9]:
model = JointCausalModel(
    encoder_name=MODEL_CONFIG["encoder_name"],
    num_cls_labels=MODEL_CONFIG["num_cls_labels"],
    num_bio_labels=MODEL_CONFIG["num_bio_labels"],
    num_rel_labels=MODEL_CONFIG["num_rel_labels"],
    dropout=MODEL_CONFIG["dropout"]
)


In [10]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAINING_CONFIG["learning_rate"],
    weight_decay=TRAINING_CONFIG["weight_decay"]
)
# %%
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=2
)

In [11]:
model_save_path = r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/expert_bert_GCE_weakSP_model.pt"

In [12]:
trained_model, training_history = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        num_epochs=TRAINING_CONFIG["num_epochs"],
        device=DEVICE,
        id2label_cls=id2label_cls,
        id2label_bio=id2label_bio,
        id2label_rel=id2label_rel,
        model_save_path=model_save_path,
        scheduler=scheduler,
        patience_epochs=TRAINING_CONFIG["num_epochs"],
        seed=SEED,
        max_grad_norm=TRAINING_CONFIG["gradient_clip_val"],
        eval_fn_metrics=evaluate_model, # Pass your evaluate_model function here
        print_report_fn=print_eval_report, # Pass your print_eval_report function here
        is_silver_training=True
    )

--- Training Configuration ---
Device: cuda
Number of Epochs: 10
Seed: 8642
Optimizer: AdamW (LR: 1e-05, Weight Decay: 0.1)
Scheduler: ReduceLROnPlateau
Max Grad Norm: 1.0 (Mode: L2 norm if enabled)
Early Stopping Patience: 10
Model Save Path: /home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/expert_bert_GCE_weakSP_model.pt
Mode: Silver Data Training (GCE)
GCE q value: 0.7
Task loss weights not provided, using default: {'cls': 1.0, 'bio': 1.0, 'rel': 1.0}
CLS Class Weights: None
BIO Class Weights: None
REL Class Weights: None
----------------------------


Epoch 1/10 [Training]:   0%|          | 0/6250 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
                                                                                                                                 


Epoch 1/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.6193
  Average Validation Loss:         1.0684
  Overall Validation Avg F1 (Macro): 0.6124
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.5592
    Macro Precision: 0.7812
    Macro Recall:    0.6245
    Accuracy:        0.6178
    Per-class details:
      non-causal  : F1=0.3986 (P=1.0000, R=0.2489, Support=229.0)
      causal      : F1=0.7199 (P=0.5623, R=1.0000, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3617
    Macro Precision: 0.3310
    Macro Recall:    0.4428
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.342 (P=0.277, R=0.445, S=263.0)
      I-C       : F1=0.511 (P=0.388, R=0.748, S=1451.0)
      B-E       : F1=0.345 (P=0.293, R=0.421

                                                                                                                                 


Epoch 2/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4703
  Average Validation Loss:         0.9613
  Overall Validation Avg F1 (Macro): 0.6405
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6150
    Macro Precision: 0.7748
    Macro Recall:    0.6592
    Accuracy:        0.6533
    Per-class details:
      non-causal  : F1=0.4935 (P=0.9620, R=0.3319, Support=229.0)
      causal      : F1=0.7365 (P=0.5876, R=0.9864, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3781
    Macro Precision: 0.3430
    Macro Recall:    0.4575
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.361 (P=0.290, R=0.479, S=263.0)
      I-C       : F1=0.529 (P=0.402, R=0.770, S=1451.0)
      B-E       : F1=0.361 (P=0.301, R=0.450

                                                                                                                                 


Epoch 3/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.4094
  Average Validation Loss:         0.9829
  Overall Validation Avg F1 (Macro): 0.6383
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6241
    Macro Precision: 0.7779
    Macro Recall:    0.6657
    Accuracy:        0.6600
    Per-class details:
      non-causal  : F1=0.5080 (P=0.9634, R=0.3450, Support=229.0)
      causal      : F1=0.7402 (P=0.5924, R=0.9864, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3678
    Macro Precision: 0.3327
    Macro Recall:    0.4567
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.363 (P=0.283, R=0.506, S=263.0)
      I-C       : F1=0.511 (P=0.389, R=0.746, S=1451.0)
      B-E       : F1=0.350 (P=0.283, R=0.458

                                                                                                                                 


Epoch 4/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.3630
  Average Validation Loss:         0.9332
  Overall Validation Avg F1 (Macro): 0.6500
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6489
    Macro Precision: 0.7973
    Macro Recall:    0.6855
    Accuracy:        0.6800
    Per-class details:
      non-causal  : F1=0.5443 (P=0.9885, R=0.3755, Support=229.0)
      causal      : F1=0.7534 (P=0.6061, R=0.9955, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3790
    Macro Precision: 0.3428
    Macro Recall:    0.4593
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.382 (P=0.305, R=0.510, S=263.0)
      I-C       : F1=0.521 (P=0.401, R=0.744, S=1451.0)
      B-E       : F1=0.352 (P=0.288, R=0.454

                                                                                                                                 


Epoch 5/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.3070
  Average Validation Loss:         0.9556
  Overall Validation Avg F1 (Macro): 0.6461
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6370
    Macro Precision: 0.7937
    Macro Recall:    0.6768
    Accuracy:        0.6711
    Per-class details:
      non-causal  : F1=0.5256 (P=0.9880, R=0.3581, Support=229.0)
      causal      : F1=0.7483 (P=0.5995, R=0.9955, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3750
    Macro Precision: 0.3392
    Macro Recall:    0.4559
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.383 (P=0.305, R=0.513, S=263.0)
      I-C       : F1=0.520 (P=0.404, R=0.729, S=1451.0)
      B-E       : F1=0.342 (P=0.278, R=0.446

                                                                                                                                 


Epoch 6/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.2952
  Average Validation Loss:         0.9500
  Overall Validation Avg F1 (Macro): 0.6475
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6361
    Macro Precision: 0.7820
    Macro Recall:    0.6744
    Accuracy:        0.6689
    Per-class details:
      non-causal  : F1=0.5270 (P=0.9651, R=0.3624, Support=229.0)
      causal      : F1=0.7453 (P=0.5989, R=0.9864, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3760
    Macro Precision: 0.3402
    Macro Recall:    0.4564
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.378 (P=0.304, R=0.498, S=263.0)
      I-C       : F1=0.519 (P=0.404, R=0.724, S=1451.0)
      B-E       : F1=0.357 (P=0.289, R=0.469

                                                                                                                                 


Epoch 7/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.2886
  Average Validation Loss:         0.9682
  Overall Validation Avg F1 (Macro): 0.6486
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6351
    Macro Precision: 0.7873
    Macro Recall:    0.6745
    Accuracy:        0.6689
    Per-class details:
      non-causal  : F1=0.5240 (P=0.9762, R=0.3581, Support=229.0)
      causal      : F1=0.7462 (P=0.5984, R=0.9910, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3751
    Macro Precision: 0.3394
    Macro Recall:    0.4586
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.372 (P=0.299, R=0.494, S=263.0)
      I-C       : F1=0.527 (P=0.407, R=0.744, S=1451.0)
      B-E       : F1=0.354 (P=0.286, R=0.465

                                                                                                                                 


Epoch 8/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.2800
  Average Validation Loss:         0.9509
  Overall Validation Avg F1 (Macro): 0.6498
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6469
    Macro Precision: 0.7911
    Macro Recall:    0.6832
    Accuracy:        0.6778
    Per-class details:
      non-causal  : F1=0.5426 (P=0.9773, R=0.3755, Support=229.0)
      causal      : F1=0.7513 (P=0.6050, R=0.9910, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3752
    Macro Precision: 0.3395
    Macro Recall:    0.4563
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.371 (P=0.297, R=0.494, S=263.0)
      I-C       : F1=0.523 (P=0.406, R=0.732, S=1451.0)
      B-E       : F1=0.354 (P=0.287, R=0.461

                                                                                                                                 


Epoch 9/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.2790
  Average Validation Loss:         0.9405
  Overall Validation Avg F1 (Macro): 0.6517
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6499
    Macro Precision: 0.7921
    Macro Recall:    0.6854
    Accuracy:        0.6800
    Per-class details:
      non-causal  : F1=0.5472 (P=0.9775, R=0.3799, Support=229.0)
      causal      : F1=0.7526 (P=0.6066, R=0.9910, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3765
    Macro Precision: 0.3406
    Macro Recall:    0.4569
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.374 (P=0.300, R=0.498, S=263.0)
      I-C       : F1=0.522 (P=0.407, R=0.728, S=1451.0)
      B-E       : F1=0.357 (P=0.290, R=0.465

                                                                                                                                  


Epoch 10/10 Summary
--------------------------------------------------------------------------------
  Average Training Loss:           0.2785
  Average Validation Loss:         0.9356
  Overall Validation Avg F1 (Macro): 0.6503
--------------------------------------------------------------------------------
Task-Specific Validation Performance:

  [Task 1: Sentence Classification]
    Macro F1-Score:  0.6479
    Macro Precision: 0.7861
    Macro Recall:    0.6832
    Accuracy:        0.6778
    Per-class details:
      non-causal  : F1=0.5455 (P=0.9667, R=0.3799, Support=229.0)
      causal      : F1=0.7504 (P=0.6056, R=0.9864, Support=221.0)

  [Task 2: BIO Prediction (Token-BIO)]
    Macro F1-Score:  0.3767
    Macro Precision: 0.3409
    Macro Recall:    0.4568
    Per-tag details (P=Precision, R=Recall, F1=F1-Score, S=Support):
      B-C       : F1=0.376 (P=0.303, R=0.498, S=263.0)
      I-C       : F1=0.523 (P=0.408, R=0.728, S=1451.0)
      B-E       : F1=0.356 (P=0.288, R=0.46



In [13]:
trained_model.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/hf_exper_bert_GCE_weakSP"
)

In [15]:
train_dataset.tokenizer.save_pretrained(
    r"/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/hf_exper_bert_GCE_weakSP"
)

('/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/hf_exper_bert_GCE_weakSP/tokenizer_config.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/hf_exper_bert_GCE_weakSP/special_tokens_map.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/hf_exper_bert_GCE_weakSP/vocab.txt',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/hf_exper_bert_GCE_weakSP/added_tokens.json',
 '/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_GCE_weakSP/hf_exper_bert_GCE_weakSP/tokenizer.json')