In [None]:
import pandas as pd
import numpy as np
from diff_diff import CallawaySantAnna, SunAbraham

# ------------------ Paths ------------------
panel_fp = r"C:/Repositories/white-bowblis-nhmc/data/clean/panel.csv"

# ------------------ Load + prep ------------------
df = pd.read_csv(panel_fp)

df["cms_certification_number"] = df["cms_certification_number"].astype(str)

# Parse year_month like "YYYY/MM" or "YYYY-MM" -> datetime at first day of month
ym_clean = df["year_month"].astype(str).str.replace("/", "-", regex=False)
df["ym_date"] = pd.to_datetime(ym_clean + "-01", errors="coerce")
if df["ym_date"].isna().any():
    bad = df.loc[df["ym_date"].isna(), "year_month"].head(10).tolist()
    raise ValueError(f"Could not parse some year_month values. Examples: {bad}")

# Use an integer time index (monotone) for Sun-Abraham relative-time calculations
# This avoids issues with datetimes in some implementations.
df = df.sort_values(["cms_certification_number", "ym_date"]).reset_index(drop=True)
df["t_index"] = (df["ym_date"].dt.year * 12 + df["ym_date"].dt.month).astype(int)

# Safe logs (match your R mk_log)
def mk_log(x):
    x = pd.to_numeric(x, errors="coerce")
    return np.where(x > 0, np.log(x), np.nan)

for y in ["rn_hppd", "lpn_hppd", "cna_hppd", "total_hppd"]:
    if y in df.columns:
        df["ln_" + y.replace("_hppd", "")] = mk_log(df[y])

# ------------------ Baseline chain status (2017Q1) ------------------
if "chain" in df.columns:
    base = (
        df[(df["ym_date"] >= "2017-01-01") & (df["ym_date"] <= "2017-03-31")]
        .sort_values(["cms_certification_number", "ym_date"])
        .groupby("cms_certification_number", as_index=False)["chain"]
        .first()
        .rename(columns={"chain": "baseline_chain_2017Q1"})
    )
    df = df.merge(base, on="cms_certification_number", how="left")
else:
    df["baseline_chain_2017Q1"] = np.nan

# ------------------ Construct first_treat ------------------
# Sun-Abraham needs first_treat:
#   0 for never-treated
#   else the first period when treatment begins
# With your data, we can derive it from post==1 (assuming post stays 1 after treatment)
df["post"] = pd.to_numeric(df["post"], errors="coerce").fillna(0).astype(int)

first_treat_map = (
    df.loc[df["post"] == 1]
      .groupby("cms_certification_number")["t_index"]
      .min()
)

df["first_treat"] = df["cms_certification_number"].map(first_treat_map).fillna(0).astype(int)

# Sanity check: post should be 0 before first_treat and 1 after (if it's a step function)
# If this fails for many units, your 'post' is not a simple adoption indicator.
# We won't hard-fail, but we will warn.
def check_post_step(data, n_check=200):
    samp_units = data["cms_certification_number"].drop_duplicates().head(n_check)
    bad = 0
    for u in samp_units:
        g = data.loc[data["cms_certification_number"] == u].sort_values("t_index")
        ft = int(g["first_treat"].iloc[0])
        if ft == 0:
            continue
        pre = g.loc[g["t_index"] < ft, "post"]
        post = g.loc[g["t_index"] >= ft, "post"]
        if (pre.max() > 0) or (post.min() < 1):
            bad += 1
    if bad > 0:
        print(f"[warn] post is not a clean step function for {bad}/{len(samp_units)} checked units.")
        print("       If this is large, we should build first_treat from your true treatment-date column instead.")

check_post_step(df)

# ------------------ Define subsets (match your R) ------------------
is_prepand  = (df["ym_date"] >= "2017-01-01") & (df["ym_date"] <= "2019-12-31")
is_pandemic = (df["ym_date"] >= "2020-04-01") & (df["ym_date"] <= "2024-06-30")

datasets = {
    "full_with_anticipation": df.copy(),
    "full_without_anticipation": df.loc[df.get("anticipation2", 0) == 0].copy(),
    "prepandemic_wo": df.loc[is_prepand & (df.get("anticipation2", 0) == 0)].copy(),
    "pandemic_wo": df.loc[is_pandemic & (df.get("anticipation2", 0) == 0)].copy(),
    "baseline_chain_2017q1_wo": df.loc[(df["baseline_chain_2017Q1"] == 1) & (df.get("anticipation2", 0) == 0)].copy(),
    "baseline_nonchain_2017q1_wo": df.loc[(df["baseline_chain_2017Q1"] == 0) & (df.get("anticipation2", 0) == 0)].copy(),
}

outs_order = ["rn_hppd", "lpn_hppd", "cna_hppd", "total_hppd"]
log_outs   = ["ln_rn", "ln_lpn", "ln_cna", "ln_total"]

# ------------------ Sun-Abraham runner ------------------
def run_sa(data: pd.DataFrame, outcome: str, label: str, bootstrap: bool = False):
    # Drop missing outcome
    d = data.dropna(subset=[outcome, "cms_certification_number", "t_index", "first_treat"]).copy()
    if d.empty:
        print(f"[skip] {label}: outcome={outcome} has no usable observations.")
        return None

    # Choose inference
    if bootstrap:
        sa = SunAbraham(
            control_group="never_treated",
            anticipation=0,
            n_bootstrap=999,
            bootstrap_weights="rademacher",
            seed=42,
            cluster="cms_certification_number"
        )
    else:
        sa = SunAbraham(
            control_group="never_treated",
            anticipation=0,
            cluster="cms_certification_number"
        )

    results = sa.fit(
        d,
        outcome=outcome,
        unit="cms_certification_number",
        time="t_index",
        first_treat="first_treat"
    )

    print("\n" + "=" * 90)
    print(f"{label} | outcome = {outcome} | N = {len(d):,} | units = {d['cms_certification_number'].nunique():,}")
    print("=" * 90)
    results.print_summary()

    # Print a compact event-study table
    es = results.event_study_effects
    # Sort relative times numerically
    rel_times = sorted(es.keys(), key=lambda x: int(x))
    print("\nEvent-study effects (relative time e):")
    for e in rel_times:
        eff = es[e]
        print(f"  e={e:>3}: {eff['effect']:+.4f}  (SE {eff['se']:.4f})")

    print(f"\nOverall ATT: {results.overall_att:+.4f}  (SE {results.overall_se:.4f})")
    return results

# ------------------ Run everything ------------------
BOOTSTRAP = False  # set True if you want bootstrap inference

all_results = {}

for dname, dset in datasets.items():
    all_results[dname] = {"levels": {}, "logs": {}}

    # Levels
    for y in outs_order:
        if y in dset.columns:
            all_results[dname]["levels"][y] = run_sa(dset, y, label=dname, bootstrap=BOOTSTRAP)

    # Logs (only if exist)
    for y in log_outs:
        if y in dset.columns:
            all_results[dname]["logs"][y] = run_sa(dset, y, label=dname, bootstrap=BOOTSTRAP)

print("\nDone.")