Imports

In [1]:
#imports
from tqdm import tqdm

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, pipeline, Trainer, TrainingArguments

from bert_training import add_special_tokens, make_training_pairs, split_train_val, balance_classes, gold_lookup

from utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics
from naive_extractor import naive_extraction
from bert_extractor import get_context_window, mark_entity, mark_date, preprocess_input, bert_extraction
from llm_extractor import load_prompt_template, make_binary_prompt, llm_extraction, parse_llm_answer

Data Loading

In [2]:
# Load data
df = load_data("data/synthetic.csv")
print(f"Loaded {len(df)} records")
#df

Loaded 101 records


In [3]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 101 samples


Naive Extractor

In [4]:
#Process samples
predictions = []

for sample in samples:
    relationships = naive_extraction(sample['entities_list'], sample['dates'])
    predictions.extend(relationships)

In [5]:
# Look at predictions
#predictions

In [6]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.1277533039647577,
 'recall': 0.13551401869158877,
 'f1': 0.13151927437641722,
 'tp': 29,
 'fp': 198,
 'fn': 185}

BERT Base

In [7]:
# Load local model and tokenizer
#model_path = "./bert_model_training/base_model"
model_path = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=2, ignore_mismatched_sizes=True)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
#Example note
note_text = (
    "Patient diagnosed with asthma on 2024-08-02. "
    "Diabetes was ruled out on 2024-08-02. "
    "Family history of hypertension, last reviewed in 2022. "
    "Patient may have pneumonia, last seen on 2024-08-02."
)

# Example entity and date spans (positions are for illustration)
entity = {'start': 23, 'end': 29, 'label': 'asthma'}  # "asthma"
date = {'start': 33, 'end': 43, 'parsed': '2024-08-02'}  # "2024-08-02"

note_text, entity, date

('Patient diagnosed with asthma on 2024-08-02. Diabetes was ruled out on 2024-08-02. Family history of hypertension, last reviewed in 2022. Patient may have pneumonia, last seen on 2024-08-02.',
 {'start': 23, 'end': 29, 'label': 'asthma'},
 {'start': 33, 'end': 43, 'parsed': '2024-08-02'})

In [9]:
#Test context window
context = get_context_window(note_text, entity['start'], date['start'], window_size=50)
print("Context window:\n", context)

Context window:
 Patient diagnosed with asthma on 2024-08-02. Diabetes was ruled out on 2024-08-02. 


In [10]:
#Test entity marking
entity_text = note_text[entity['start']:entity['end']]
offset = context.find(entity_text)
entity_rel = {'start': offset, 'end': offset + len(entity_text)}

marked_entity = mark_entity(context, entity_rel)
print("Entity marked:\n", marked_entity)

Entity marked:
 Patient diagnosed with [E]asthma[E] on 2024-08-02. Diabetes was ruled out on 2024-08-02. 


In [11]:
#Test date marking
date_text = note_text[date['start']:date['end']]
offset_date = context.find(date_text)
date_rel = {'start': offset_date, 'end': offset_date + len(date_text)}

marked_date = mark_date(context, date_rel)
print("Date marked:\n", marked_date)

Date marked:
 Patient diagnosed with asthma on [D]2024-08-02[D]. Diabetes was ruled out on 2024-08-02. 


The below cell needs fixing!

In [12]:
#Do full pre-processing
preprocessed = preprocess_input(note_text, entity, date, window_size=20)
print("Preprocessed input:\n", preprocessed)

Preprocessed input:
 ient diagnosed with [E]asthma[[D]E] on 2024[D]-08-02. Diabetes


In [13]:
# Process samples
predictions = []

#for sample in tqdm(samples, desc="Samples"):
    #pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    #for pair in pairs:
        #text = f"Does '{pair['entity_label']}' relate to '{pair['date']}'? {sample['note_text']}..."
        #pred, conf = bert_extraction(text, model, tokenizer)
        #if pred == 1:
            #predictions.append({'entity_label': pair['entity_label'], 'date': pair['date'], 'confidence': conf})

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']
        pred, conf = bert_extraction(sample['note_text'], entity, date, model, tokenizer, window_size=100)
        if pred == 1:
            predictions.append({'entity_label': entity['label'], 'date': date['parsed'], 'confidence': conf})

Samples: 100%|██████████| 101/101 [00:08<00:00, 11.77it/s]


In [14]:
#Look at predictions
predictions

[{'entity_label': 'asthma',
  'date': '2024-08-02',
  'confidence': 0.5390741229057312},
 {'entity_label': 'asthma',
  'date': '2024-10-23',
  'confidence': 0.5179486274719238},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-08-02',
  'confidence': 0.5275401473045349},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-10-23',
  'confidence': 0.5234643816947937},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-08-02',
  'confidence': 0.5361946821212769},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-10-23',
  'confidence': 0.5237062573432922},
 {'entity_label': 'pneumonia',
  'date': '2024-08-02',
  'confidence': 0.5249199867248535},
 {'entity_label': 'pneumonia',
  'date': '2024-10-23',
  'confidence': 0.5298224091529846},
 {'entity_label': 'gerd',
  'date': '2024-08-02',
  'confidence': 0.5323746204376221},
 {'entity_label': 'gerd',
  'date': '2024-10-23',
  'confidence': 0.5361893177032471},
 {'entity_label': 'meningitis',
  'date': '2024-08-02',
  

In [15]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.13973063973063973,
 'recall': 0.7757009345794392,
 'f1': 0.23680456490727533,
 'tp': 166,
 'fp': 1022,
 'fn': 48}

BERT Finetuning

In [16]:
# 1. Add special tokens to tokenizer (do this before model.resize_token_embeddings)

# Add special tokens to tokenizer
add_special_tokens(tokenizer)

# Resize model embeddings to match new tokenizer size
model.resize_token_embeddings(len(tokenizer))

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(30524, 128, padding_idx=0)

In [17]:
# 2. Prepare training pairs
training_pairs = make_training_pairs(samples, gold_lookup, window_size=100)

In [18]:
# 3. (Optional) Balance classes for better F1
training_pairs = balance_classes(training_pairs, ratio=1.0)  # 1:1 pos:neg

In [19]:
# 4. Split into train/val
train_pairs, val_pairs = split_train_val(training_pairs, val_frac=0.2)

In [20]:
# 5. Convert to HuggingFace Datasets
train_dataset = Dataset.from_pandas(train_pairs)
val_dataset = Dataset.from_pandas(val_pairs)

In [21]:
# 6. Tokenization function
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)

In [22]:
#Tokenization
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_val = val_dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/278 [00:00<?, ? examples/s]

Map:   0%|          | 0/70 [00:00<?, ? examples/s]

In [23]:
#Look at tokenized datasets
tokenized_train

Dataset({
    features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 278
})

In [24]:
#Define training args
training_args = TrainingArguments(
    output_dir="./bert_model_training/bert_finetuned",
    per_device_train_batch_size=16,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
)

In [25]:
#Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
)

In [26]:
#Train
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.6978,0.693009
2,0.6878,0.69094
3,0.6897,0.687727
4,0.669,0.686871
5,0.6775,0.684701
6,0.6474,0.680297
7,0.6426,0.676889
8,0.6278,0.673971
9,0.6264,0.675374
10,0.632,0.674538


TrainOutput(global_step=180, training_loss=0.6566579050487942, metrics={'train_runtime': 38.8, 'train_samples_per_second': 71.649, 'train_steps_per_second': 4.639, 'total_flos': 882988492800.0, 'train_loss': 0.6566579050487942, 'epoch': 10.0})

In [27]:
#Process samples
predictions = []

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']
        pred, conf = bert_extraction(sample['note_text'], entity, date, model, tokenizer, window_size=100)
        if pred == 1:
            predictions.append({'entity_label': entity['label'], 'date': date['parsed'], 'confidence': conf})

Samples: 100%|██████████| 101/101 [00:10<00:00, 10.02it/s]


In [28]:
#Look at predictions
predictions

[{'entity_label': 'pneumonia',
  'date': '2024-10-23',
  'confidence': 0.5751230120658875},
 {'entity_label': 'gerd',
  'date': '2024-10-23',
  'confidence': 0.5810917019844055},
 {'entity_label': 'meningitis',
  'date': '2024-10-23',
  'confidence': 0.5744342803955078},
 {'entity_label': 'depression',
  'date': '2025-05-07',
  'confidence': 0.5590400099754333},
 {'entity_label': 'schizophrenia',
  'date': '2025-03-05',
  'confidence': 0.5576508045196533},
 {'entity_label': 'copd',
  'date': '2025-03-05',
  'confidence': 0.5729040503501892},
 {'entity_label': 'stroke',
  'date': '2025-03-05',
  'confidence': 0.5437808632850647},
 {'entity_label': 'microadenoma',
  'date': '2025-03-05',
  'confidence': 0.5483373999595642},
 {'entity_label': 'diabetes_mellitus',
  'date': '2024-11-18',
  'confidence': 0.5358163118362427},
 {'entity_label': 'diabetes_mellitus',
  'date': '2014-12-24',
  'confidence': 0.5347342491149902},
 {'entity_label': 'pituitary macroadenoma',
  'date': '2024-11-18',


In [29]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.3265993265993266,
 'recall': 0.4532710280373832,
 'f1': 0.3796477495107632,
 'tp': 97,
 'fp': 200,
 'fn': 117}

LLM

In [None]:
# Define generator
#generator = pipeline("text-generation", model="../Llama-3.2-3B-Instruct", device=-1)
generator = pipeline("text2text-generation", model="google/flan-t5-small", device=-1)

In [None]:
# Test simple prompt using generator
prompt = "Does the following text indicate a relationship between 'asthma' and '2024-08-02'? Answer YES or NO. Text: Patient diagnosed with asthma on 2024-08-02."
result = generator(prompt)
print(result[0]['generated_text'])

In [None]:
# Test simple prompt using llm_extraction() function
prompt = "Does the following text indicate a relationship between 'asthma' and '2024-08-02'? Answer YES or NO. Text: Patient diagnosed with asthma on 2024-08-02."
response = llm_extraction(prompt, generator)
response

In [None]:
#Prompt to use
prompt_to_use = 'prompt.txt'

In [None]:
#Process all date-entity pairs, make prompt, do llm extraction and make prediction
predictions = []

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    #for pair in pairs[:1]:
    for pair in pairs:
        #print(pair)
        prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
        #print(prompt)
        response = llm_extraction(prompt, generator)
        #print(response)
        pred, conf = parse_llm_answer(response)
        #print(pred, conf)
        if pred == 1:
            predictions.append({
                'entity_label': pair['entity_label'],
                'date': pair['date'],
                'confidence': conf
            })

In [None]:
#Look at prediction
predictions

In [None]:
#Calculcate metrics
metrics = calculate_metrics(predictions, df)
metrics