Imports

In [1]:
#imports
import sys
import os

scripts_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'scripts'))
if scripts_path not in sys.path:
    sys.path.insert(0, scripts_path)

from tqdm import tqdm
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics
from bert_training import tokenize_function, add_special_tokens, make_training_pairs, balance_classes, gold_lookup
from bert_extractor import get_context_window, mark_entity, mark_date, preprocess_input, bert_extraction

Data Loading

In [3]:
# Load data
df = load_data("../data/synthetic.csv")
print(f"Loaded {len(df)} records")
df

Loaded 101 records


Unnamed: 0,patient,note_id,note,document_timestamp,extracted_disorders,formatted_dates,relationship_gold
0,1,0,Ultrasound (30nd Jun 2024): no significant fin...,14/05/2025,"[{'label': 'asthma', 'start': 57, 'end': 63}, ...","[{'original': '(02nd Aug 2024)', 'parsed': '20...","[{'date': '2024-08-02', 'date_position': 311, ..."
1,2,1,Labs (27th Sep 2024): anemia. resolving Skin:...,14/05/2025,"[{'label': 'multiple_sclerosis', 'start': 307,...","[{'original': '(27th Sep 2024)', 'parsed': '20...","[{'date': '2024-09-27', 'date_position': 5, 'd..."
2,3,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,14/05/2025,"[{'label': 'osteoarthritis', 'start': 43, 'end...","[{'original': '(2024-10-04)', 'parsed': '2004-...","[{'date': '2024-10-04', 'date_position': 16, '..."
3,4,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,14/05/2025,"[{'label': 'schizophrenia', 'start': 437, 'end...","[{'original': '(13rd Feb 2025)', 'parsed': '20...","[{'date': '2025-02-13', 'date_position': 14, '..."
4,5,4,New pt((18/11/24)): pt presents with nausea/vo...,14/05/2025,"[{'label': 'diabetes_mellitus', 'start': 440, ...","[{'original': '(18/11/24)', 'parsed': '2024-11...","[{'date': '2024-11-18', 'date_position': 7, 'd..."
...,...,...,...,...,...,...,...
96,7,96,Visit((08/10/24)): pt presents with joint pain...,14/05/2025,"[{'label': 'macroadenoma', 'start': 112, 'end'...","[{'original': '(11/12/2024)', 'parsed': '2024-...","[{'date': '2024-12-11', 'date_position': 624, ..."
97,8,97,F/U (31 Aug 2024): resolved A review of system...,14/05/2025,"[{'label': 'macroadenoma', 'start': 315, 'end'...","[{'original': '(31 Aug 2024)', 'parsed': '2024...","[{'date': '2024-08-31', 'date_position': 4, 'd..."
98,9,98,Phone note((12-10-2024)): slightly improved. o...,14/05/2025,"[{'label': 'tension_headache', 'start': 108, '...","[{'original': '(01/03/2025)', 'parsed': '2025-...","[{'date': '2025-03-01', 'date_position': 228, ..."
99,10,99,F/U (2025-02-23): fluctuating. confirmed multi...,14/05/2025,"[{'label': 'multiple_sclerosis', 'start': 41, ...","[{'original': '(2025-02-23)', 'parsed': '2023-...","[{'date': '2025-02-23', 'date_position': 6, 'd..."


In [4]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
samples[0]

Prepared 101 samples


{'note_text': "Ultrasound (30nd Jun 2024): no significant findings.imp: asthma\n\nShe denies any nausea, vomiting, or diarrhea.\nC Patient reports compliance with current medication regimen. Basic metabolic panel within normal limits with sodium 140, potassium 4.2, creatinine 0.9.\nPatient is afebrile with normal vital signs. T (02nd Aug 2024): reveals asthma.imp: asthma\n\nX-ray (12nd Sep 2024): shows 3.1cm mass in brain.imp: pituitary_adenoma\n\nCLINIC VISIT (16 Sep'24): nausea/vomiting worsening confirmed rheumatoid_arthritis switch to aspirin\n\nPast medical history is non-contributory.\nURGENT REVIEW (23rd Oct 2024): headache x1 day.r Will order additional laboratory studies at next visit if symptoms persist. Heart: Regular rate and rhythm, no murmurs. Patient has a history of meningitis. GI: Bowel sounds present in all four quadrants.\n Liver function tests show mild elevation in ALT and AST, likely due to medication effect.Chest X-ray reveals clear lung fields without infiltrate

BERT Base

In [5]:
# Load local model and tokenizer
#model_path = "./bert_model_training/base_model"
model_path = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=2, ignore_mismatched_sizes=True)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
#Example note
note_text = (
    "Patient diagnosed with asthma on 2024-08-02. "
    "Diabetes was ruled out on 2024-08-02. "
    "Family history of hypertension, last reviewed in 2022. "
    "Patient may have pneumonia, last seen on 2024-08-02."
)

# Example entity and date spans (positions are for illustration)
entity = {'start': 23, 'end': 29, 'label': 'asthma'}  # "asthma"
date = {'start': 33, 'end': 43, 'parsed': '2024-08-02'}  # "2024-08-02"

note_text, entity, date

('Patient diagnosed with asthma on 2024-08-02. Diabetes was ruled out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.',
 {'start': 23, 'end': 29, 'label': 'asthma'},
 {'start': 33, 'end': 43, 'parsed': '2024-08-02'})

In [7]:
#Test context window
context = get_context_window(note_text, entity['start'], date['start'], window_size=50)
print("Context window:\n", context)

Context window:
 Patient diagnosed with asthma on 2024-08-02. Diabetes was ruled out on 2024-08-02. 


In [8]:
#Test entity marking
entity_text = note_text[entity['start']:entity['end']]
offset = context.find(entity_text)
entity_rel = {'start': offset, 'end': offset + len(entity_text)}

marked_entity = mark_entity(context, entity_rel)
print("Entity marked:\n", marked_entity)

Entity marked:
 Patient diagnosed with [E]asthma[E] on 2024-08-02. Diabetes was ruled out on 2024-08-02. 


In [9]:
#Test date marking
date_text = note_text[date['start']:date['end']]
offset_date = context.find(date_text)
date_rel = {'start': offset_date, 'end': offset_date + len(date_text)}

marked_date = mark_date(context, date_rel)
print("Date marked:\n", marked_date)

Date marked:
 Patient diagnosed with asthma on [D]2024-08-02[D]. Diabetes was ruled out on 2024-08-02. 


In [10]:
#Do full pre-processing
preprocessed = preprocess_input(note_text, entity, date, window_size=20)
print("Preprocessed input:\n", preprocessed)
preprocessed

Preprocessed input:
 ient diagnosed with [E]asthma[E] on [D]2024-08-02[D]. Diabetes


'ient diagnosed with [E]asthma[E] on [D]2024-08-02[D]. Diabetes'

In [15]:
# Process samples
predictions = []

#for sample in tqdm(samples, desc="Samples"):
    #pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    #for pair in pairs:
        #text = f"Does '{pair['entity_label']}' relate to '{pair['date']}'? {sample['note_text']}..."
        #pred, conf = bert_extraction(text, model, tokenizer)
        #if pred == 1:
            #predictions.append({'entity_label': pair['entity_label'], 'date': pair['date'], 'confidence': conf})

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']  # Should have 'start', 'end', 'original'
        #print(date)
        pred, conf = bert_extraction(sample['note_text'], entity, date, model, tokenizer, window_size=100)
        if pred == 1:
            predictions.append({'entity_label': entity['label'], 'date': date.get('parsed'), 'confidence': conf})

Samples: 100%|██████████| 101/101 [00:08<00:00, 12.16it/s]


In [13]:
#Look at predictions
predictions

[{'entity_label': 'asthma',
  'date': '2024-08-02',
  'confidence': 0.5719550848007202},
 {'entity_label': 'asthma',
  'date': '2024-10-23',
  'confidence': 0.551507294178009},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-08-02',
  'confidence': 0.5719370245933533},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-10-23',
  'confidence': 0.5550824403762817},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-08-02',
  'confidence': 0.5640566945075989},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-10-23',
  'confidence': 0.5527292490005493},
 {'entity_label': 'pneumonia',
  'date': '2024-08-02',
  'confidence': 0.5390223860740662},
 {'entity_label': 'pneumonia',
  'date': '2024-10-23',
  'confidence': 0.5255318284034729},
 {'entity_label': 'gerd',
  'date': '2024-08-02',
  'confidence': 0.5509217977523804},
 {'entity_label': 'gerd',
  'date': '2024-10-23',
  'confidence': 0.5377340316772461},
 {'entity_label': 'meningitis',
  'date': '2024-08-02',
  '

In [14]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.1399662731871838,
 'recall': 0.7757009345794392,
 'f1': 0.23714285714285713,
 'tp': 166,
 'fp': 1020,
 'fn': 48}

BERT Finetuning

In [36]:
#Add special tokens to tokenizer
add_special_tokens(tokenizer)

# Resize model embeddings to match new tokenizer size
model.resize_token_embeddings(len(tokenizer))

Embedding(30524, 128, padding_idx=0)

In [None]:
#Do some layer freezing here?

In [None]:
#Split samples
train_samples, val_samples = train_test_split(samples, test_size=0.2, random_state=42)
print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")

In [None]:
#Prepare training pairs
training_pairs = make_training_pairs(train_samples, gold_lookup, window_size=100)
val_pairs = make_training_pairs(val_samples, gold_lookup, window_size=100)

In [None]:
#Balance classes (only training set)
training_pairs = balance_classes(training_pairs, ratio=1.0)

In [None]:
#Convert to Datasets
train_dataset = Dataset.from_pandas(training_pairs)
val_dataset = Dataset.from_pandas(val_pairs)

In [None]:
#Tokenization
tokenized_train = train_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
tokenized_val = val_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)

In [None]:
#Look at tokenized dataset
tokenized_train

In [None]:
#Define training args
training_args = TrainingArguments(
    output_dir="./bert_model_training/bert_finetuned",
    per_device_train_batch_size=16,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
)

In [None]:
#Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
)

In [None]:
#Train
trainer.train()

In [None]:
#Process validation samples
predictions = []

for sample in tqdm(val_samples, desc="Validation Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']
        pred, conf = bert_extraction(sample['note_text'], entity, date, model, tokenizer, window_size=100)
        if pred == 1:
            predictions.append({'entity_label': entity['label'], 'date': date['parsed'], 'confidence': conf})

In [None]:
#Look at predictions
predictions

In [None]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics