Imports

In [None]:
# Imports
import pandas as pd
import json
from tqdm import tqdm
from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.cat import CAT
from medcat.config import Config
import spacy
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

import sys
import os
utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from general_utils import load_data, create_non_empty_filter
from date_extractor_utils import extract_absolute_dates, extract_relative_dates
from bert_relative_date_utils import predict_relative_dates

Data Loading

In [None]:
# Data Loading
df = pd.read_csv("../data/dataset_synthetic1.csv")
print(f"Loaded {len(df)} records for {df['patient_id'].nunique()} patients")

In [None]:
#Inspect df
df.head()

MedCAT Model

In [None]:
#Load finetuned cdb - this should be the one that was finetuned in MedCAT Trainer
cdb = CDB.load("../models/cdb_AewJ3qR.dat")

In [None]:
#Load vocab - again the one that was used in MedCAT Trainer
vocab = Vocab.load("../models/vocab.dat")

In [None]:
# Initialise the model (note you will need en_core_web_md downloaded, you can run python -m spacy download en_core_web_md)
nlp = spacy.load("en_core_web_md")
cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab)

In [None]:
# Load pre-trained MedCAT model from modelpack (you'll need to provide the path)
# This is a bit hacky but for the purpose of getting meta-annotations that are not available from MedCAT Trainer models
cat_2 = CAT.load_model_pack("../models/20230227__kch_gstt_trained_model_494c3717f637bb89.zip")

Relative Date Model

In [None]:
#Set relative date method - can be bert or regex
relative_date_method = 'regex'

In [None]:
#Model load path
if relative_date_method == 'bert':
    model_load_path = '../models/bert_model_relative_dates/'
else:
    pass

In [None]:
# Load the fine-tuned relative date extractor
if relative_date_method == 'bert':
    relative_model_path = model_load_path
    tokenizer_rel = AutoTokenizer.from_pretrained(relative_model_path)
    model_rel = AutoModelForTokenClassification.from_pretrained(relative_model_path)
    model_rel.eval()

    print("Relative date model loaded successfully!")
else:
    pass

Add Entities & Dates

In [None]:
# Process each document
results = []

# Define the set of categories you want to keep
keep_categories = {'disorder', 'substance', 'finding'}

for idx, row in tqdm(df.iterrows(), desc='row'):
    patient_id = row['patient_id']
    doc_id = row['doc_id']
    text = row['note_text']

    # Extract entities and meta-annotations using cat_2
    doc_2 = cat_2(text)
    meta_anns_lookup = {} # Dictionary to store meta-anns keyed by (start, end)
    for ent_2 in doc_2.ents:
        if ent_2._.cui != -1:
            # Extract Negation Status
            negated = False
            if hasattr(ent_2._, 'meta_anns') and ent_2._.meta_anns and 'Presence' in ent_2._.meta_anns:
                if ent_2._.meta_anns['Presence'].get('value') == 'False':
                    negated = True
            # Extract Subject
            subject = None
            if hasattr(ent_2._, 'meta_anns') and ent_2._.meta_anns and 'Subject' in ent_2._.meta_anns:
                subject = ent_2._.meta_anns['Subject'].get('value')

            # Store in lookup table
            meta_anns_lookup[(ent_2.start_char, ent_2.end_char)] = {'negated': negated, 'subject': subject}

    # Extract entities using cat
    doc = cat(text)
    entities = []
    for ent in doc.ents:
        # Check for valid CUI
        if ent._.cui != -1:

            # Get preferred name string from cat
            preferred_name_str = cat.cdb.cui2preferred_name.get(ent._.cui)

            # Parse category from cat's preferred name
            preferred_name = preferred_name_str
            category_name = None
            if preferred_name_str and preferred_name_str.endswith(')') and ' (' in preferred_name_str:
                parts = preferred_name_str.rsplit(' (', 1)
                if len(parts) == 2:
                    preferred_name = parts[0]
                    category_name = parts[1].rstrip(')')

            # Apply category filter
            if category_name in keep_categories:

                # Look up meta-annotations from cat_2
                span_key = (ent.start_char, ent.end_char)
                meta_anns = meta_anns_lookup.get(span_key)

                if meta_anns: # Check if cat_2 found a matching entity
                    negated_val = meta_anns['negated']
                    subject_val = meta_anns['subject']

                    # Apply meta-annotation filters
                    # Keep only if NOT negated
                    if not negated_val:
                        entities.append({
                            'id': f"ent_{len(entities) + 1}",
                            'value': ent.text,
                            'preferred_name': preferred_name,
                            'cui': ent._.cui,
                            'category': category_name,
                            'negated': negated_val,
                            'subject': subject_val,
                            'start': ent.start_char,
                            'end': ent.end_char
                        })

    # Extract dates
    dates = extract_absolute_dates(text)
    if relative_date_method == 'bert':
        relative_dates = predict_relative_dates(text, model_rel, tokenizer_rel)
    elif relative_date_method == 'regex':
        relative_dates = extract_relative_dates(text)
    else:
        raise ValueError(f"Invalid method: {relative_date_method}. Must be either 'bert' or 'regex'.")

    # Create final result row
    results.append({
        'patient_id': patient_id,
        'doc_id': doc_id,
        'note_text': text,
        'entities_json': json.dumps(entities),
        'dates_json': json.dumps(dates),
        'relative_dates_json': json.dumps(relative_dates)
    })

In [None]:
#Convert to df
inference_df = pd.DataFrame(results)
print(f"Created inference dataset with {len(inference_df)} records and {inference_df['patient_id'].nunique()} patients")

In [None]:
#Inspect df
inference_df.head()

In [None]:
# Save csv
inference_df.to_csv("../data/inference_dataset_synthetic1.csv", index=False)

Basic Checks

In [None]:
# Load data
#inference_df = load_data("../data/inference_dataset_synthetic1.csv")

In [None]:
#Check length
print(f"Loaded {len(inference_df)} records for {inference_df['patient_id'].nunique()} patients")

In [None]:
#Check data types
print("--- Data type distribution for 'entities_json' ---")
print(inference_df['entities_json'].apply(type).value_counts())
print("\n" + "="*50 + "\n")

print("--- Data type distribution for 'dates_json' ---")
print(inference_df['dates_json'].apply(type).value_counts())
print("\n" + "="*50 + "\n")

print("--- Data type distribution for 'relative_dates_json' ---")
print(inference_df['relative_dates_json'].apply(type).value_counts())

In [None]:
# Create filters dynamically for each column
has_entities = create_non_empty_filter(inference_df['entities_json'])
has_absolute_dates = create_non_empty_filter(inference_df['dates_json'])
has_relative_dates = create_non_empty_filter(inference_df['relative_dates_json'])

print("-" * 30)

# 1. How many rows have entities
rows_with_entities_count = has_entities.sum()
print(f"Number of rows with entities: {rows_with_entities_count}")

# 2. How many rows have absolute dates
rows_with_absolute_dates_count = has_absolute_dates.sum()
print(f"Number of rows with absolute dates: {rows_with_absolute_dates_count}")

# 3. How many rows have relative dates
rows_with_relative_dates_count = has_relative_dates.sum()
print(f"Number of rows with relative dates: {rows_with_relative_dates_count}")

print("-" * 30)

# 4. How many distinct patients have entities
patients_with_entities_count = inference_df.loc[has_entities, 'patient_id'].nunique()
print(f"Number of distinct patients with entities: {patients_with_entities_count}")

# 5. How many distinct patients have absolute dates
patients_with_absolute_dates_count = inference_df.loc[has_absolute_dates, 'patient_id'].nunique()
print(f"Number of distinct patients with absolute dates: {patients_with_absolute_dates_count}")

# 6. How many distinct patients have relative dates
patients_with_relative_dates_count = inference_df.loc[has_relative_dates, 'patient_id'].nunique()
print(f"Number of distinct patients with relative dates: {patients_with_relative_dates_count}")

print("-" * 30)

# 7. Number of distinct patients with entities AND (absolute dates OR relative dates)
has_any_date = has_absolute_dates | has_relative_dates
combined_filter = has_entities & has_any_date

patients_with_entities_and_dates_count = inference_df.loc[combined_filter, 'patient_id'].nunique()
print(f"Number of distinct patients with entities AND (absolute OR relative dates): {patients_with_entities_and_dates_count}")