Imports

In [1]:
#imports
import pandas as pd
import torch
import json
from transformers import AutoTokenizer, set_seed
from safetensors.torch import load_file
from tqdm import tqdm

# Import our modules
import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
models_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'models'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)
if models_path not in sys.path:
    sys.path.insert(0, models_path)

from general_utils import load_data, prepare_all_samples, get_entity_date_pairs
from bert_extractor_utils import bert_extraction
from bert_model import BertRC

Data Loading

In [2]:
# Load data
df = load_data("../data/inference_dataset.csv")
print(f"Loaded {len(df)} records")

Loaded 101 records


In [3]:
#Inspect df
df.head()

Unnamed: 0,doc_id,note_text,entities_json,dates_json,relative_dates_json
0,0,Ultrasound (30nd Jun 2024): no significant fin...,"[{'id': 'ent_1', 'value': 'Ultrasound', 'prefe...","[{'id': 'abs_1', 'value': '30nd Jun 2024', 'st...",[]
1,1,Labs (27th Sep 2024): anemia. resolving Skin:...,"[{'id': 'ent_1', 'value': 'anemia', 'preferred...","[{'id': 'abs_1', 'value': '27th Sep 2024', 'st...",[]
2,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,"[{'id': 'ent_1', 'value': 'cough', 'preferred_...","[{'id': 'abs_1', 'value': '2024-10-04', 'start...",[]
3,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,"[{'id': 'ent_1', 'value': 'MRI', 'preferred_na...","[{'id': 'abs_1', 'value': '13rd Feb 2025', 'st...",[]
4,4,New pt((18/11/24)): pt presents with nausea/vo...,"[{'id': 'ent_1', 'value': 'nausea/vomiting', '...","[{'id': 'abs_1', 'value': '18/11/24', 'start':...",[]


In [4]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

Prepared 101 samples


BERT Inference

In [5]:
# Set seed for reproducibility
set_seed(42)

In [6]:
#Set path to load model from
model_path = '../models/bert_model/'

In [7]:
#Set model name - this should be the same as the base model used for training
model_name = "google/bert_uncased_L-2_H-128_A-2"

In [8]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

#Load trained model
model = BertRC(model_name=model_name, tokenizer=tokenizer, num_labels=2)

#Load saved weights
state_dict = load_file(f"{model_path}/model.safetensors", device="cpu") # Or "cuda"

#Apply the weights to the model instance
model.load_state_dict(state_dict, strict=False)

#Set the model to evaluation mode
model.eval()
print("Model loaded successfully!")

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Model loaded successfully!


In [9]:
# Make predictions
predictions = []

for sample in tqdm(samples, desc="Samples"):
    # Get absolute date pairs
    absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    
    # Get relative date pairs if available
    if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
        relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
        pairs = absolute_pairs + relative_pairs
    else:
        pairs = absolute_pairs
    
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']
        pred, conf = bert_extraction(sample['note_text'], entity, date, model, tokenizer)
        if pred == 1:
            predictions.append({
                'doc_id': sample['doc_id'],
                'date_id': date['id'],
                'date': date['value'],
                'date_type': pair['date_type'],
                'entity_id': entity['id'],
                'entity_label': entity['value'],
                'entity_preferred_name': entity.get('preferred_name', entity['value'])
            })

print(f"Total predictions: {len(predictions)}")

Samples: 100%|██████████| 101/101 [01:55<00:00,  1.15s/it]

Total predictions: 4349





In [10]:
#Look at prediction
predictions

[{'doc_id': 0,
  'date_id': 'abs_1',
  'date': '30nd Jun 2024',
  'date_type': 'absolute',
  'entity_id': 'ent_1',
  'entity_label': 'Ultrasound',
  'entity_preferred_name': 'Ultrasound'},
 {'doc_id': 0,
  'date_id': 'abs_2',
  'date': '02nd Aug 2024',
  'date_type': 'absolute',
  'entity_id': 'ent_1',
  'entity_label': 'Ultrasound',
  'entity_preferred_name': 'Ultrasound'},
 {'doc_id': 0,
  'date_id': 'abs_3',
  'date': '12nd Sep 2024',
  'date_type': 'absolute',
  'entity_id': 'ent_1',
  'entity_label': 'Ultrasound',
  'entity_preferred_name': 'Ultrasound'},
 {'doc_id': 0,
  'date_id': 'abs_4',
  'date': '16 Sep',
  'date_type': 'absolute',
  'entity_id': 'ent_1',
  'entity_label': 'Ultrasound',
  'entity_preferred_name': 'Ultrasound'},
 {'doc_id': 0,
  'date_id': 'abs_5',
  'date': '23rd Oct 2024',
  'date_type': 'absolute',
  'entity_id': 'ent_1',
  'entity_label': 'Ultrasound',
  'entity_preferred_name': 'Ultrasound'},
 {'doc_id': 0,
  'date_id': 'abs_6',
  'date': '16st Nov 2024'

In [11]:
# Save predictions
with open('../outputs/bert_predictions.json', 'w') as f:
    json.dump(predictions, f, indent=2)

print("Saved predictions to outputs/bert_predictions.json")

Saved predictions to outputs/bert_predictions.json
