Imports

In [1]:
#imports
from tqdm import tqdm
from transformers import pipeline
from openai import OpenAI
from huggingface_hub import HfApi

import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data, prepare_all_samples, get_entity_date_pairs

from llm_extractor_utils import (
    make_binary_prompt, 
    llm_extraction_binary_hf,
    llm_extraction_binary_openai,
    parse_llm_answer,
    make_multi_prompt,
    llm_extraction_multi_openai,
    llm_extraction_multi_hf,
    llm_extraction_multi_structured_openai,
    llm_extraction_multi_structured_hf,
)

Data Loading

In [2]:
# Load data
df = load_data("../data/inference_dataset.csv")
print(f"Loaded {len(df)} records")

Loaded 101 records


In [3]:
#Inspect df
df.head()

Unnamed: 0,doc_id,note_text,entities_json,dates_json,relative_dates_json
0,0,Ultrasound (30nd Jun 2024): no significant fin...,"[{'id': 'ent_1', 'value': 'Ultrasound', 'cui':...","[{'id': 'abs_1', 'value': '30nd Jun 2024', 'st...",[]
1,1,Labs (27th Sep 2024): anemia. resolving Skin:...,"[{'id': 'ent_1', 'value': 'anemia', 'cui': 'C0...","[{'id': 'abs_1', 'value': '27th Sep 2024', 'st...",[]
2,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,"[{'id': 'ent_1', 'value': 'REVIEW', 'cui': 'C1...","[{'id': 'abs_1', 'value': '2024-10-04', 'start...",[]
3,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,"[{'id': 'ent_1', 'value': 'REVIEW', 'cui': 'C0...","[{'id': 'abs_1', 'value': '13rd Feb 2025', 'st...",[]
4,4,New pt((18/11/24)): pt presents with nausea/vo...,"[{'id': 'ent_1', 'value': 'nausea', 'cui': 'C0...","[{'id': 'abs_1', 'value': '18/11/24', 'start':...",[]


In [4]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 101 samples


LLM Inference

In [5]:
#Choose whether to use HuggingFace 'hf' or OpenAI 'openai'
#Note if choosing OpenAI need to set OPENAI_API_KEY in .env file
provider = 'openai'

In [6]:
#Choose whether to do binary prediction 'binary', multi-prediction 'multi', or multi-structured prediction 'multi_structured'
#Binary prediction is where the LLM is passed every possible entity-date pair and asked to predict whether they are related (Yes or No). This requires a lot more LLM calls so can be expensive if using the OpenAI API
#Multi prediction is where the LLM is passed each note in turn and asked to extract all correct relations in one go as a json
#Multi-structured is the same as multi but using structured outputs
method = 'multi'

In [None]:
#To see a list of available OpenAI API models
#Alternatively visit https://platform.openai.com/docs/models

#client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
#models = client.models.list()
#for model in models:
    #print(model.id)

In [None]:
#To see a list of available HuggingFace text generation models
#Alternatively visit https://huggingface.co/models?pipeline_tag=text-generation

#Uncomment the below code to see the full list. Note this list is long and can take a couple of mins to run
#api = HfApi()
#text_gen_models = api.list_models(filter="text-generation")
#for model_info in text_gen_models:
    #print(model_info.modelId)

In [7]:
# Define LLM to use - this should be the model id from OpenAI or HuggingFace
model = 'gpt-3.5-turbo'

#Examples of OpenAI models
#model = 'gpt-3.5-turbo' #cheap option for binary method
#model = 'o4-mini'
#model = 'gpt-5-mini' #best reasoning option

#Examples of HF models
#model = 'gpt2'
#model = 'google/gemma-3-270m'
#model= 'Qwen/Qwen3-0.6B'
#model= 'roneneldan/TinyStories-1M' #fast option even on CPU, good for binary method
#model = 'EleutherAI/gpt-neo-125m'
#model = 'microsoft/phi-1_5'
#model = 'TinyLlama/TinyLlama-1.1B-chat-v1.0'

In [8]:
# Define generator if using HF option
if provider == 'hf':
    if method == 'multi_structured':
        generator = pipeline(
            "text-generation",
            model=model,
            device=-1, #-1 for CPU, 0 for GPU
            trust_remote_code=True,  # Needed for Qwen models
            model_kwargs={"chat_format": "chatml"}  # For chat completion format
        )
    else:
        # Regular text generation for non-structured methods
        generator = pipeline("text-generation", model=model, device=-1) #-1 for CPU, 0 for GPU
else:
    pass

In [9]:
#Prompt to use
if method == 'binary':
    prompt_to_use = 'binary_prompt.txt'
elif method == 'multi':
    prompt_to_use = 'multi_prompt.txt'
elif method == 'multi_structured':
    prompt_to_use = 'multi_prompt_structured.txt'       
else:
    raise ValueError(f"Invalid method: {method}. Must be either 'binary', 'multi' or 'multi-structured'.")

In [None]:
#Process samples, make prompt, do llm extraction and make predictions
predictions = []

if method == 'binary':
    for sample in samples:
        
        # Get absolute date pairs
        absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
        
        # Get relative date pairs if available
        if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
            relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
            pairs = absolute_pairs + relative_pairs
        else:
            pairs = absolute_pairs
        
        for pair in tqdm(pairs, desc="Pairs"):
            #Create binary prompt
            prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
            
            #Get response based on the method chosen
            if provider == 'openai':
                response = llm_extraction_binary_openai(prompt, model=model)
            elif provider == 'hf':
                response = llm_extraction_binary_hf(prompt, generator)
            else:
                raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
            
            #Parse response into binary prediction
            pred, conf = parse_llm_answer(response)
            if pred == 1:
                predictions.append({
                    'entity_label': pair['entity_label'],
                    'date': pair['date'],
                    'confidence': conf
                })

elif method in ['multi', 'multi_structured']:
    for sample in tqdm(samples, desc="Processing notes"):
        
        # Create multi-extraction prompt
        prompt = make_multi_prompt(
            note_text=sample['note_text'],
            prompt_filename=prompt_to_use,
            entities_list=sample['entities_list'],  # optional - does not need to be passed
            dates=sample['dates']  # optional - does not need to be passed
        )
        
        # Get all relationships in one call
        if method == 'multi':
            if provider == 'openai':
                relationships = llm_extraction_multi_openai(prompt, model=model)
            elif provider == 'hf':
                relationships = llm_extraction_multi_hf(prompt, generator)
            else:
                raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
        
        elif method == 'multi_structured':
            if provider == 'openai':
                relationships = llm_extraction_multi_structured_openai(prompt, model=model)
            elif provider == 'hf':
                relationships = llm_extraction_multi_structured_hf(prompt, generator)  # To be implemented
            else:
                raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
        
        # Add to predictions
        predictions.extend(relationships)

Processing notes:   7%|▋         | 7/101 [00:29<06:36,  4.22s/it]

In [None]:
#Look at prediction
predictions