Imports

In [17]:
#imports
from tqdm import tqdm
from transformers import pipeline, BitsAndBytesConfig
from openai import OpenAI
from huggingface_hub import HfApi

import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics

from llm_extractor_utils import (
    make_binary_prompt, 
    llm_extraction_binary_hf,
    llm_extraction_binary_openai,
    parse_llm_answer,
    make_multi_prompt,
    llm_extraction_multi_openai,
    llm_extraction_multi_hf,
    llm_extraction_multi_structured_openai,
    llm_extraction_multi_structured_hf,
)

Data Loading

In [18]:
# Load data
df = load_data("../data/training_dataset_synthetic.csv")
print(f"Loaded {len(df)} records")

Loaded 101 records


In [19]:
#Inspect df
df.head()

Unnamed: 0,doc_id,note_text,entities_json,dates_json,relations_json,relative_dates_json
0,0,Ultrasound (30nd Jun 2024): no significant fin...,"[{'id': 300001, 'value': 'asthma', 'cui': 'pla...","[{'id': 308001, 'value': '02nd Aug 2024', 'sta...","[{'date': '02nd Aug 2024', 'entity': 'asthma',...",[]
1,1,Labs (27th Sep 2024): anemia. resolving Skin:...,"[{'id': 300001, 'value': 'multiple_sclerosis',...","[{'id': 308001, 'value': '27th Sep 2024', 'sta...","[{'date': '27th Sep 2024', 'entity': 'multiple...",[]
2,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,"[{'id': 300001, 'value': 'osteoarthritis', 'cu...","[{'id': 308001, 'value': '2024-10-04', 'start'...","[{'date': '2024-10-04', 'entity': 'osteoarthri...",[]
3,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,"[{'id': 300001, 'value': 'schizophrenia', 'cui...","[{'id': 308001, 'value': '13rd Feb 2025', 'sta...","[{'date': '13rd Feb 2025', 'entity': 'schizoph...",[]
4,4,New pt((18/11/24)): pt presents with nausea/vo...,"[{'id': 300001, 'value': 'diabetes_mellitus', ...","[{'id': 308001, 'value': '18/11/24', 'start': ...","[{'date': '18/11/24', 'entity': 'diabetes_mell...",[]


In [20]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 101 samples


LLM Evaluation

In [21]:
#Choose whether to use HuggingFace 'hf' or OpenAI 'openai'
#Note if choosing OpenAI need to set OPENAI_API_KEY in .env file
provider = 'openai'

In [22]:
#Set device
device = -1  # Set to -1 for CPU, 0 for GPU

In [23]:
#Choose whether to do binary prediction 'binary', multi-prediction 'multi', or multi-structured prediction 'multi_structured'
#Binary prediction is where the LLM is passed every possible entity-date pair and asked to predict whether they are related (Yes or No). This requires a lot more LLM calls so can be expensive if using the OpenAI API
#Multi prediction is where the LLM is passed each note in turn and asked to extract all correct relations in one go as a json
#Multi-structured is the same as multi but using structured outputs
method = 'multi'

In [24]:
#To see a list of available OpenAI API models
#Alternatively visit https://platform.openai.com/docs/models

#client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
#models = client.models.list()
#for model in models:
    #print(model.id)

In [25]:
#To see a list of available HuggingFace text generation models
#Alternatively visit https://huggingface.co/models?pipeline_tag=text-generation

#Uncomment the below code to see the full list. Note this list is long and can take a couple of mins to run
#api = HfApi()
#text_gen_models = api.list_models(filter="text-generation")
#for model_info in text_gen_models:
    #print(model_info.modelId)

In [26]:
# Define LLM to use - this should be the model id from OpenAI or HuggingFace
model = 'gpt-5-mini'

#Examples of OpenAI models
#model = 'gpt-3.5-turbo' #cheap option for binary method
#model = 'o4-mini'
#model = 'gpt-5-mini' #best reasoning option

#Examples of HF models
#model = 'gpt2'
#model = 'google/gemma-3-270m'
#model= 'Qwen/Qwen3-0.6B'
#model= 'roneneldan/TinyStories-1M' #fast option even on CPU, good for binary method
#model = 'EleutherAI/gpt-neo-125m'
#model = 'microsoft/phi-1_5'
#model = 'TinyLlama/TinyLlama-1.1B-chat-v1.0'
#model = '../Llama-3.1-8B-Instruct' #lcoal folder example

In [None]:
# Configure quantization if using HF opion
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True
)

In [None]:
# Define generator if using HF option
if provider == 'hf':
    if method == 'multi_structured':
        generator = pipeline(
            "text-generation",
            model=model,
            device=device,
            model_kwargs={"chat_format": "chatml"}  # For chat completion format
        )
    elif method == 'multi':
        generator = pipeline(
            "text-generation",
            model=model,
            #device=device,
            model_kwargs={"quantization_config": quantization_config},
            max_new_tokens=1000,
            do_sample=False,
            #pad_token_id=2,  # Common pad token ID for Llama models
            #eos_token_id=2,  # End of sequence token
            return_full_text=False  # Only return the newly generated text
        )
    else:
        # Regular text generation for binary method
        generator = pipeline(
            "text-generation",
            model=model,
            #device=device,
            model_kwargs={"quantization_config": quantization_config}
        )
else:
    pass

In [28]:
#Prompt to use
if method == 'binary':
    prompt_to_use = 'binary_prompt.txt'
elif method == 'multi':
    prompt_to_use = 'multi_prompt.txt'
elif method == 'multi_structured':
    prompt_to_use = 'multi_prompt_structured.txt'       
else:
    raise ValueError(f"Invalid method: {method}. Must be either 'binary', 'multi' or 'multi-structured'.")

In [29]:
#Process samples, make prompt, do llm extraction and make predictions
predictions = []

if method == 'binary':
    for sample in samples:
        
        # Get absolute date pairs
        absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
        
        # Get relative date pairs if available
        if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
            relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
            pairs = absolute_pairs + relative_pairs
        else:
            pairs = absolute_pairs
        
        for pair in tqdm(pairs, desc="Pairs"):
            #Create binary prompt
            prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
            
            #Get response based on the method chosen
            if provider == 'openai':
                response = llm_extraction_binary_openai(prompt, model=model)
            elif provider == 'hf':
                response = llm_extraction_binary_hf(prompt, generator)
            else:
                raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
            
            #Parse response into binary prediction
            pred, conf = parse_llm_answer(response)
            if pred == 1:
                predictions.append({
                    'entity_label': pair['entity_label'],
                    'date': pair['date'],
                    'confidence': conf
                })

elif method in ['multi', 'multi_structured']:
    for sample in tqdm(samples, desc="Processing notes"):
        
        # Create multi-extraction prompt
        prompt = make_multi_prompt(
            note_text=sample['note_text'],
            prompt_filename=prompt_to_use,
            entities_list=sample['entities_list'],  # optional - does not need to be passed
            dates=sample['dates']  # optional - does not need to be passed
        )
        
        # Get all relationships in one call
        if method == 'multi':
            if provider == 'openai':
                relationships = llm_extraction_multi_openai(prompt, model=model)
            elif provider == 'hf':
                relationships = llm_extraction_multi_hf(prompt, generator)
            else:
                raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
        
        elif method == 'multi_structured':
            if provider == 'openai':
                relationships = llm_extraction_multi_structured_openai(prompt, model=model)
            elif provider == 'hf':
                relationships = llm_extraction_multi_structured_hf(prompt, generator)  # To be implemented
            else:
                raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
        
        # Add to predictions
        predictions.extend(relationships)

Processing notes:   0%|          | 0/101 [00:00<?, ?it/s]

Processing notes: 100%|██████████| 101/101 [39:29<00:00, 23.46s/it]


In [30]:
#Look at predictions
predictions

[{'date': '02nd Aug 2024',
  'entity_label': 'asthma',
  'date_id': 308001,
  'entity_id': 300001},
 {'date': '23rd Oct 2024',
  'entity_label': 'pneumonia',
  'date_id': 308002,
  'entity_id': 300004},
 {'date': '27th Sep 2024',
  'entity_label': 'congenital malformation',
  'date_id': 308001,
  'entity_id': 300006},
 {'date': '27th Sep 2024',
  'entity_label': 'multiple_sclerosis',
  'date_id': 308001,
  'entity_id': 300001},
 {'date': '29th Jan 2025',
  'entity_label': 'tension_headache',
  'date_id': 308002,
  'entity_id': 300003},
 {'date': '2024-10-04',
  'entity_label': 'osteoarthritis',
  'date_id': 308001,
  'entity_id': 300001},
 {'date': '07.05.25',
  'entity_label': 'copd',
  'date_id': 308002,
  'entity_id': 300007},
 {'date': '07.05.25',
  'entity_label': 'hyperlipidemia',
  'date_id': 308002,
  'entity_id': 300008},
 {'date': '13rd Feb 2025',
  'entity_label': 'schizophrenia',
  'date_id': 308001,
  'entity_id': 300001},
 {'date': '13rd Feb 2025',
  'entity_label': 'stro

In [31]:
#Calculcate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.5822222222222222,
 'recall': 0.7197802197802198,
 'f1': 0.6437346437346436,
 'tp': 131,
 'fp': 94,
 'fn': 51}