In [1]:
#Imports
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback, set_seed
from datasets import Dataset
import evaluate

# Import our modules
import sys
import os

scripts_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'scripts'))
if scripts_path not in sys.path:
    sys.path.insert(0, scripts_path)

from utils import load_data, prepare_all_samples
from bert_training import create_training_pairs, handle_class_imbalance, add_special_tokens, tokenize_function, compute_metrics
from bert_extractor import preprocess_input, bert_extraction
from bert_model import BertRC

In [2]:
#Set seed
set_seed(42)

In [3]:
#Data loading
df = load_data("../data/synthetic.csv")
df

Unnamed: 0,patient,note_id,note,document_timestamp,extracted_disorders,formatted_dates,relationship_gold
0,1,0,Ultrasound (30nd Jun 2024): no significant fin...,14/05/2025,"[{'label': 'asthma', 'start': 57, 'end': 63}, ...","[{'original': '(02nd Aug 2024)', 'parsed': '20...","[{'date': '2024-08-02', 'date_position': 311, ..."
1,2,1,Labs (27th Sep 2024): anemia. resolving Skin:...,14/05/2025,"[{'label': 'multiple_sclerosis', 'start': 307,...","[{'original': '(27th Sep 2024)', 'parsed': '20...","[{'date': '2024-09-27', 'date_position': 5, 'd..."
2,3,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,14/05/2025,"[{'label': 'osteoarthritis', 'start': 43, 'end...","[{'original': '(2024-10-04)', 'parsed': '2004-...","[{'date': '2024-10-04', 'date_position': 16, '..."
3,4,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,14/05/2025,"[{'label': 'schizophrenia', 'start': 437, 'end...","[{'original': '(13rd Feb 2025)', 'parsed': '20...","[{'date': '2025-02-13', 'date_position': 14, '..."
4,5,4,New pt((18/11/24)): pt presents with nausea/vo...,14/05/2025,"[{'label': 'diabetes_mellitus', 'start': 440, ...","[{'original': '(18/11/24)', 'parsed': '2024-11...","[{'date': '2024-11-18', 'date_position': 7, 'd..."
...,...,...,...,...,...,...,...
96,7,96,Visit((08/10/24)): pt presents with joint pain...,14/05/2025,"[{'label': 'macroadenoma', 'start': 112, 'end'...","[{'original': '(11/12/2024)', 'parsed': '2024-...","[{'date': '2024-12-11', 'date_position': 624, ..."
97,8,97,F/U (31 Aug 2024): resolved A review of system...,14/05/2025,"[{'label': 'macroadenoma', 'start': 315, 'end'...","[{'original': '(31 Aug 2024)', 'parsed': '2024...","[{'date': '2024-08-31', 'date_position': 4, 'd..."
98,9,98,Phone note((12-10-2024)): slightly improved. o...,14/05/2025,"[{'label': 'tension_headache', 'start': 108, '...","[{'original': '(01/03/2025)', 'parsed': '2025-...","[{'date': '2025-03-01', 'date_position': 228, ..."
99,10,99,F/U (2025-02-23): fluctuating. confirmed multi...,14/05/2025,"[{'label': 'multiple_sclerosis', 'start': 41, ...","[{'original': '(2025-02-23)', 'parsed': '2023-...","[{'date': '2025-02-23', 'date_position': 6, 'd..."


In [4]:
#Prepare samples
samples = prepare_all_samples(df)
samples

[{'note_text': "Ultrasound (30nd Jun 2024): no significant findings.imp: asthma\n\nShe denies any nausea, vomiting, or diarrhea.\nC Patient reports compliance with current medication regimen. Basic metabolic panel within normal limits with sodium 140, potassium 4.2, creatinine 0.9.\nPatient is afebrile with normal vital signs. T (02nd Aug 2024): reveals asthma.imp: asthma\n\nX-ray (12nd Sep 2024): shows 3.1cm mass in brain.imp: pituitary_adenoma\n\nCLINIC VISIT (16 Sep'24): nausea/vomiting worsening confirmed rheumatoid_arthritis switch to aspirin\n\nPast medical history is non-contributory.\nURGENT REVIEW (23rd Oct 2024): headache x1 day.r Will order additional laboratory studies at next visit if symptoms persist. Heart: Regular rate and rhythm, no murmurs. Patient has a history of meningitis. GI: Bowel sounds present in all four quadrants.\n Liver function tests show mild elevation in ALT and AST, likely due to medication effect.Chest X-ray reveals clear lung fields without infiltrat

In [5]:
#Create training pairs
processed_df = create_training_pairs(samples)
processed_df

Unnamed: 0,text,marked_text,ent1_start,ent1_end,ent2_start,ent2_end,label,patient_id,note_id,distance
0,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,57,63,311,326,1,1,0,254
1,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,57,63,587,602,0,1,0,530
2,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,410,427,311,326,0,1,0,99
3,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,410,427,587,602,0,1,0,177
4,Ultrasound (30nd Jun 2024): no significant fin...,Ultrasound (30nd Jun 2024): no significant fin...,491,511,311,326,0,1,0,180
...,...,...,...,...,...,...,...,...,...,...
1237,CLINIC VISIT (15/06/2025): Patient was Current...,CLINIC VISIT [E2] (15/06/2025) [/E2]: Patient ...,1165,1175,13,25,1,1,100,1152
1238,CLINIC VISIT (15/06/2025): Patient was Current...,CLINIC VISIT [E2] (15/06/2025) [/E2]: Patient ...,1041,1058,13,25,0,1,100,1028
1239,CLINIC VISIT (15/06/2025): Patient was Current...,CLINIC VISIT [E2] (15/06/2025) [/E2]: Patient ...,274,297,13,25,0,1,100,261
1240,CLINIC VISIT (15/06/2025): Patient was Current...,CLINIC VISIT [E2] (15/06/2025) [/E2]: Patient ...,827,845,13,25,0,1,100,814


In [6]:
#Handle class imbalance
balanced_df, class_weights = handle_class_imbalance(processed_df, method='weighted')
class_weights

tensor([0.2721, 1.7279])

In [7]:
#Train-test split
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42, stratify=balanced_df['label'])

In [8]:
#Tokenizer
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = add_special_tokens(tokenizer)

In [9]:
#Data prep
train_dataset = Dataset.from_pandas(train_df[['marked_text', 'label']])
test_dataset = Dataset.from_pandas(test_df[['marked_text', 'label']])

train_tokenized = train_dataset.map(lambda x: tokenize_function(x, tokenizer, max_length=256), batched=True)
test_tokenized = test_dataset.map(lambda x: tokenize_function(x, tokenizer, max_length=256), batched=True)

train_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

Map:   0%|          | 0/993 [00:00<?, ? examples/s]

Map:   0%|          | 0/249 [00:00<?, ? examples/s]

In [10]:
#Model creation
model = BertRC(
    model_name=model_name,
    tokenizer=tokenizer,
    num_labels=2,
    class_weights=class_weights
)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [12]:
#Training args
training_args = TrainingArguments(
    output_dir="./bert_rc_results",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    num_train_epochs=3,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    report_to=[],
    seed=42,
)

In [13]:
#Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

  trainer = Trainer(


In [14]:
#Training
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro,F1 Weighted
1,0.7122,0.822181,0.863454,0.463362,0.863454,0.800183
2,0.6735,0.719191,0.815261,0.488661,0.815261,0.785719
3,0.4967,0.75417,0.787149,0.504933,0.787149,0.776641




TrainOutput(global_step=375, training_loss=0.6302923380533855, metrics={'train_runtime': 5818.9495, 'train_samples_per_second': 0.512, 'train_steps_per_second': 0.064, 'total_flos': 0.0, 'train_loss': 0.6302923380533855, 'epoch': 3.0})

In [15]:
#Evaluation
eval_results = trainer.evaluate(test_tokenized)
print("Test Results:")
for metric, value in eval_results.items():
    if not metric.startswith('eval_'):
        continue
    clean_metric = metric.replace('eval_', '')
    print(f"{clean_metric}: {value:.4f}")



Test Results:
loss: 0.7542
accuracy: 0.7871
f1_macro: 0.5049
f1_micro: 0.7871
f1_weighted: 0.7766
runtime: 89.2447
samples_per_second: 2.7900
steps_per_second: 0.1790


In [None]:
#Save model
trainer.save_model("./bert_rc_final_model")
tokenizer.save_pretrained("./bert_rc_final_model")