Imports

In [None]:
#imports
from tqdm import tqdm
from transformers import pipeline

import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from llm_extractor import load_prompt_template, make_binary_prompt, llm_extraction, parse_llm_answer
from utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics
from relative_date_extractor import add_relative_dates

Data Loading

In [None]:
# Load data
df = load_data("../data/medcat_re_dataset.csv")
print(f"Loaded {len(df)} records")
#df

In [None]:
#Add relative dates if not already added via MedCAT trainer 
if 'relative_dates_json' not in df.columns:
    df = add_relative_dates(df)
    print("Added relative dates")
else:
    print("Relative dates already present, skipping extraction")

In [None]:
#Inspect df to check that relative dates have been added
df

In [None]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

LLM

In [None]:
# Define generator
#generator = pipeline("text-generation", model="../Llama-3.2-3B-Instruct", device=-1)
generator = pipeline("text2text-generation", model="google/flan-t5-small", device=-1)

In [None]:
# Test simple prompt using generator
prompt = "Does the following text indicate a relationship between 'asthma' and '2024-08-02'? Answer YES or NO. Text: Patient diagnosed with asthma on 2024-08-02."
result = generator(prompt)
print(result[0]['generated_text'])

In [None]:
# Test simple prompt using llm_extraction() function
prompt = "Does the following text indicate a relationship between 'asthma' and '2024-08-02'? Answer YES or NO. Text: Patient diagnosed with asthma on 2024-08-02."
response = llm_extraction(prompt, generator)
response

In [None]:
#Prompt to use
prompt_to_use = 'prompt.txt'

In [None]:
#Process all date-entity pairs, make prompt, do llm extraction and make prediction
predictions = []

for sample in tqdm(samples, desc="Samples"):
    # Get absolute date pairs
    absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    
    # Get relative date pairs if available
    if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
        relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
        pairs = absolute_pairs + relative_pairs
    else:
        pairs = absolute_pairs
    
    #for pair in pairs[:1]:
    for pair in pairs:
        #print(pair)
        prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
        print(prompt)
        response = llm_extraction(prompt, generator)
        #print(response)
        pred, conf = parse_llm_answer(response)
        #print(pred, conf)
        if pred == 1:
            predictions.append({
                'entity_label': pair['entity_label'],
                'date': pair['date'],
                'confidence': conf
            })

In [None]:
#Look at prediction
predictions

In [None]:
#Calculcate metrics
metrics = calculate_metrics(predictions, df)
metrics