Imports

In [None]:
#Imports
import sys
import os
import logging
import json
import pandas as pd

from medcat.cdb import CDB
from medcat.config_rel_cat import ConfigRelCAT
from medcat.rel_cat import RelCAT

utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data

Data Loading

In [None]:
# Load data
df = load_data("../data/inference_dataset.csv")
print(f"Loaded {len(df)} records")

In [None]:
#Inspect df
df.head()

RelCAT Inference

In [None]:
#Load trained RelCAT model
relCAT = RelCAT.load("../models/relcat_model")

In [None]:
#Define cuis for absolute and relative dates (these should align with the cuis that were used to add these terms in MedCAT Trainer)
DATE_CUI = "410671006"
RELATIVE_DATE_CUI = "118578006"

In [None]:
# Make predictions
predictions = []

doc_ids = df['doc_id'].unique()

for doc_id in doc_ids:
    # Get the row
    row = df[df['doc_id'] == doc_id].iloc[0]
    
    # Parse the JSON columns
    entities = row["entities_json"] if isinstance(row["entities_json"], list) else json.loads(row["entities_json"])
    dates = row["dates_json"] if isinstance(row["dates_json"], list) else json.loads(row["dates_json"])
    relative_dates = row["relative_dates_json"] if isinstance(row["relative_dates_json"], list) else json.loads(row["relative_dates_json"]) if "relative_dates_json" in row else []
    
    # Create lookup maps
    entity_map = {entity['value']: (entity['id'], entity.get('preferred_name', entity['value'])) for entity in entities}
    date_map = {date['value']: (date['id'], 'absolute') for date in dates}
    date_map.update({date['value']: (date['id'], 'relative') for date in relative_dates})
    
    # Create annotations in the same format as training
    annotations = []
    for entity in entities:
        annotations.append({
            "value": entity["value"],
            "cui": entity.get("cui"),
            "start": entity.get("start"),
            "end": entity.get("end")
        })
    for date in dates:
        annotations.append({
            "value": date["value"],
            "cui": DATE_CUI,
            "start": date.get("start"),
            "end": date.get("end")
        })
    for date in relative_dates:
        annotations.append({
            "value": date["value"],
            "cui": RELATIVE_DATE_CUI,
            "start": date.get("start"),
            "end": date.get("end")
        })
    
    try:
        # Run inference
        output_doc_with_relations = relCAT.predict_text_with_anns(
            text=row["note_text"], 
            annotations=annotations
        )
        
        # Collect results - only keep date-entity pairs
        for relation in output_doc_with_relations._.relations:
            # Check if this is a date-entity pair (not entity-entity)
            date_text = None
            entity_text = None
            
            if relation['ent1_text'] in date_map and relation['ent2_text'] in entity_map:
                date_text = relation['ent1_text']
                entity_text = relation['ent2_text']
            elif relation['ent2_text'] in date_map and relation['ent1_text'] in entity_map:
                date_text = relation['ent2_text']
                entity_text = relation['ent1_text']
                
            if date_text and entity_text:
                date_id, date_type = date_map[date_text]
                entity_id, preferred_name = entity_map[entity_text]
                
                predictions.append({
                    'doc_id': doc_id,
                    'date_id': date_id,
                    'date': date_text,
                    'date_type': date_type,
                    'entity_id': entity_id,
                    'entity_label': entity_text,
                    'entity_preferred_name': preferred_name
                })
                
    except Exception as e:
        print(f"Error processing document {doc_id}: {e}")
        continue

print(f"Processed {len(doc_ids)} documents")
print(f"Total predictions: {len(predictions)}")

In [None]:
# Look at predictions
#predictions

In [None]:
# Save predictions
with open('../outputs/relcat_predictions.json', 'w') as f:
    json.dump(predictions, f, indent=2)

print("Saved predictions to outputs/relcat_predictions.json")