Imports

In [1]:
#imports
from tqdm import tqdm
from transformers import pipeline

import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from llm_extractor_utils import make_binary_prompt, llm_extraction, parse_llm_answer
from general_utils import load_data, prepare_all_samples, get_entity_date_pairs

Data Loading

In [2]:
# Load data
df = load_data("../data/inference_dataset.csv")
print(f"Loaded {len(df)} records")

Loaded 101 records


In [3]:
#Inspect df
df.head()

Unnamed: 0,doc_id,note_text,entities_json,dates_json,relative_dates_json
0,0,Ultrasound (30nd Jun 2024): no significant fin...,"[{'id': 'ent_1', 'value': 'Ultrasound', 'cui':...","[{'id': 'abs_1', 'value': '30nd Jun 2024', 'st...",[]
1,1,Labs (27th Sep 2024): anemia. resolving Skin:...,"[{'id': 'ent_1', 'value': 'anemia', 'cui': 'C0...","[{'id': 'abs_1', 'value': '27th Sep 2024', 'st...",[]
2,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,"[{'id': 'ent_1', 'value': 'REVIEW', 'cui': 'C1...","[{'id': 'abs_1', 'value': '2024-10-04', 'start...",[]
3,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,"[{'id': 'ent_1', 'value': 'REVIEW', 'cui': 'C0...","[{'id': 'abs_1', 'value': '13rd Feb 2025', 'st...",[]
4,4,New pt((18/11/24)): pt presents with nausea/vo...,"[{'id': 'ent_1', 'value': 'nausea', 'cui': 'C0...","[{'id': 'abs_1', 'value': '18/11/24', 'start':...",[]


In [4]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 101 samples


LLM Inference

In [13]:
# Define LLM to use (any text generation model from HuggingFace can be used, see: https://huggingface.co/models?pipeline_tag=text-generation)
model = 'gpt2'

In [14]:
# Define generator
generator = pipeline("text-generation", model=model, device=-1)

Device set to use cpu


In [15]:
#Prompt to use
prompt_to_use = 'prompt.txt'

In [None]:
#Process all date-entity pairs, make prompt, do llm extraction and make predictions
predictions = []

for sample in tqdm(samples, desc="Samples"):
    # Get absolute date pairs
    absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    
    # Get relative date pairs if available
    if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
        relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
        pairs = absolute_pairs + relative_pairs
    else:
        pairs = absolute_pairs
    
    for pair in pairs:
        prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
        response = llm_extraction(prompt, generator)
        pred, conf = parse_llm_answer(response)
        if pred == 1:
            predictions.append({
                'entity_label': pair['entity_label'],
                'date': pair['date'],
                'confidence': conf
            })

Samples:   0%|          | 0/101 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more det

In [None]:
#Look at prediction
predictions