Imports

In [1]:
#imports
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback, set_seed
from datasets import Dataset
import evaluate
from tqdm import tqdm

# Import our modules
import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
models_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'models'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)
if models_path not in sys.path:
    sys.path.insert(0, models_path)

from utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics
from bert_training import create_training_pairs, handle_class_imbalance, add_special_tokens, tokenize_function, compute_metrics, build_gold_lookup, get_label_for_pair
from bert_extractor import preprocess_input, bert_extraction, mark_entities_full_text
from bert_model import BertRC
from relative_date_extractor import add_relative_dates

In [2]:
# Set seed for reproducibility
set_seed(42)

Data Loading

In [3]:
# Load data
df = load_data("../data/medcat_re_dataset.csv")
print(f"Loaded {len(df)} records")
#df

Loaded 5 records


In [4]:
#Add relative dates if not already added via MedCAT trainer 
if 'relative_dates_json' not in df.columns:
    df = add_relative_dates(df)
    print("Added relative dates")
else:
    print("Relative dates already present, skipping extraction")

Added relative dates


In [5]:
#Inspect df to check that relative dates have been added
#df

In [6]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 5 samples


Test Pre-Processing & Utility Functions

In [7]:
# Example note for testing
note_text = (
    "Patient diagnosed with asthma on 2024-08-02. "
    "Diabetes was ruled out on 2024-08-02. "
    "Family history of hypertension, last reviewed in 2022. "
    "Patient may have pneumonia, last seen on 2024-08-02."
)

# Example entity and date spans
entity = {'start': 23, 'end': 29, 'value': 'asthma'}
date   = {'start': 33, 'end': 43, 'value': '2024-08-02'}

print("Example note:", note_text)
print("Entity:", entity)
print("Date:", date)

Example note: Patient diagnosed with asthma on 2024-08-02. Diabetes was ruled out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.
Entity: {'start': 23, 'end': 29, 'value': 'asthma'}
Date: {'start': 33, 'end': 43, 'value': '2024-08-02'}


In [8]:
# Test each function explicitly
print("Testing build_gold_lookup...")
gold_map = build_gold_lookup(samples[0]['links_json'])
print(f"Gold map: {gold_map}")

Testing build_gold_lookup...
Gold map: {('rheumatoid_arthritis', "16 Sep'24"), ('pituitary_adenoma', '12nd Sep 2024'), ('GERD', '17.12.24'), ('headache', '23rd Oct 2024')}


In [9]:
print("Testing get_label_for_pair...")
label = get_label_for_pair(57, 311, gold_map)  # Example positions
print(f"Label: {label}")

Testing get_label_for_pair...
Label: no_link


In [10]:
print("Testing mark_entities_full_text...")
marked = mark_entities_full_text(note_text, 23, 29, 33, 43, "asthma", "2024-08-02")
print(f"Marked text: {marked}")

Testing mark_entities_full_text...
Marked text: Patient diagnosed with [E1] asthma [/E1] on [E2] 2024-08-02 [/E2]. Diabetes was ruled out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.


In [11]:
# Test preprocessing
preprocessed = preprocess_input(note_text, entity, date)
print("\nPreprocessed input:")
print(preprocessed['marked_text'])


Preprocessed input:
Patient diagnosed with [E1] asthma [/E1] on [E2] 2024-08-02 [/E2]. Diabetes was ruled out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.


In [12]:
# ============================================================================
# COMPREHENSIVE ENTITY MARKING AND PREPROCESSING TESTS (new schema)
# ============================================================================

# Test with the full first sample from our data
sample = samples[0]
print("=" * 60)
print("TESTING WITH FULL SAMPLE")
print("=" * 60)
print(f"Sample note length: {len(sample['note_text'])}")
print(f"Number of entities: {len(sample['entities_list'])}")
print(f"Number of dates: {len(sample['dates'])}")
print(f"Number of gold relationships: {len(sample['links_json'])}")

# Test all entity-date combinations
print("\nTesting all entity-date combinations:")
for i, entity in enumerate(sample['entities_list']):
    for j, date in enumerate(sample['dates']):
        print(f"\n--- Combination {i+1}-{j+1}: {entity['value']} + {date['value']} ---")
        
        # Test preprocessing
        processed = preprocess_input(sample['note_text'], entity, date)
        
        # Show the marked text (truncated for readability)
        marked_text = processed['marked_text']
        print(f"Original text length: {len(sample['note_text'])}")
        print(f"Marked text length: {len(marked_text)}")
        
        # Show a snippet around the marked entities
        ent_start = processed['ent1_start']
        date_start = processed['ent2_start']
        
        # Find the context around both entities
        context_start = max(0, min(ent_start, date_start) - 50)
        context_end = min(len(marked_text), max(ent_start, date_start) + 100)
        context = marked_text[context_start:context_end]
        print(f"Context snippet: ...{context}...")
        
        # Test gold lookup (value-based)
        gold_set = build_gold_lookup(sample['links_json'])
        label = get_label_for_pair(entity['value'], date['value'], gold_set)
        print(f"Gold label: {label}")

# Test edge cases
print("\n" + "=" * 60)
print("TESTING EDGE CASES")
print("=" * 60)

# Test with entities at the very beginning and end of text
print("Testing entities at text boundaries...")
first_entity = sample['entities_list'][0]
last_date = sample['dates'][-1]

processed_edge = preprocess_input(sample['note_text'], first_entity, last_date)
print(f"First entity position: {first_entity['start']}-{first_entity['end']}")
print(f"Last date position: {last_date['start']}-{last_date['end']}")

# Show beginning and end of marked text
print(f"Marked text start: {processed_edge['marked_text'][:100]}...")
print(f"Marked text end: ...{processed_edge['marked_text'][-100:]}")

# Test for potential overlapping entities
print("\nTesting for potential overlapping entities...")
for i, entity in enumerate(sample['entities_list']):
    for j, date in enumerate(sample['dates']):
        if abs(entity['start'] - date['start']) < 10:  # Close entities
            print(f"Close entities found: {entity['value']} at {entity['start']}, {date['value']} at {date['start']}")
            processed_close = preprocess_input(sample['note_text'], entity, date)
            print(f"Marked text: {processed_close['marked_text'][entity['start']-20:date['end']+20]}")

# Test gold relationship mapping (value pairs)
print("\n" + "=" * 60)
print("TESTING GOLD RELATIONSHIP MAPPING")
print("=" * 60)

gold_set = build_gold_lookup(sample['links_json'])
print(f"Gold set size: {len(gold_set)}")

# Show each gold relationship
for rel in sample['links_json']:
    print(f"Gold relationship: {rel['entity']} <-> {rel['date']}")

TESTING WITH FULL SAMPLE
Sample note length: 1319
Number of entities: 64
Number of dates: 6
Number of gold relationships: 4

Testing all entity-date combinations:

--- Combination 1-1: history of meningitis + 30nd Jun 2024 ---
Original text length: 1319
Marked text length: 1341
Context snippet: ...Ultrasound ([E2] 30nd Jun 2024 [/E2]): no significant findings.imp: asthma

She denies any nausea, vomiting, or diarrhea.
C Patient reports compliance with current medication regimen. Basic metabolic panel within normal limits with sodium 140, potassium 4.2, creatinine 0.9.
Patient is afebrile with normal vital signs. T (02nd Aug 2024): reveals asthma.imp: asthma

X-ray (12nd Sep 2024): shows 3.1cm mass in brain.imp: pituitary_adenoma

CLINIC VISIT (16 Sep'24): nausea/vomiting worsening confirmed rheumatoid_arthritis switch to aspirin

Past medical history is non-contributory.
URGENT REVIEW (23rd Oct 2024): headache x1 day.r Will order additional laboratory studies at next visit if symptoms p

BERT Base

In [13]:
# Load base model and tokenizer
model_name = "google/bert_uncased_L-2_H-128_A-2"
#model_name = "emilyalsentzer/Bio_ClinicalBERT"
base_tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
# Test base model on a few samples
print("\nTesting base model (no finetuning):")
base_predictions = []

for sample in tqdm(samples[:10], desc="Base model testing"):  # Test on first 10 samples
    # Get absolute date pairs
    absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    
    # Get relative date pairs if available
    if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
        relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
        pairs = absolute_pairs + relative_pairs
    else:
        pairs = absolute_pairs
    
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']
        pred, conf = bert_extraction(sample['note_text'], entity, date, base_model, base_tokenizer)
        if pred == 1:
            base_predictions.append({'entity_label': entity['value'], 'date': date['value'], 'confidence': conf})

print(f"Base model predictions: {len(base_predictions)}")


Testing base model (no finetuning):


Base model testing: 100%|██████████| 5/5 [00:05<00:00,  1.00s/it]

Base model predictions: 756





In [32]:
#Look at predictions
base_predictions

[{'entity_label': 'history of meningitis',
  'date': '30nd Jun 2024',
  'confidence': 0.5474071502685547},
 {'entity_label': 'history of meningitis',
  'date': '12nd Sep 2024',
  'confidence': 0.5467742085456848},
 {'entity_label': 'history of meningitis',
  'date': "16 Sep'24",
  'confidence': 0.5474152565002441},
 {'entity_label': 'history of meningitis',
  'date': '23rd Oct 2024',
  'confidence': 0.5460683703422546},
 {'entity_label': 'history of meningitis',
  'date': '16st Nov 2024',
  'confidence': 0.5437679886817932},
 {'entity_label': 'history of meningitis',
  'date': '17.12.24',
  'confidence': 0.5437679886817932},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '30nd Jun 2024',
  'confidence': 0.5458624362945557},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '12nd Sep 2024',
  'confidence': 0.5452666282653809},
 {'entity_label': 'rheumatoid_arthritis',
  'date': "16 Sep'24",
  'confidence': 0.546122670173645},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2

In [17]:
# Calculate base model metrics
base_metrics = calculate_metrics(base_predictions, df)
print("Base model metrics:", base_metrics)

Base model metrics: {'precision': 0.05465288035450517, 'recall': 1.0, 'f1': 0.10364145658263305, 'tp': 37, 'fp': 640, 'fn': 0}


Data Preparation for Finetuning

In [18]:
# Create training pairs using best approach
processed_df = create_training_pairs(samples)
print(f"\nCreated {len(processed_df)} training pairs")


Created 756 training pairs


In [19]:
# Handle class imbalance
balanced_df, class_weights = handle_class_imbalance(processed_df, method='weighted')
print(f"Class weights: {class_weights}")

Class weights: tensor([0.1164, 1.8836])


In [20]:
# Train-test split
train_df, test_df = train_test_split(balanced_df, test_size=0.2, random_state=42, stratify=balanced_df['label'])
print(f"Train: {len(train_df)}, Test: {len(test_df)}")

Train: 604, Test: 152


In [21]:
# Setup tokenizer with special tokens
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = add_special_tokens(tokenizer)

In [22]:
# Resize model embeddings to match new tokenizer size
base_model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(30526, 128, padding_idx=0)

In [23]:
# Prepare PyTorch datasets
train_dataset = Dataset.from_pandas(train_df[['marked_text', 'label']])
test_dataset = Dataset.from_pandas(test_df[['marked_text', 'label']])

In [24]:
# Tokenize
train_tokenized = train_dataset.map(lambda x: tokenize_function(x, tokenizer, max_length=256), batched=True)
test_tokenized = test_dataset.map(lambda x: tokenize_function(x, tokenizer, max_length=256), batched=True)

Map:   0%|          | 0/604 [00:00<?, ? examples/s]

Map:   0%|          | 0/152 [00:00<?, ? examples/s]

In [25]:
# Set format for PyTorch
train_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

BERT Finetuning

In [26]:
# Create custom model with span pooling
model = BertRC(
    model_name=model_name,
    tokenizer=tokenizer,
    num_labels=2,
    class_weights=class_weights
)

In [27]:
# Resize model embeddings to match new tokenizer size
model.backbone.resize_token_embeddings(len(tokenizer))

Embedding(30526, 128, padding_idx=0)

In [28]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./bert_rc_results",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    num_train_epochs=3,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_steps=50,
    fp16=torch.cuda.is_available(),
    report_to=[],
    seed=42,
)

In [29]:
# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

  trainer = Trainer(


In [30]:
#Train
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro,F1 Weighted
1,0.599,0.985021,0.940789,0.484746,0.940789,0.912087
2,0.7237,1.422796,0.940789,0.484746,0.940789,0.912087
3,0.9405,1.560878,0.940789,0.484746,0.940789,0.912087




TrainOutput(global_step=228, training_loss=0.740918937482332, metrics={'train_runtime': 35.7328, 'train_samples_per_second': 50.71, 'train_steps_per_second': 6.381, 'total_flos': 0.0, 'train_loss': 0.740918937482332, 'epoch': 3.0})

In [31]:
# Evaluate on test set
eval_results = trainer.evaluate(test_tokenized)
print("\nTest Results:")
for metric, value in eval_results.items():
    if not metric.startswith('eval_'):
        continue
    clean_metric = metric.replace('eval_', '')
    print(f"{clean_metric}: {value:.4f}")




Test Results:
loss: 0.9850
accuracy: 0.9408
f1_macro: 0.4847
f1_micro: 0.9408
f1_weighted: 0.9121
runtime: 0.4319
samples_per_second: 351.9570
steps_per_second: 23.1550


In [None]:
# Save the final model
trainer.save_model("./bert_rc_final_model")
tokenizer.save_pretrained("./bert_rc_final_model")
print("\nModel saved to ./bert_rc_final_model")