In [None]:
#Imports
import pandas as pd
import sys
import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from tqdm import tqdm

# Add utils to path
utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', 'utils'))
if utils_path not in sys.path:
    sys.path.insert(0, utils_path)

from date_extractor_utils import clean_value, extract_absolute_dates, normalise_relative, extract_relative_dates
from general_utils import load_data
from bert_relative_date_utils import predict_relative_dates

Test Absolute Dates

In [None]:
# Test on simple example
test_text = "Patient was seen on 15/06/2025 for follow-up. Next appointment scheduled for January 4th, 2026."

absolute_dates = extract_absolute_dates(test_text)

print("Results:")
for date in absolute_dates:
    print(f"  '{date['value']}' -> (start: {date['start']}, end: {date['end']})")

In [None]:
# Test on longer example with various date formats
test_text = """
Various date formats:
1. Standard formats:
   - 15/06/2025
   - 2025-06-15
   - 15-06-2025
   
2. Month name formats:
   - June 15, 2025
   - 15 June 2025
   - Jun 15, 2025
   
3. Mixed in text:
   The patient was seen on 15/06/2025 and had a follow-up on June 15, 2025.
   Next appointment scheduled for January 4th, 2026.
"""

absolute_dates = extract_absolute_dates(test_text)
print("\nResults:")
for date in absolute_dates:
    print(f"  '{date['value']}' -> (start: {date['start']}, end: {date['end']})")

In [None]:
# Test edge cases
edge_cases = [
    "",  # Empty string
    "No dates here",  # No dates
    "Invalid dates: 35/13/2025, 00/00/0000",  # Invalid dates
    "Partial dates: June 2025, 2025",  # Partial dates
]

for text in edge_cases:
    dates = extract_absolute_dates(text)
    print(f"\nText: '{text}'")
    print(f"Found {len(dates)} dates:")
    for date in dates:
        print(f"  '{date['value']}' -> (start: {date['start']}, end: {date['end']})")

In [None]:
# Test on actual dataset sample
df = pd.read_csv("../data/dataset_synthetic1.csv")
sample_text = df.iloc[0]['note_text']

dates = extract_absolute_dates(sample_text)

print(f"Text: {sample_text}...")
print(f"\nFound {len(dates)} dates:")
for date in dates:
    print(f"  '{date['value']}' -> (start: {date['start']}, end: {date['end']})")

Test Relative Dates (Regex)

In [None]:
# Test on one example
test_text = "Patient was seen last week for follow-up. Next appointment scheduled for tomorrow. Symptoms started 3 days ago."
relative_dates = extract_relative_dates(test_text)
relative_dates

In [None]:
# Test on comprehensive examples to verify all pattern types
test_text = """
Patient was seen last week for follow-up. 
Next appointment scheduled for tomorrow. 
Symptoms started 3 days ago.
Last visit was on Monday.
Previous checkup was 2 weeks earlier.
Past few days have been difficult.
Several months ago the condition worsened.
Earlier this week the patient improved.
Last visit was productive.
Next few days will be critical.
"""

results = extract_relative_dates(test_text)

print("\nResults:")
for result in results:
    print(f"  '{result['value']}' -> (pattern: {result['pattern_type']})")

print(f"\nTotal patterns found: {len(results)}")

Test Against Labelled Training Data

In [None]:
# Load training dataset using existing load_data function
df = load_data("../data/training_dataset_nph.csv")
print(f"Main dataset: {df.shape}")

In [None]:
# Load the cleaned gold standard data we generated
df_clean_gold = pd.read_csv("../data/relative_date_gold.csv")

# Create a dictionary of maps, one for each validation category
category_maps = {}
for category in df_clean_gold['is_valid'].unique():
    category_maps[category] = (
        df_clean_gold[df_clean_gold['is_valid'] == category]
        .groupby('doc_id')['date_value']
        .apply(list)
        .to_dict()
    )

# For convenience, keep the valid_rel_dates_map for the evaluation cell
valid_rel_dates_map = category_maps.get('YES', {})

print(f"Loaded gold standard data for {len(df_clean_gold['is_valid'].unique())} categories.")

In [None]:
#Set relative date method - can be bert or regex
relative_date_method = 'regex'

In [None]:
# Model load path
if relative_date_method == 'bert':
    model_load_path = '../models/bert_model_relative_dates/'
else:
    model_load_path = None

In [None]:
# Load fine-tuned relative date extractor (if using BERT)
if relative_date_method == 'bert':
    tokenizer_rel = AutoTokenizer.from_pretrained(model_load_path)
    model_rel = AutoModelForTokenClassification.from_pretrained(model_load_path)
    model_rel.eval()
    print("BERT relative date model loaded successfully!")
else:
    print("Using regex-based relative date extraction.")

In [None]:
# Initialize counters
abs_results = []
rel_results = []

# Loop through each note
for _, row in df.iterrows():
    text = row['note_text']
    doc_id = row.get('doc_id') # Retrieve doc_id for map lookup

    # Absolute dates
    validated_abs = row.get('dates_json', []) or []
    if isinstance(validated_abs, str):
        try:
            validated_abs = json.loads(validated_abs)
        except json.JSONDecodeError:
            validated_abs = []
    
    gold_abs_values = {clean_value(d['value']) for d in validated_abs if isinstance(d, dict) and d.get('value')}
    
    # Relative dates
    validated_rel_list = valid_rel_dates_map.get(doc_id, [])
    gold_rel_values = {normalise_relative(v) for v in validated_rel_list}
    pred_abs = extract_absolute_dates(text)

    if relative_date_method == 'bert':
        pred_rel = predict_relative_dates(text, model_rel, tokenizer_rel)
    elif relative_date_method == 'regex':
        pred_rel = extract_relative_dates(text)
    else:
        raise ValueError(f"Invalid method: {relative_date_method}. Must be either 'bert' or 'regex'.")

    # Normalise predicted values
    pred_abs_values = {clean_value(p['value']) for p in pred_abs if isinstance(p, dict) and p.get('value')}
    pred_rel_values = {normalise_relative(p['value']) for p in pred_rel if isinstance(p, dict) and p.get('value')}

    # Compare sets using same normalisation
    tp_abs = len(pred_abs_values & gold_abs_values)
    fp_abs = len(pred_abs_values - gold_abs_values)
    fn_abs = len(gold_abs_values - pred_abs_values)

    tp_rel = len(pred_rel_values & gold_rel_values)
    fp_rel = len(pred_rel_values - gold_rel_values)
    fn_rel = len(gold_rel_values - pred_rel_values)

    abs_results.append((tp_abs, fp_abs, fn_abs))
    rel_results.append((tp_rel, fp_rel, fn_rel))

In [None]:
# Compute metrics
def compute_metrics(results):
    tp = sum(r[0] for r in results)
    fp = sum(r[1] for r in results)
    fn = sum(r[2] for r in results)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1, tp, fp, fn

In [None]:
# Print results
abs_precision, abs_recall, abs_f1, tp_abs, fp_abs, fn_abs = compute_metrics(abs_results)
rel_precision, rel_recall, rel_f1, tp_rel, fp_rel, fn_rel = compute_metrics(rel_results)

print("=== Absolute Dates ===")
print(f"TP={tp_abs}, FP={fp_abs}, FN={fn_abs}")
print(f"Precision={abs_precision:.3f}, Recall={abs_recall:.3f}, F1={abs_f1:.3f}")

print("\n=== Relative Dates ===")
print(f"TP={tp_rel}, FP={fp_rel}, FN={fn_rel}")
print(f"Precision={rel_precision:.3f}, Recall={rel_recall:.3f}, F1={rel_f1:.3f}")

In [None]:
# Debug absolute dates (first 10 rows)
for _, row in df.head(10).iterrows():
    text = row['note_text']

    abs_data = row.get('dates_json')
    if isinstance(abs_data, str):
        try:
            abs_data = json.loads(abs_data)
        except json.JSONDecodeError:
            abs_data = []

    # Extract validated and predicted values
    gold_abs = {d['value'].strip() for d in abs_data if isinstance(d, dict) and 'value' in d}
    pred_abs = {p['value'].strip() for p in extract_absolute_dates(text) if isinstance(p, dict) and 'value' in p}

    # Only show rows where something exists
    if gold_abs or pred_abs:
        print(f"\nDoc {row.get('doc_id', 'N/A')}")
        print("Validated absolute:", gold_abs)
        print("Predicted absolute:", pred_abs)
        print("Overlap:", gold_abs & pred_abs)
        print("-" * 80)

In [None]:
# Debug relative date extraction with full category breakdown
for _, row in df.iterrows():
    text = row['note_text']
    doc_id = row.get('doc_id')

    # Get Gold Standard Sets from Maps
    gold_sets = {
        category: {v.lower() for v in maps.get(doc_id, [])}
        for category, maps in category_maps.items()
    }
    gold_yes = gold_sets.get('YES', set())
    all_gold_values = set.union(*gold_sets.values()) # Combine all gold values for checking new FPs

    # Prediction
    pred_raw = {p['value'].lower() for p in extract_relative_dates(text)}

    # Analysis
    if not (gold_yes or pred_raw):
        continue # Skip docs with no gold dates and no predictions

    print(f"\n--- Doc {doc_id} ---")
    print(f"Gold (YES): {gold_yes if gold_yes else 'None'}")
    print(f"Predicted:  {pred_raw if pred_raw else 'None'}")
    print("-" * 20)

    # True Positives
    tp = gold_yes & pred_raw
    print(f"Correctly Predicted (TP): {tp if tp else 'None'}")

    # False Negatives
    fn = gold_yes - pred_raw
    print(f"Missed (FN): {fn if fn else 'None'}")

    # False Positives Breakdown
    fp = pred_raw - gold_yes
    if fp:
        print("Incorrectly Predicted (FP):")
        for category, gold_values in gold_sets.items():
            if category == 'YES': continue
            fp_category = fp & gold_values
            if fp_category:
                print(f"  - Is {category}: {fp_category}")
        
        # Catch any FPs that were not in our gold standard at all
        fp_novel = fp - all_gold_values
        if fp_novel:
            print(f"  - Is Novel (Not in Gold): {fp_novel}")

In [None]:
# Analysis of FPs and FNs

# Set display_mode to 'FP', 'FN', or 'BOTH' to control the output
display_mode = 'BOTH'

# Store all errors for a final summary
all_errors = {'FP': [], 'FN': []}

# Loop through every row in the dataframe
for _, row in df.iterrows():
    text = row['note_text']
    doc_id = row.get('doc_id')

    # Get Gold Standard Sets from Maps
    gold_raw_map = {
        category: {v.lower() for v in maps.get(doc_id, [])}
        for category, maps in category_maps.items()
    }
    gold_norm_yes = {normalise_relative(v) for v in gold_raw_map.get('YES', set())}
    
    # Prediction
    pred_raw_set = {p['value'].lower() for p in extract_relative_dates(text)}
    pred_norm_set = {normalise_relative(v) for v in pred_raw_set}

    # Identify False Positives and False Negatives
    false_positives = pred_norm_set - gold_norm_yes
    false_negatives = gold_norm_yes - pred_norm_set

    # Log and Print Errors
    if display_mode in ['FP', 'BOTH'] and false_positives:
        all_errors['FP'].extend(list(false_positives))
        for val_norm in false_positives:
            # Find the original raw value that corresponds to the normalized FP
            raw_val = next((r for r in pred_raw_set if normalise_relative(r) == val_norm), val_norm)
            print(f"Doc {doc_id}: FP > Gold: None > Predicted: '{raw_val}' (Norm: '{val_norm}') > Evaluation: Missed")
            
    if display_mode in ['FN', 'BOTH'] and false_negatives:
        all_errors['FN'].extend(list(false_negatives))
        for val_norm in false_negatives:
            # Find the original raw gold value for display
            raw_val = next((r for r in gold_raw_map.get('YES', set()) if normalise_relative(r) == val_norm), val_norm)
            print(f"Doc {doc_id}: FN > Gold: '{raw_val}' (Norm: '{val_norm}') > Predicted: None > Evaluation: Missed")

# Final Summary
print("\n" + "="*30)
print(f"Total False Positives: {len(all_errors['FP'])}")
print(f"Total False Negatives: {len(all_errors['FN'])}")
print("="*30)