Imports

In [None]:
#imports
from tqdm import tqdm
from transformers import pipeline, BitsAndBytesConfig
from openai import OpenAI
from huggingface_hub import HfApi

import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics

from llm_extractor_utils import (
    make_binary_prompt, 
    llm_extraction_binary_hf,
    llm_extraction_binary_openai,
    parse_llm_answer,
    make_multi_prompt,
    llm_extraction_multi_openai,
    llm_extraction_multi_hf
)

Data Loading

In [None]:
# Load data
df = load_data("../data/training_dataset_synthetic1.csv")
print(f"Loaded {len(df)} records")

In [None]:
#Inspect df
#df.head()

In [None]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

LLM Evaluation

In [None]:
#Choose whether to use HuggingFace 'hf' or OpenAI 'openai'
#Note if choosing OpenAI need to set OPENAI_API_KEY in .env file
provider = 'openai'

In [None]:
#Set device
device = -1  # Set to -1 for CPU, 0 for GPU

In [None]:
#Choose whether to do binary prediction 'binary' or multi-prediction 'multi'
#Binary prediction is where the LLM is passed every possible entity-date pair and asked to predict whether they are related (Yes or No). This requires a lot more LLM calls so can be expensive if using the OpenAI API
#Multi prediction is where the LLM is passed each note in turn and asked to extract all correct relations in one go as a json
method = 'multi'

In [None]:
#Choose whether to use zero-shot 'zero' or few-shot prompt 'few'
prompt_type = 'few'

In [None]:
#To see a list of available OpenAI API models
#Alternatively visit https://platform.openai.com/docs/models

#client = OpenAI(api_key = os.getenv('OPENAI_API_KEY'))
#models = client.models.list()
#for model in models:
    #print(model.id)

In [None]:
#To see a list of available HuggingFace text generation models
#Alternatively visit https://huggingface.co/models?pipeline_tag=text-generation

#Uncomment the below code to see the full list. Note this list is long and can take a couple of mins to run
#api = HfApi()
#text_gen_models = api.list_models(filter="text-generation")
#for model_info in text_gen_models:
    #print(model_info.modelId)

In [None]:
# Define LLM to use - this should be the model id from OpenAI or HuggingFace
model = 'gpt-3.5-turbo'

#Examples of OpenAI models
#model = 'gpt-3.5-turbo' #cheap option for binary method
#model = 'o4-mini'
#model = 'gpt-5-mini' #cost-effective reasoning model, good for multi method

#Examples of HF models
#model = 'gpt2'
#model = 'google/gemma-3-270m'
#model= 'Qwen/Qwen3-0.6B'
#model= 'roneneldan/TinyStories-1M' #fast option even on CPU, good for binary method
#model = 'EleutherAI/gpt-neo-125m'
#model = 'microsoft/phi-1_5'
#model = 'TinyLlama/TinyLlama-1.1B-chat-v1.0'
#model = '../Llama-3.1-8B-Instruct' #lcoal folder example

In [None]:
# Configure quantization if using HF opion
if provider == 'hf':
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_enable_fp32_cpu_offload=True
    )
else:
    pass

In [None]:
# Define generator if using HF option
if provider == 'hf':
    if method == 'multi':
        generator = pipeline(
            "text-generation",
            model=model,
            #device=device,
            model_kwargs={"quantization_config": quantization_config},
            max_new_tokens=1000,
            do_sample=False,
            #pad_token_id=2,  # Common pad token ID for Llama models
            #eos_token_id=2,  # End of sequence token
            return_full_text=False  # Only return the newly generated text
        )
    else:
        # Regular text generation for binary method
        generator = pipeline(
            "text-generation",
            model=model,
            #device=device,
            model_kwargs={"quantization_config": quantization_config}
        )
else:
    pass

In [None]:
# Prompt to use
if method == 'binary' and prompt_type == 'zero':
    prompt_to_use = 'binary_prompt_zero_shot.txt'
elif method == 'binary' and prompt_type == 'few':
    prompt_to_use = 'binary_prompt_few_shot.txt'
elif method == 'multi' and prompt_type == 'zero':
    prompt_to_use = 'multi_prompt_zero_shot.txt'
elif method == 'multi' and prompt_type == 'few':
    prompt_to_use = 'multi_prompt_few_shot.txt'
else:
    raise ValueError(
        f"Invalid method: {method} or prompt type {prompt_type}. "
        "Method must be either 'binary' or 'multi' and prompt type must be either 'zero' or 'few'."
    )

In [None]:
#Process samples, make prompt, do llm extraction and make predictions
predictions = []

if method == 'binary':
    for sample in tqdm(samples, desc="Samples"):
        
        # Get absolute date pairs
        absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
        
        # Get relative date pairs if available
        if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
            relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
            pairs = absolute_pairs + relative_pairs
        else:
            pairs = absolute_pairs
        
        for pair in pairs:
            #Create binary prompt
            prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
            
            #Get response based on the method chosen
            if provider == 'openai':
                response = llm_extraction_binary_openai(prompt, model=model)
            elif provider == 'hf':
                response = llm_extraction_binary_hf(prompt, generator)
            else:
                raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
            
            #Parse response into binary prediction
            pred, conf = parse_llm_answer(response)
            if pred == 1:
                predictions.append({
                    'entity_label': pair['entity_label'],
                    'date': pair['date'],
                    'confidence': conf
                })

elif method == 'multi':
    for sample in tqdm(samples, desc="Processing notes"):
        
        # Create multi-extraction prompt
        prompt = make_multi_prompt(
            note_text=sample['note_text'],
            prompt_filename=prompt_to_use,
            entities_list=sample['entities_list'],
            dates=sample['dates']
        )
        
        # Get all relationships in one call
        if provider == 'openai':
            relationships = llm_extraction_multi_openai(prompt, model=model)
        elif provider == 'hf':
            relationships = llm_extraction_multi_hf(prompt, generator)
        else:
            raise ValueError(f"Invalid provider: {provider}. Must be either 'openai' or 'hf'.")
        
        # Add to predictions
        predictions.extend(relationships)

In [None]:
#Look at predictions
#predictions

In [None]:
#Calculcate metrics
metrics = calculate_metrics(predictions, df)
metrics