# Corrected Time-to-Event and Intervention Selection Analysis

This notebook contains the corrected implementation with proper temporal filtering.

**Two critical fixes have been applied:**

1. CUTOFF_DATE corrected to 2024-12-31 (was incorrectly 2024-06-01)
2. Temporal filter added to diagnosis features: `dx = dx[dx["createdAt"] < cutoff_date].copy()`


In [None]:
import pandas as pd
import re
from datetime import datetime
from typing import Optional

CUTOFF_DATE = pd.Timestamp("2024-12-31")

_DX_CATS = {
    "renal":        [r"\brenal\b", r"kidney", r"neph"],
    "infection":    [r"infection", r"sepsis", r"uti\b", r"pneumonia", r"cellulitis", r"abscess"],
    "injury":       [r"injur", r"fracture", r"sprain", r"lacerat", r"contusion", r"burn"],
    "respiratory":  [r"asthma", r"copd", r"respir", r"bronch", r"sob", r"dyspnea"],
    "cardiac":      [r"card", r"mi\b", r"myocard", r"afib", r"arrhythm", r"chf", r"heart"],
    "neuro":        [r"neuro", r"stroke", r"cva\b", r"seizure", r"tia\b", r"migraine", r"headache"],
    "psych":        [r"psych", r"depress", r"anx", r"bipolar", r"schizo", r"suicid"],
    "pain":         [r"pain", r"ache", r"cramp", r"spasm"],
    "pregnancy":    [r"pregnan", r"ob\b", r"obstet", r"gyn\b", r"miscar", r"prenatal"],
    "gi":           [r"nausea", r"vomit|emesis", r"diarrhea", r"abdom|abd\b|belly", r"gi\b", r"gastr|ulcer"],
    "gu":           [r"urinar|dysuria|hematur|pyel|prostat", r"gu\b"],
    "endocrine":    [r"diabet|hypergly|hypogly|thyroid|adrenal|endocr"],
    "oncology":     [r"cancer|onc|tumor|malignan|neoplasm"],
    "trauma":       [r"mvc|mva|gunshot|stab|assault|fall\b|trauma"],
    "skin_wound":   [r"rash|dermat|wound|ulcer|decub|psoria|eczema"],
    "substance":    [r"etoh|alcohol|opioid|overdose|cocaine|meth|fentanyl|substance"],
    "social":       [r"homeless|housing|transport|food|utility|childcare|violence|safety"],
    "dental":       [r"dental|tooth|teeth|oral|abscess"],
    "other":        [r"ekg|tbd|other|unspecified|unknown"],
}

_DX_RX = {k: re.compile("|".join(v), re.I) for k, v in _DX_CATS.items()}


def build_dx_features(cutoff_date: pd.Timestamp = CUTOFF_DATE) -> pd.DataFrame:
    """
    Build diagnosis features with proper temporal filtering.
    
    CRITICAL FIX: This version includes temporal filtering to prevent data leakage.
    Only diagnoses recorded BEFORE the cutoff date are included in features.
    """
    dx = _read_sql('SELECT "patientId","description","createdAt" FROM "Diagnoses"', LH_ENGINE)
    if dx.empty:
        return pd.DataFrame(columns=["patient_id"])
    
    dx["createdAt"] = pd.to_datetime(dx["createdAt"], errors="coerce")
    dx = dx.dropna(subset=["createdAt"]).copy()
    dx = dx[dx["createdAt"] < cutoff_date].copy()
    
    mm = _member_map()
    dx["patientId"] = dx["patientId"].astype(str)
    mm["member_id"] = mm["member_id"].astype(str)
    dx = dx.merge(mm, left_on="patientId", right_on="member_id", how="left")
    dx = dx.dropna(subset=["patient_id"]).copy()
    
    def cats_for(text: str) -> list[str]:
        s = str(text) if isinstance(text, str) else ""
        s = re.sub(r"[\\/;|]+", ",", s)
        hits = [k for k, rx in _DX_RX.items() if rx.search(s)]
        return sorted(set(hits))
    
    dx["_cats"] = dx["description"].apply(cats_for)
    rows = []
    for pid, cats in zip(dx["patient_id"].astype(str), dx["_cats"]):
        for c in cats:
            rows.append((pid, c))
    if not rows:
        return pd.DataFrame(columns=["patient_id"])
    
    df = pd.DataFrame(rows, columns=["patient_id","category"])\
           .groupby(["patient_id","category"]).size().reset_index(name="count")
    piv = df.pivot(index="patient_id", columns="category", values="count").fillna(0).astype(int)
    piv.columns = [f"dx_cat_{c}_count" for c in piv.columns]
    flags = (piv > 0).astype(int)
    flags.columns = [c.replace("_count","_any") for c in piv.columns]
    out = pd.concat([piv, flags], axis=1).reset_index()
    out["dx_any_count"] = out[[c for c in out.columns if c.endswith("_count")]].sum(axis=1)
    return out


def build_features(members: pd.DataFrame, 
                   adt: pd.DataFrame, 
                   interventions: pd.DataFrame, 
                   engaged_ids: set[str], 
                   dx_feat: pd.DataFrame,
                   cutoff_date: pd.Timestamp = CUTOFF_DATE) -> pd.DataFrame:
    """
    Build leak-free features and outcomes with temporal filtering.
    """
    feat = members.copy()
    
    feat["birthDate"] = pd.to_datetime(feat["birthDate"], errors="coerce")
    feat = feat.dropna(subset=["birthDate"]).copy()
    feat["age"] = (cutoff_date - feat["birthDate"]).dt.days / 365.25
    feat["is_engaged"] = feat["patient_id"].astype(str).isin(engaged_ids).astype(int)
    
    adt_h = adt[adt["event_datetime"] < cutoff_date].copy()
    util = adt_h.groupby("patient_id").agg({
        "event_datetime": "count",
        "is_admit": ["sum", "max"],
        "is_ed": ["sum", "max"],
        "is_obs": ["sum", "max"]
    })
    util.columns = ["_".join(c) if c[1] else c[0] for c in util.columns]
    util = util.reset_index()
    util.columns = ["patient_id", "hist_event_count", "hist_admit_count", 
                    "hist_had_admit", "hist_ed_count", "hist_had_ed", 
                    "hist_obs_count", "hist_had_obs"]
    
    inter_h = interventions[interventions["encounter_datetime"] < cutoff_date].copy()
    inter_agg = inter_h.groupby("patient_id").size().reset_index(name="hist_encounter_count")
    
    feat = feat.merge(util, on="patient_id", how="left")
    feat = feat.merge(inter_agg, on="patient_id", how="left")
    feat = feat.merge(dx_feat, on="patient_id", how="left")
    
    for c in ["hist_event_count","hist_admit_count","hist_ed_count","hist_obs_count","hist_encounter_count"]:
        if c in feat.columns:
            feat[c] = feat[c].fillna(0).astype(int)
    for c in ["hist_had_admit","hist_had_ed","hist_had_obs"]:
        if c in feat.columns:
            feat[c] = feat[c].fillna(0).astype(int)
    
    return feat


def verify_temporal_integrity(dx_features: pd.DataFrame, cutoff_date: pd.Timestamp) -> bool:
    """
    Verify no future-dated diagnoses in features.
    Raises ValueError if temporal leakage detected.
    """
    dx_all = _read_sql('SELECT "patientId","createdAt" FROM "Diagnoses"', LH_ENGINE)
    dx_all["createdAt"] = pd.to_datetime(dx_all["createdAt"], errors="coerce")
    
    mm = _member_map()
    dx_all = dx_all.merge(mm, left_on="patientId", right_on="member_id", how="left")
    
    patients_with_features = set(dx_features['patient_id'].astype(str))
    future_diagnoses = dx_all[
        (dx_all['patient_id'].isin(patients_with_features)) &
        (dx_all['createdAt'] >= cutoff_date)
    ]
    
    if len(future_diagnoses) > 0:
        raise ValueError(
            f"Temporal leakage detected: {len(future_diagnoses):,} future diagnoses "
            f"found for {future_diagnoses['patient_id'].nunique():,} patients"
        )
    
    return True
