Imports

In [None]:
#imports
from tqdm import tqdm

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, pipeline, Trainer, TrainingArguments

from bert_training import make_training_pairs, gold_lookup

from utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics
from naive_extractor import naive_extraction
from bert_extractor import bert_extraction
from llm_extractor import load_prompt_template, make_binary_prompt, llm_extraction, parse_llm_answer

Data Loading

In [3]:
# Load data
df = load_data("data/synthetic.csv")
print(f"Loaded {len(df)} records")
#df

Loaded 101 records


In [4]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 101 samples


Naive Extractor

In [6]:
#Process samples
predictions = []

for sample in samples:
    relationships = naive_extraction(sample['entities_list'], sample['dates'])
    predictions.extend(relationships)

In [8]:
# Look at predictions
#predictions

In [9]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.1277533039647577,
 'recall': 0.13551401869158877,
 'f1': 0.13151927437641722,
 'tp': 29,
 'fp': 198,
 'fn': 185}

BERT Base

In [10]:
# Load local model and tokenizer
#model_path = "./bert_model_training/base_model"
model_path = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=2, ignore_mismatched_sizes=True)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
# Process samples
predictions = []

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        text = f"Does '{pair['entity_label']}' relate to '{pair['date']}'? {sample['note_text'][:200]}..."
        pred, conf = bert_extraction(text, model, tokenizer)
        if pred == 1:
            predictions.append({'entity_label': pair['entity_label'], 'date': pair['date'], 'confidence': conf})

Samples: 100%|██████████| 101/101 [00:05<00:00, 19.18it/s]


In [12]:
#Look at predictions
predictions

[{'entity_label': 'multiple_sclerosis',
  'date': '2024-09-27',
  'confidence': 0.5041855573654175},
 {'entity_label': 'multiple_sclerosis',
  'date': '2025-01-29',
  'confidence': 0.503922700881958},
 {'entity_label': 'bronchitis',
  'date': '2024-09-27',
  'confidence': 0.5079386830329895},
 {'entity_label': 'bronchitis',
  'date': '2025-01-29',
  'confidence': 0.5076344609260559},
 {'entity_label': 'tension_headache',
  'date': '2024-09-27',
  'confidence': 0.5036704540252686},
 {'entity_label': 'tension_headache',
  'date': '2025-01-29',
  'confidence': 0.5033239126205444},
 {'entity_label': 'gerd',
  'date': '2024-09-27',
  'confidence': 0.503506600856781},
 {'entity_label': 'gerd',
  'date': '2025-01-29',
  'confidence': 0.5032444000244141},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-09-27',
  'confidence': 0.5085455775260925},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2025-01-29',
  'confidence': 0.5084171891212463},
 {'entity_label': 'congenital malfor

In [13]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.1308641975308642,
 'recall': 0.24766355140186916,
 'f1': 0.17124394184168015,
 'tp': 53,
 'fp': 352,
 'fn': 161}

BERT Finetuning

In [None]:
# Prepare data
df_train = make_training_pairs(samples, gold_lookup)
dataset = Dataset.from_pandas(df_train)

def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [None]:
#Look at tokenized datasets
tokenized_datasets

In [None]:
#Define training args
training_args = TrainingArguments(
    output_dir="./bert_model_training/bert_finetuned",
    per_device_train_batch_size=16,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=10,
    save_total_limit=1,
    load_best_model_at_end=True,
)

In [None]:
#Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,  # For demo; use a real split for real training!
)

In [None]:
#Train
trainer.train()

In [None]:
#Process samples
predictions = []

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        text = f"Does '{pair['entity_label']}' relate to '{pair['date']}'? {sample['note_text'][:200]}..."
        pred, conf = bert_extraction(text, model, tokenizer)
        if pred == 1:
            predictions.append({'entity_label': pair['entity_label'], 'date': pair['date'], 'confidence': conf})

In [None]:
#Look at predictions
predictions

In [None]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

LLM

In [None]:
# Define generator
#generator = pipeline("text-generation", model="../Llama-3.2-3B-Instruct", device=-1)
generator = pipeline("text2text-generation", model="google/flan-t5-small", device=-1)

Device set to use cpu


In [15]:
# Test simple prompt using generator
prompt = "Does the following text indicate a relationship between 'asthma' and '2024-08-02'? Answer YES or NO. Text: Patient diagnosed with asthma on 2024-08-02."
result = generator(prompt)
print(result[0]['generated_text'])

No


In [None]:
# Test simple prompt using llm_extraction() function
prompt = "Does the following text indicate a relationship between 'asthma' and '2024-08-02'? Answer YES or NO. Text: Patient diagnosed with asthma on 2024-08-02."
response = llm_extraction(prompt, generator)
response

'No'

In [18]:
#Prompt to use
prompt_to_use = 'prompt.txt'

In [None]:
#Process all date-entity pairs, make prompt, do llm extraction and make prediction
predictions = []

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    #for pair in pairs[:1]:
    for pair in pairs:
        #print(pair)
        prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
        #print(prompt)
        response = llm_extraction(prompt, generator)
        #print(response)
        pred, conf = parse_llm_answer(response)
        #print(pred, conf)
        if pred == 1:
            predictions.append({
                'entity_label': pair['entity_label'],
                'date': pair['date'],
                'confidence': conf
            })

Samples: 100%|██████████| 101/101 [18:18<00:00, 10.87s/it]


In [20]:
#Look at prediction
predictions

[]

In [21]:
#Calculcate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0, 'recall': 0.0, 'f1': 0, 'tp': 0, 'fp': 0, 'fn': 214}