Imports

In [1]:
#imports
from tqdm import tqdm
from transformers import pipeline

import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from llm_extractor_utils import make_binary_prompt, llm_extraction, parse_llm_answer
from general_utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics

Data Loading

In [2]:
# Load data
df = load_data("../data/training_dataset.csv")
print(f"Loaded {len(df)} records")

Loaded 5 records


In [None]:
#Inspect df
df.head()

In [5]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 5 samples


LLM Evaluation

In [7]:
# Define LLM to use (any text generation model from HuggingFace can be used, see: https://huggingface.co/models?pipeline_tag=text-generation)
model = 'gpt2'

In [8]:
# Define generator
generator = pipeline("text-generation", model=model, device=-1)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cpu


In [10]:
#Define prompt to use
prompt_to_use = 'prompt.txt'

In [12]:
#Process all date-entity pairs, make prompt, do llm extraction and make prediction
predictions = []

for sample in tqdm(samples, desc="Samples"):
    # Get absolute date pairs
    absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    
    # Get relative date pairs if available
    if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
        relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
        pairs = absolute_pairs + relative_pairs
    else:
        pairs = absolute_pairs
    
    for pair in pairs:
        prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'], prompt_to_use)
        response = llm_extraction(prompt, generator)
        pred, conf = parse_llm_answer(response)
        if pred == 1:
            predictions.append({
                'entity_label': pair['entity_label'],
                'date': pair['date'],
                'confidence': conf
            })

Samples:   0%|          | 0/5 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more detai

In [15]:
#Look at prediction
#predictions

In [14]:
#Calculcate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.05580693815987934,
 'recall': 1.0,
 'f1': 0.10571428571428572,
 'tp': 37,
 'fp': 626,
 'fn': 0}