Imports

In [None]:
#imports
import pandas as pd
import numpy as np
import torch
import json
from transformers import AutoTokenizer, set_seed
from safetensors.torch import load_file
from tqdm import tqdm

# Import our modules
import sys
import os

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
models_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'models'))

if utils_path not in sys.path:
    sys.path.insert(0, utils_path)
if models_path not in sys.path:
    sys.path.insert(0, models_path)

from general_utils import load_data, prepare_all_samples, get_entity_date_pairs
from bert_training_utils import add_special_tokens
from bert_extractor_utils import preprocess_input, bert_extraction, mark_entities_full_text
from bert_model import BertRC

Data Loading

In [None]:
# Load data
df = load_data("../data/inference_dataset.csv")
print(f"Loaded {len(df)} records")

In [None]:
#Inspect df
df.head()

In [None]:
# Prepare all samples
samples = prepare_all_samples(df)
print(f"Prepared {len(samples)} samples")
#samples[0]

BERT Inference

In [None]:
# Set seed for reproducibility
set_seed(42)

In [None]:
#Set path to load model from
model_path = '../models/bert_model/'

In [None]:
#Set model name - this should be the same as the base model used for training
model_name = "google/bert_uncased_L-2_H-128_A-2"

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

#Load trained model
model = BertRC(model_name=model_name, tokenizer=tokenizer, num_labels=2)

#Load saved weights
state_dict = load_file(f"{model_path}/model.safetensors", device="cpu") # Or "cuda"

#Apply the weights to the model instance
model.load_state_dict(state_dict, strict=False)

#Set the model to evaluation mode
model.eval()
print("Model loaded successfully!")

In [None]:
# Make predictions
predictions = []

for sample in tqdm(samples, desc="Samples"):
    # Get absolute date pairs
    absolute_pairs = get_entity_date_pairs(sample['entities_list'], sample['dates'])
    
    # Get relative date pairs if available
    if sample.get('relative_dates') and len(sample['relative_dates']) > 0:
        relative_pairs = get_entity_date_pairs(sample['entities_list'], [], sample['relative_dates'])
        pairs = absolute_pairs + relative_pairs
    else:
        pairs = absolute_pairs
    
    for pair in pairs:
        entity = pair['entity']
        date = pair['date_info']
        pred, conf = bert_extraction(sample['note_text'], entity, date, model, tokenizer)
        if pred == 1:
            predictions.append({
                'doc_id': sample['doc_id'],
                'date_id': date['id'],
                'date': date['value'],
                'date_type': pair['date_type'],
                'entity_id': entity['id'],
                'entity_label': entity['value'],
                'entity_preferred_name': entity.get('preferred_name', entity['value'])
            })

print(f"Total predictions: {len(predictions)}")

In [None]:
#Look at prediction
predictions

In [None]:
# Save predictions
with open('../outputs/bert_predictions.json', 'w') as f:
    json.dump(predictions, f, indent=2)

print("Saved predictions to outputs/bert_predictions.json")