Imports

In [1]:
#Imports
import logging
import json
import pandas as pd

from medcat.cdb import CDB
from medcat.config_rel_cat import ConfigRelCAT
from medcat.rel_cat import RelCAT

import sys, os
utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data

Data Loading

In [2]:
# Load data
df = load_data("../data/inference_dataset.csv")
print(f"Loaded {len(df)} records")

Loaded 101 records


In [3]:
#Inspect df
df.head()

Unnamed: 0,doc_id,note_text,entities_json,dates_json,relative_dates_json
0,0,Ultrasound (30nd Jun 2024): no significant fin...,"[{'id': 'ent_1', 'value': 'Ultrasound', 'cui':...","[{'id': 'abs_1', 'value': '30nd Jun 2024', 'st...",[]
1,1,Labs (27th Sep 2024): anemia. resolving Skin:...,"[{'id': 'ent_1', 'value': 'anemia', 'cui': 'C0...","[{'id': 'abs_1', 'value': '27th Sep 2024', 'st...",[]
2,2,URGENT REVIEW (2024-10-04): cough. suspect ost...,"[{'id': 'ent_1', 'value': 'REVIEW', 'cui': 'C1...","[{'id': 'abs_1', 'value': '2024-10-04', 'start...",[]
3,3,URGENT REVIEW (13rd Feb 2025) MRI of the brain...,"[{'id': 'ent_1', 'value': 'REVIEW', 'cui': 'C0...","[{'id': 'abs_1', 'value': '13rd Feb 2025', 'st...",[]
4,4,New pt((18/11/24)): pt presents with nausea/vo...,"[{'id': 'ent_1', 'value': 'nausea', 'cui': 'C0...","[{'id': 'abs_1', 'value': '18/11/24', 'start':...",[]


RelCAT Inference

In [4]:
#Load trained RelCAT model
relCAT = RelCAT.load("../models/relcat_models")

INFO:medcat.utils.relation_extraction.base_component:BaseComponent_RelationExtraction initialized
INFO:medcat.utils.relation_extraction.base_component:BaseComponent_RelationExtraction initialized


In [7]:
#Define cuis for absolute and relative dates (these should align with the cuis that were used to add these terms in MedCAT Trainer)
DATE_CUI = "410671006"
RELATIVE_DATE_CUI = "410671007"

In [8]:
# Make predictions
predictions = []

doc_ids = df['doc_id'].unique()

for doc_id in doc_ids:
    # Get the row
    row = df[df['doc_id'] == doc_id].iloc[0]
    
    # Parse the JSON columns
    entities = row["entities_json"] if isinstance(row["entities_json"], list) else json.loads(row["entities_json"])
    dates = row["dates_json"] if isinstance(row["dates_json"], list) else json.loads(row["dates_json"])
    relative_dates = row["relative_dates_json"] if isinstance(row["relative_dates_json"], list) else json.loads(row["relative_dates_json"]) if "relative_dates_json" in row else []
    
    # Combine absolute and relative dates
    all_dates = dates + relative_dates
    
    # Create annotations in the same format as training
    annotations = []
    for entity in entities:
        annotations.append({
            "value": entity["value"],
            "cui": entity.get("cui"),
            "start": entity.get("start"),
            "end": entity.get("end")
        })
    for date in all_dates:
        annotations.append({
            "value": date["value"],
            "cui": DATE_CUI,
            "start": date.get("start"),
            "end": date.get("end")
        })
    
    try:
        # Run inference
        output_doc_with_relations = relCAT.predict_text_with_anns(
            text=row["note_text"], 
            annotations=annotations
        )
        
        # Collect results - only keep date-entity pairs
        for relation in output_doc_with_relations._.relations:
            # Check if this is a date-entity pair (not entity-entity)
            if (relation['ent1_text'] in [d['value'] for d in dates] and 
                relation['ent2_text'] in [e['value'] for e in entities]) or \
               (relation['ent2_text'] in [d['value'] for d in dates] and 
                relation['ent1_text'] in [e['value'] for e in entities]):
                
                all_predictions.append({
                    'entity_label': relation['ent1_text'],
                    'date': relation['ent2_text'],
                    'confidence': relation['confidence'],
                    'doc_id': doc_id
                })
                
    except Exception as e:
        print(f"Error processing document {doc_id}: {e}")
        continue

print(f"Processed {len(doc_ids)} documents")
print(f"Total predictions: {len(predictions)}")

Error processing document 0: min() arg is an empty sequence


INFO:medcat.rel_cat:total relations for doc: 346
INFO:medcat.rel_cat:processing...
100%|██████████| 346/346 [02:18<00:00,  2.50it/s]


Error processing document 1: name 'all_predictions' is not defined
Error processing document 2: min() arg is an empty sequence


INFO:medcat.rel_cat:total relations for doc: 382
INFO:medcat.rel_cat:processing...
  8%|▊         | 32/382 [00:08<01:35,  3.65it/s]

KeyboardInterrupt: 

In [None]:
# Look at predictions
#predictions

In [None]:
# Show results
print("Test Set Results:")
print(f"Total predictions: {len(all_predictions)}")

# Show first 10 predictions
print("\nFirst 10 predictions:")
for i, pred in enumerate(all_predictions[:10]):
    print(f"{i+1}. {pred['entity_label']} -> {pred['date']} (conf: {pred['confidence']:.3f}) [doc: {pred['doc_id']}]")

# Show high confidence predictions
high_conf = [p for p in all_predictions if p['confidence'] > 0.7]
print(f"\nHigh confidence predictions (>0.7): {len(high_conf)}")
for i, pred in enumerate(high_conf[:5]):  # Show first 5
    print(f"{i+1}. {pred['entity_label']} -> {pred['date']} (conf: {pred['confidence']:.3f})")

In [None]:
# Let's debug the exact counts
print(f"Total test pairs: {len(df)}")
print(f"Total predictions: {len(all_predictions)}")

# Count how many test pairs were actually predicted
predicted_count = 0
for _, row in df.iterrows():
    found = False
    for pred in all_predictions:
        if (pred['doc_id'] == row['doc_id'] and 
            ((pred['entity_label'] == row['ent1'] and pred['date'] == row['ent2']) or
             (pred['entity_label'] == row['ent2'] and pred['date'] == row['ent1']))):
            found = True
            break
    if found:
        predicted_count += 1

print(f"Test pairs that were predicted: {predicted_count}")
print(f"Test pairs that were NOT predicted: {len(df) - predicted_count}")

# Also check if there are predictions for documents not in test set
test_doc_ids = set(df['doc_id'].unique())
pred_doc_ids = set([p['doc_id'] for p in all_predictions])
print(f"Test doc IDs: {test_doc_ids}")
print(f"Prediction doc IDs: {pred_doc_ids}")
print(f"Extra predictions (not in test): {len(pred_doc_ids - test_doc_ids)}")

In [None]:
# Create predictions for all test pairs
all_test_predictions = []

for _, row in df.iterrows():
    # Check if this pair was predicted as a link
    found = False
    for pred in all_predictions:
        if (pred['doc_id'] == row['doc_id'] and 
            ((pred['entity_label'] == row['ent1'] and pred['date'] == row['ent2']) or
             (pred['entity_label'] == row['ent2'] and pred['date'] == row['ent1']))):
            all_test_predictions.append('LINK')
            found = True
            break
    
    if not found:
        all_test_predictions.append('NO_LINK')

# Now calculate metrics on all test pairs
y_true_all = df['label'].tolist()
y_pred_all = all_test_predictions

print(f"\nAll Test Pairs Metrics:")
print(f"Accuracy: {sum(1 for t, p in zip(y_true_all, y_pred_all) if t == p) / len(y_true_all):.3f}")
print(classification_report(y_true_all, y_pred_all, labels=['LINK', 'NO_LINK']))