Imports

In [1]:
#imports
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, pipeline

from utils import load_data, prepare_all_samples, get_entity_date_pairs, calculate_metrics
from naive_extractor import naive_extraction
from bert_extractor import bert_extraction
from llama_extractor import llama_extraction, make_binary_prompt, parse_llama_answer

Data Loading

In [2]:
# Load data
df = load_data("data/synthetic.csv")
print(f"Loaded {len(df)} records")
#df

Loaded 101 records


In [3]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 101 samples


Naive Extractor

In [4]:
#Process samples
predictions = []

for sample in samples:
    relationships = naive_extraction(sample['note_text'], sample['entities_list'], sample['dates'])
    predictions.extend(relationships)

In [5]:
# Look at predictions
#predictions

In [6]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.1277533039647577,
 'recall': 0.13551401869158877,
 'f1': 0.13151927437641722,
 'tp': 29,
 'fp': 198,
 'fn': 185}

BERT

In [7]:
# Load model and tokenizer
model_name = "google/bert_uncased_L-2_H-128_A-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google/bert_uncased_L-2_H-128_A-2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# Your extraction logic in notebook
predictions = []
for sample in samples:
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    
    for pair in pairs:
        text = f"Does '{pair['entity_label']}' relate to '{pair['date']}'? {sample['note_text'][:200]}..."
        pred, conf = bert_extraction(text, model, tokenizer)
        
        if pred == 1:
            predictions.append({'entity_label': pair['entity_label'], 'date': pair['date'], 'confidence': conf})

In [9]:
predictions

[{'entity_label': 'asthma',
  'date': '2024-08-02',
  'confidence': 0.509070873260498},
 {'entity_label': 'asthma',
  'date': '2024-10-23',
  'confidence': 0.5114834904670715},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-08-02',
  'confidence': 0.5100752115249634},
 {'entity_label': 'pituitary_adenoma',
  'date': '2024-10-23',
  'confidence': 0.5126566290855408},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-08-02',
  'confidence': 0.5073235630989075},
 {'entity_label': 'rheumatoid_arthritis',
  'date': '2024-10-23',
  'confidence': 0.5100324153900146},
 {'entity_label': 'pneumonia',
  'date': '2024-08-02',
  'confidence': 0.5072735548019409},
 {'entity_label': 'pneumonia',
  'date': '2024-10-23',
  'confidence': 0.5096699595451355},
 {'entity_label': 'gerd',
  'date': '2024-08-02',
  'confidence': 0.5039689540863037},
 {'entity_label': 'gerd',
  'date': '2024-10-23',
  'confidence': 0.5065612196922302},
 {'entity_label': 'meningitis',
  'date': '2024-08-02',
  '

In [10]:
# Calculate metrics
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0.12982456140350876,
 'recall': 0.17289719626168223,
 'f1': 0.14829659318637273,
 'tp': 37,
 'fp': 248,
 'fn': 177}

Llama

In [23]:
# Load in notebook (you control this)
#generator = pipeline("text-generation", model="../Llama-3.2-3B-Instruct", device=-1)
generator = pipeline(
    "text2text-generation",
    model="google/flan-t5-small",
    device=-1
)

Device set to use cpu


In [13]:
# Create prompt in notebook
#prompt = f"In the following text, confirm if {text} Answer YES or NO:"
#response = llama_extraction(prompt, generator)

In [14]:
prompt = "Does the following text indicate a relationship between 'asthma' and '2024-08-02'? Answer YES or NO. Text: Patient diagnosed with asthma on 2024-08-02."
result = generator(prompt)
print(result[0]['generated_text'])

No


In [15]:
for sample in tqdm(samples[:1], desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs[:1]:
        #print(pair)
        prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'])
        #print(prompt)
        response = llama_extraction(prompt, generator)
        print(response)
        pred, conf = parse_llama_answer(response)
        print(pred, conf)

Samples: 100%|██████████| 1/1 [00:00<00:00,  1.46it/s]

No
0 0.0





In [16]:
predictions = []

for sample in tqdm(samples, desc="Samples"):
    pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    for pair in pairs:
        prompt = make_binary_prompt(pair['entity'], pair['date_info'], sample['note_text'])
        response = llama_extraction(prompt, generator)
        pred, conf = parse_llama_answer(response)
        if pred == 1:
            predictions.append({
                'entity_label': pair['entity_label'],
                'date': pair['date'],
                'confidence': conf
            })

Samples: 100%|██████████| 101/101 [16:47<00:00,  9.98s/it]


In [19]:
predictions

[]

In [20]:
metrics = calculate_metrics(predictions, df)
metrics

{'precision': 0, 'recall': 0.0, 'f1': 0, 'tp': 0, 'fp': 0, 'fn': 214}