Imports

In [30]:
#imports
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback, set_seed
from datasets import Dataset
import evaluate
from tqdm import tqdm

# Import our modules
import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
models_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'models'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)
if models_path not in sys.path:
    sys.path.insert(0, models_path)

from utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics
from bert_training import create_training_pairs, handle_class_imbalance, add_special_tokens, tokenize_function, compute_metrics, build_gold_lookup, get_label_for_pair
from bert_extractor import preprocess_input, bert_extraction, mark_entities_full_text
from bert_model import BertRC

In [3]:
# Set seed for reproducibility
set_seed(42)

Data Loading

In [4]:
# Load data
df = load_data("../data/synthetic.csv")
print(f"Loaded {len(df)} records")
df

Loaded 101 records


Unnamed: 0,patient,note_id,note,document_timestamp,extracted_disorders,formatted_dates,relationship_gold
0,1,0,Ultrasound (30nd Jun 2024): no significant fin...,14/05/2025,"[{'label': 'asthma', 'start': 57, 'end': 63}, ...","[{'original': '(02nd Aug 2024)', 'parsed': '20...","[{'date': '2024-08-02', 'date_position': 311, ..."
1,2,1,Labs (27th Sep 2024): anemia. resolving Skin:...,14/05/2025,"[{'label': 'multiple_sclerosis', 'start': 307,...","[{'original': '(27th Sep 2024)', 'parsed': '20...","[{'date': '2024-09-27', 'date_position': 5, 'd..."
2,3,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,14/05/2025,"[{'label': 'osteoarthritis', 'start': 43, 'end...","[{'original': '(2024-10-04)', 'parsed': '2004-...","[{'date': '2024-10-04', 'date_position': 16, '..."
3,4,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,14/05/2025,"[{'label': 'schizophrenia', 'start': 437, 'end...","[{'original': '(13rd Feb 2025)', 'parsed': '20...","[{'date': '2025-02-13', 'date_position': 14, '..."
4,5,4,New pt((18/11/24)): pt presents with nausea/vo...,14/05/2025,"[{'label': 'diabetes_mellitus', 'start': 440, ...","[{'original': '(18/11/24)', 'parsed': '2024-11...","[{'date': '2024-11-18', 'date_position': 7, 'd..."
...,...,...,...,...,...,...,...
96,7,96,Visit((08/10/24)): pt presents with joint pain...,14/05/2025,"[{'label': 'macroadenoma', 'start': 112, 'end'...","[{'original': '(11/12/2024)', 'parsed': '2024-...","[{'date': '2024-12-11', 'date_position': 624, ..."
97,8,97,F/U (31 Aug 2024): resolved A review of system...,14/05/2025,"[{'label': 'macroadenoma', 'start': 315, 'end'...","[{'original': '(31 Aug 2024)', 'parsed': '2024...","[{'date': '2024-08-31', 'date_position': 4, 'd..."
98,9,98,Phone note((12-10-2024)): slightly improved. o...,14/05/2025,"[{'label': 'tension_headache', 'start': 108, '...","[{'original': '(01/03/2025)', 'parsed': '2025-...","[{'date': '2025-03-01', 'date_position': 228, ..."
99,10,99,F/U (2025-02-23): fluctuating. confirmed multi...,14/05/2025,"[{'label': 'multiple_sclerosis', 'start': 41, ...","[{'original': '(2025-02-23)', 'parsed': '2023-...","[{'date': '2025-02-23', 'date_position': 6, 'd..."


In [5]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
samples[0]

Prepared 101 samples


{'note_text': "Ultrasound (30nd Jun 2024): no significant findings.imp: asthma\n\nShe denies any nausea, vomiting, or diarrhea.\nC Patient reports compliance with current medication regimen. Basic metabolic panel within normal limits with sodium 140, potassium 4.2, creatinine 0.9.\nPatient is afebrile with normal vital signs. T (02nd Aug 2024): reveals asthma.imp: asthma\n\nX-ray (12nd Sep 2024): shows 3.1cm mass in brain.imp: pituitary_adenoma\n\nCLINIC VISIT (16 Sep'24): nausea/vomiting worsening confirmed rheumatoid_arthritis switch to aspirin\n\nPast medical history is non-contributory.\nURGENT REVIEW (23rd Oct 2024): headache x1 day.r Will order additional laboratory studies at next visit if symptoms persist. Heart: Regular rate and rhythm, no murmurs. Patient has a history of meningitis. GI: Bowel sounds present in all four quadrants.\n Liver function tests show mild elevation in ALT and AST, likely due to medication effect.Chest X-ray reveals clear lung fields without infiltrate

Test Pre-Processing & Utility Functions

In [6]:
# Example note for testing
note_text = (
    "Patient diagnosed with asthma on 2024-08-02. "
    "Diabetes was ruled out on 2024-08-02. "
    "Family history of hypertension, last reviewed in 2022. "
    "Patient may have pneumonia, last seen on 2024-08-02."
)

# Example entity and date spans
entity = {'start': 23, 'end': 29, 'label': 'asthma'}
date = {'start': 33, 'end': 43, 'parsed': '2024-08-02', 'original': '2024-08-02'}

print("Example note:", note_text)
print("Entity:", entity)
print("Date:", date)

Example note: Patient diagnosed with asthma on 2024-08-02. Diabetes was ruled out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.
Entity: {'start': 23, 'end': 29, 'label': 'asthma'}
Date: {'start': 33, 'end': 43, 'parsed': '2024-08-02', 'original': '2024-08-02'}


In [31]:
# Test each function explicitly
print("Testing build_gold_lookup...")
gold_map = build_gold_lookup(samples[0]['relationship_gold'])
print(f"Gold map: {gold_map}")

Testing build_gold_lookup...
Gold map: {(57, 311): 'link', (1143, 587): 'link'}


In [32]:
print("Testing get_label_for_pair...")
label = get_label_for_pair(57, 311, gold_map)  # Example positions
print(f"Label: {label}")

Testing get_label_for_pair...
Label: link


In [33]:
print("Testing mark_entities_full_text...")
marked = mark_entities_full_text(note_text, 23, 29, 33, 43, "asthma", "2024-08-02")
print(f"Marked text: {marked}")

Testing mark_entities_full_text...
Marked text: Patient diagnosed with [E1] asthma [/E1] on [E2] 2024-08-02 [/E2]. Diabetes was ruled out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.


In [34]:
# Test preprocessing
preprocessed = preprocess_input(note_text, entity, date)
print("\nPreprocessed input:")
print(preprocessed['marked_text'])


Preprocessed input:
Patient diagnosed with asthma on 2024-08-02. Diabetes was[E1]  ruled [/E1] out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.[E2]  [/E2]


In [35]:
# ============================================================================
# COMPREHENSIVE ENTITY MARKING AND PREPROCESSING TESTS
# ============================================================================

# Test with the full first sample from our data
sample = samples[0]
print("=" * 60)
print("TESTING WITH FULL SAMPLE")
print("=" * 60)
print(f"Sample note length: {len(sample['note_text'])}")
print(f"Number of disorders: {len(sample['entities_list'])}")
print(f"Number of dates: {len(sample['dates'])}")
print(f"Number of gold relationships: {len(sample['relationship_gold'])}")

# Test all disorder-date combinations
print("\nTesting all disorder-date combinations:")
for i, disorder in enumerate(sample['entities_list']):
    for j, date in enumerate(sample['dates']):
        print(f"\n--- Combination {i+1}-{j+1}: {disorder['label']} + {date['parsed']} ---")
        
        # Test preprocessing
        processed = preprocess_input(sample['note_text'], disorder, date)
        
        # Show the marked text (truncated for readability)
        marked_text = processed['marked_text']
        print(f"Original text length: {len(sample['note_text'])}")
        print(f"Marked text length: {len(marked_text)}")
        
        # Show a snippet around the marked entities
        disorder_start = processed['ent1_start']
        date_start = processed['ent2_start']
        
        # Find the context around both entities
        context_start = max(0, min(disorder_start, date_start) - 50)
        context_end = min(len(marked_text), max(disorder_start, date_start) + 100)
        context = marked_text[context_start:context_end]
        
        print(f"Context snippet: ...{context}...")
        
        # Test gold lookup
        gold_map = build_gold_lookup(sample['relationship_gold'])
        label = get_label_for_pair(disorder['start'], date['start'], gold_map)
        print(f"Gold label: {label}")

# Test edge cases
print("\n" + "=" * 60)
print("TESTING EDGE CASES")
print("=" * 60)

# Test with entities at the very beginning and end of text
print("Testing entities at text boundaries...")
first_disorder = sample['entities_list'][0]
last_date = sample['dates'][-1]

processed_edge = preprocess_input(sample['note_text'], first_disorder, last_date)
print(f"First disorder position: {first_disorder['start']}-{first_disorder['end']}")
print(f"Last date position: {last_date['start']}-{last_date['end']}")

# Show beginning and end of marked text
print(f"Marked text start: {processed_edge['marked_text'][:100]}...")
print(f"Marked text end: ...{processed_edge['marked_text'][-100:]}")

# Test with overlapping entities (if any)
print("\nTesting for potential overlapping entities...")
for i, disorder in enumerate(sample['entities_list']):
    for j, date in enumerate(sample['dates']):
        if abs(disorder['start'] - date['start']) < 10:  # Close entities
            print(f"Close entities found: {disorder['label']} at {disorder['start']}, {date['parsed']} at {date['start']}")
            processed_close = preprocess_input(sample['note_text'], disorder, date)
            print(f"Marked text: {processed_close['marked_text'][disorder['start']-20:date['end']+20]}")

# Test gold relationship mapping
print("\n" + "=" * 60)
print("TESTING GOLD RELATIONSHIP MAPPING")
print("=" * 60)

gold_map = build_gold_lookup(sample['relationship_gold'])
print(f"Gold map: {gold_map}")

# Test each gold relationship
for rel in sample['relationship_gold']:
    print(f"\nGold relationship: {rel['date']} (position {rel['date_position']})")
    for diag in rel['diagnoses']:
        print(f"  - {diag['diagnosis']} (position {diag['position']})")
        # Find the corresponding disorder in entities_list
        matching_disorder = None
        for disorder in sample['entities_list']:
            if disorder['start'] == diag['position']:
                matching_disorder = disorder
                break
        
        if matching_disorder:
            # Find the corresponding date
            matching_date = None
            for date in sample['dates']:
                if date['start'] == rel['date_position']:
                    matching_date = date
                    break
            
            if matching_date:
                print(f"    -> Found matching pair: {matching_disorder['label']} + {matching_date['parsed']}")
                processed_gold = preprocess_input(sample['note_text'], matching_disorder, matching_date)
                print(f"    -> Marked text snippet: ...{processed_gold['marked_text'][matching_disorder['start']-30:matching_date['end']+30]}...")
            else:
                print(f"    -> No matching date found for position {rel['date_position']}")
        else:
            print(f"    -> No matching disorder found for position {diag['position']}")

print("\n" + "=" * 60)
print("COMPREHENSIVE TESTING COMPLETE")
print("=" * 60)

TESTING WITH FULL SAMPLE
Sample note length: 1319
Number of disorders: 6
Number of dates: 2
Number of gold relationships: 2

Testing all disorder-date combinations:

--- Combination 1-1: asthma + 2024-08-02 ---
Original text length: 1319
Marked text length: 1341
Context snippet: ...und (30nd Jun 2024): no significant findings.imp: [E1] asthma [/E1]

She denies any nausea, vomiting, or diarrhea.
C Patient reports compliance with current medication regimen. Basic metabolic panel within normal limits with sodium 140, potassium 4.2, creatinine 0.9.
Patient is afebrile with normal vital signs. T [E2] (02nd Aug 2024) [/E2]: reveals asthma.imp: asthma

X-ray (12nd Sep 2024): shows 3.1c...
Gold label: link

--- Combination 1-2: asthma + 2024-10-23 ---
Original text length: 1319
Marked text length: 1341
Context snippet: ...und (30nd Jun 2024): no significant findings.imp: [E1] asthma [/E1]

She denies any nausea, vomiting, or diarrhea.
C Patient reports compliance with current medication regime

BERT Base

In [8]:
# Load base model and tokenizer
model_name = "google/bert_uncased_L-2_H-128_A-2"
#model_name = "emilyalsentzer/Bio_ClinicalBERT"
base_tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# Test base model on a few samples
print("\nTesting base model (no finetuning):")
base_predictions = []

for sample in tqdm(samples[:10], desc="Base model testing"):  # Test on first 10 samples
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']
        pred, conf = bert_extraction(sample['note_text'], entity, date, base_model, base_tokenizer)
        if pred == 1:
            base_predictions.append({'entity_label': entity['label'], 'date': date.get('parsed'), 'confidence': conf})

print(f"Base model predictions: {len(base_predictions)}")


Testing base model (no finetuning):


Base model testing: 100%|██████████| 10/10 [00:00<00:00, 11.06it/s]

Base model predictions: 128





In [10]:
#Look at predictions
base_predictions

[{'entity_label': 'asthma',
  'date': '2024-08-02',
  'confidence': 0.5452772378921509},
 {'entity_label': 'asthma',
  'date': '2024-10-23',
  'confidence': 0.5476818084716797},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-08-02',
  'confidence': 0.5448076128959656},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-10-23',
  'confidence': 0.5468413829803467},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-08-02',
  'confidence': 0.5446170568466187},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-10-23',
  'confidence': 0.5465338826179504},
 {'entity_label': 'pneumonia',
  'date': '2024-08-02',
  'confidence': 0.5450191497802734},
 {'entity_label': 'pneumonia',
  'date': '2024-10-23',
  'confidence': 0.5448727011680603},
 {'entity_label': 'gerd',
  'date': '2024-08-02',
  'confidence': 0.5450191497802734},
 {'entity_label': 'gerd',
  'date': '2024-10-23',
  'confidence': 0.5448727011680603},
 {'entity_label': 'meningitis',
  'date': '2024-08-02',
  

In [11]:
# Calculate base model metrics
base_metrics = calculate_metrics(base_predictions, df)
print("Base model metrics:", base_metrics)

Base model metrics: {'precision': 0.140625, 'recall': 0.08411214953271028, 'f1': 0.10526315789473684, 'tp': 18, 'fp': 110, 'fn': 196}


Data Preparation for Finetuning

In [12]:
# Create training pairs using best approach
processed_df = create_training_pairs(samples)
print(f"\nCreated {len(processed_df)} training pairs")


Created 1242 training pairs


In [13]:
# Handle class imbalance
balanced_df, class_weights = handle_class_imbalance(processed_df, method='weighted')
print(f"Class weights: {class_weights}")

Class weights: tensor([0.2721, 1.7279])


In [14]:
# Train-test split
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42, stratify=balanced_df['label'])
print(f"Train: {len(train_df)}, Test: {len(test_df)}")

Train: 993, Test: 249


In [15]:
# Setup tokenizer with special tokens
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = add_special_tokens(tokenizer)

In [17]:
# Resize model embeddings to match new tokenizer size
base_model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(30526, 128, padding_idx=0)

In [None]:
# Prepare PyTorch datasets
train_dataset = Dataset.from_pandas(train_df[['marked_text', 'label']])
test_dataset = Dataset.from_pandas(test_df[['marked_text', 'label']])

In [19]:
# Tokenize
train_tokenized = train_dataset.map(lambda x: tokenize_function(x, tokenizer, max_length=256), batched=True)
test_tokenized = test_dataset.map(lambda x: tokenize_function(x, tokenizer, max_length=256), batched=True)

Map:   0%|          | 0/993 [00:00<?, ? examples/s]

Map:   0%|          | 0/249 [00:00<?, ? examples/s]

In [20]:
# Set format for PyTorch
train_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

BERT Finetuning

In [21]:
# Create custom model with span pooling
model = BertRC(
    model_name=model_name,
    tokenizer=tokenizer,
    num_labels=2,
    class_weights=class_weights
)

In [22]:
# Resize model embeddings to match new tokenizer size
model.backbone.resize_token_embeddings(len(tokenizer))

Embedding(30526, 128, padding_idx=0)

In [23]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./bert_rc_results",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    num_train_epochs=3,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    report_to=[],
    seed=42,
)

In [24]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

  trainer = Trainer(


In [25]:
#Train
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro,F1 Weighted
1,0.6897,0.702927,0.795181,0.548801,0.795181,0.791164
2,0.7045,0.713931,0.843373,0.588255,0.843373,0.823849
3,0.7021,0.69395,0.799197,0.563342,0.799197,0.796619




TrainOutput(global_step=375, training_loss=0.6900090840657552, metrics={'train_runtime': 60.202, 'train_samples_per_second': 49.483, 'train_steps_per_second': 6.229, 'total_flos': 0.0, 'train_loss': 0.6900090840657552, 'epoch': 3.0})

In [26]:
# Evaluate on test set
eval_results = trainer.evaluate(test_tokenized)
print("\nTest Results:")
for metric, value in eval_results.items():
    if not metric.startswith('eval_'):
        continue
    clean_metric = metric.replace('eval_', '')
    print(f"{clean_metric}: {value:.4f}")




Test Results:
loss: 0.7139
accuracy: 0.8434
f1_macro: 0.5883
f1_micro: 0.8434
f1_weighted: 0.8238
runtime: 0.7401
samples_per_second: 336.4230
steps_per_second: 21.6180


In [None]:
# Save the final model
trainer.save_model("./bert_rc_final_model")
tokenizer.save_pretrained("./bert_rc_final_model")
print("\nModel saved to ./bert_rc_final_model")