In [None]:
# =========================
# Cell 1 — Imports, paths, robust readers, QC + plotting helpers
# =========================
from __future__ import annotations
import math
import gzip
import numpy as np
import pandas as pd
from pathlib import Path
from astropy.io import fits
import matplotlib.pyplot as plt

# optional ROOT reader (for FITRES ROOT)
try:
    import uproot  # type: ignore
except Exception:
    uproot = None

# ---- user paths (HIGHZ + LOWZ) ----
BASE = Path("/Users/tz/Documents/GitHub/Lsst-Twilight-SNe/Motivation/data_snana_EP_LSST")

# 1_SIM (HEAD/PHOT)
HEAD_HIGHZ = BASE / "1_SIM/PIP_EP-LSST_LSST_P21_HIGHZ/PIP_EP-LSST_LSST_P21_HIGHZ_SNIaMODEL00-0001_HEAD.FITS.gz"
PHOT_HIGHZ = BASE / "1_SIM/PIP_EP-LSST_LSST_P21_HIGHZ/PIP_EP-LSST_LSST_P21_HIGHZ_SNIaMODEL00-0001_PHOT.FITS.gz"

HEAD_LOWZ  = BASE / "1_SIM/PIP_EP-LSST_LSST_P21_LOWZ/PIP_EP-LSST_LSST_P21_LOWZ_SNIaMODEL00-0001_HEAD.FITS.gz"
PHOT_LOWZ  = BASE / "1_SIM/PIP_EP-LSST_LSST_P21_LOWZ/PIP_EP-LSST_LSST_P21_LOWZ_SNIaMODEL00-0001_PHOT.FITS.gz"

# 2_LCFIT (dirs + canonical filenames)
FITRES_HIGHZ_DIR = BASE / "2_LCFIT/PIP_EP-LSST_LSST_P21_HIGHZ"
FITRES_LOWZ_DIR  = BASE / "2_LCFIT/PIP_EP-LSST_LSST_P21_LOWZ"
FITRES_HIGHZ_ASCII = FITRES_HIGHZ_DIR / "FITOPT000.FITRES.gz"
FITRES_LOWZ_ASCII  = FITRES_LOWZ_DIR  / "FITOPT000.FITRES.gz"   # try ascii first
FITRES_LOWZ_ROOT   = FITRES_LOWZ_DIR  / "FITOPT000.ROOT.gz"     # fallback if uproot installed

# ---- analysis controls ----
Z_MAX = 1.20   # keep within this z for both detection and cosmology

# -------------------------
# Utilities
# -------------------------
def _to_int64_safe(x):
    if pd.isna(x):
        return pd.NA
    try:
        if isinstance(x, bytes):
            x = x.decode(errors="ignore")
        if isinstance(x, str):
            x = x.strip()
        return np.int64(int(float(x)))
    except Exception:
        return pd.NA

def _clean_chars_inplace(df: pd.DataFrame) -> None:
    for c in df.columns:
        if df[c].dtype.kind in ("S", "O", "U"):
            try:
                df[c] = df[c].astype(str).str.strip()
            except Exception:
                pass

def _to_numeric_if_possible(s: pd.Series) -> pd.Series:
    """Try numeric conversion. If it raises, keep original."""
    try:
        return pd.to_numeric(s)
    except Exception:
        return s

def read_head_fits(head_path: Path) -> pd.DataFrame:
    with fits.open(head_path) as hdul:
        arr = np.array(hdul[1].data)
    df = pd.DataFrame(arr.byteswap().newbyteorder())
    _clean_chars_inplace(df)
    # ID
    if "SNID" in df.columns:
        df["ID_int"] = pd.Series([_to_int64_safe(v) for v in df["SNID"]], dtype="Int64")
    else:
        df["ID_int"] = pd.Series([_to_int64_safe(v) for v in df.index], dtype="Int64")
    # z, mjd
    zcol = "REDSHIFT_FINAL" if "REDSHIFT_FINAL" in df.columns else ("REDSHIFT_TRUE" if "REDSHIFT_TRUE" in df.columns else None)
    if zcol is not None:
        df["z"] = pd.to_numeric(df[zcol], errors="coerce")
    if "PEAKMJD" in df.columns:
        df["PEAKMJD"] = pd.to_numeric(df["PEAKMJD"], errors="coerce")
    # ensure integers for pointers if present
    for c in ("NOBS","PTROBS_MIN","PTROBS_MAX"):
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce").astype("Int64")
    return df

def read_fitres_ascii_gz(path: Path) -> pd.DataFrame:
    """Parse SNANA FITRES ASCII (.FITRES.gz) robustly."""
    with gzip.open(path, "rt") as f:
        lines = f.readlines()
    # find VARNAMES
    names = None
    start = 0
    for i, L in enumerate(lines):
        if L.strip().upper().startswith("VARNAMES:"):
            names = L.strip().split()[1:]
            start = i + 1
            break
    if names is None:
        raise RuntimeError(f"VARNAMES not found in {path}")
    rows = []
    for L in lines[start:]:
        s = L.strip()
        if (not s) or s.startswith("#") or s.upper().startswith(("VARNAMES","NVAR","END","VERSION","SNANA")):
            continue
        toks = s.split()
        if toks and toks[0].endswith(":"):
            toks = toks[1:]
        if len(toks) < len(names):
            continue
        rows.append(toks[:len(names)])
    df = pd.DataFrame(rows, columns=names)
    for c in df.columns:
        df[c] = _to_numeric_if_possible(df[c])
    return df

def read_fitres_root(path: Path) -> pd.DataFrame:
    if uproot is None:
        raise RuntimeError("uproot is not available to read ROOT FITRES.")
    with uproot.open(path) as f:
        tree = None
        for key in ("FITRES", "FITOPT000", "FITRES/FITRES"):
            if key in f:
                tree = f[key]
                break
        if tree is None:
            for _, obj in f.items():
                if hasattr(obj, "arrays"):
                    tree = obj
                    break
        if tree is None:
            raise RuntimeError(f"No TTree found in {path}")
        df = tree.arrays(library="pd")
    _clean_chars_inplace(df)
    return df.reset_index(drop=True)

def read_fitres_any(*candidates: Path) -> pd.DataFrame:
    """
    Try ASCII FITRES first (preferred), then ROOT (if uproot available).
    If a directory is passed, search it for *.FITRES.gz (then *.ROOT*).
    If nothing is readable, return empty DataFrame (with warning-like print).
    """
    paths: list[Path] = []
    for p in candidates:
        if p is None:
            continue
        p = Path(p)
        if p.is_file():
            paths.append(p)
        elif p.is_dir():
            paths += sorted(p.glob("*.FITRES.gz"))
            paths += sorted(p.glob("*.ROOT*"))

    # prefer ASCII
    for p in paths:
        if p.suffixes[-2:] == [".FITRES", ".gz"] or p.name.upper().endswith(".FITRES.GZ"):
            return read_fitres_ascii_gz(p)
    # then ROOT (if possible)
    for p in paths:
        if (".ROOT" in p.name.upper()) and (uproot is not None):
            return read_fitres_root(p)

    print(f"[WARN] No readable FITRES among: {[str(pp) for pp in candidates]}. Continuing with empty FITRES.")
    return pd.DataFrame()

def standardize_fitres(df: pd.DataFrame) -> pd.DataFrame:
    if df is None or df.empty:
        return pd.DataFrame(columns=["ID_int","z","PKMJD"])
    _clean_chars_inplace(df)
    # ID
    if "CIDint" in df.columns:
        df["ID_int"] = pd.to_numeric(df["CIDint"], errors="coerce").astype("Int64")
    elif "CID" in df.columns:
        df["ID_int"] = pd.Series([_to_int64_safe(v) for v in df["CID"]], dtype="Int64")
    elif "SNID" in df.columns:
        df["ID_int"] = pd.Series([_to_int64_safe(v) for v in df["SNID"]], dtype="Int64")
    else:
        df["ID_int"] = pd.Series([_to_int64_safe(v) for v in df.index], dtype="Int64")

    # redshift
    z = None
    for zc in ("zHD","zCMB","z","ZCMB","Z"):
        if zc in df.columns:
            z = pd.to_numeric(df[zc], errors="coerce")
            break
    df["z"] = z

    # peak mjd
    pk = None
    for pc in ("PKMJD","PKMJD_SALT2","PKMJD_SNIa"):
        if pc in df.columns:
            pk = pd.to_numeric(df[pc], errors="coerce")
            break
    df["PKMJD"] = pk
    return df

def densest_year_window(mjd: np.ndarray, width_days: float = 365.25) -> tuple[float,float,int]:
    """Return (t0, t1, max_count) for a sliding window with max inclusions."""
    mjd = np.sort(mjd[np.isfinite(mjd)])
    if mjd.size == 0:
        return (np.nan, np.nan, 0)
    j0 = 0
    best = (mjd[0], mjd[0] + width_days, 1)
    for j1 in range(mjd.size):
        while mjd[j1] - mjd[j0] > width_days:
            j0 += 1
        cnt = j1 - j0 + 1
        if cnt > best[2]:
            best = (mjd[j0], mjd[j0] + width_days, cnt)
    return best

# ---- PHOT reader, ±10d/SN>5 check, Rosselli-style QC, N(z) + plot ----
def read_phot_fits(path: Path) -> pd.DataFrame:
    with fits.open(path) as hdul:
        arr = np.array(hdul[1].data)
    df = pd.DataFrame(arr.byteswap().newbyteorder())
    for c in ("MJD","FLUXCAL","FLUXCALERR","FLUX","FLUXERR","PHOTFLAG"):
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df

# ---------- Saturation helpers ----------
from astropy.io import fits

def discover_sat_mask_from_headers(phot_path: Path) -> int | None:
    """Return PHOTFLAG_SATURATE mask found anywhere in the PHOT FITS headers, else None."""
    try:
        with fits.open(phot_path) as hdus:
            for h in hdus:
                if "PHOTFLAG_SATURATE" in h.header:
                    return int(h.header["PHOTFLAG_SATURATE"])
    except Exception:
        pass
    return None

def epoch_saturation_mask_df(phot_df: pd.DataFrame, phot_path: Path | None = None) -> np.ndarray:
    """
    Build an epoch-level saturation boolean mask for a PHOT DataFrame.
    Priority:
      1) Use PHOTFLAG & PHOTFLAG_SATURATE bit (from header) if available.
      2) Fallback sentinel: |FLUXCAL| < 1e-6 and |FLUXCALERR| > 1e7.
    """
    # Ensure numeric coercion
    for c in ("FLUXCAL","FLUXCALERR","FLUX","FLUXERR","PHOTFLAG"):
        if c in phot_df.columns:
            phot_df[c] = pd.to_numeric(phot_df[c], errors="coerce")

    # Try PHOTFLAG bit
    if ("PHOTFLAG" in phot_df.columns) and (phot_path is not None):
        sat_mask = discover_sat_mask_from_headers(phot_path)
        if sat_mask is not None:
            return (phot_df["PHOTFLAG"].fillna(0).astype(np.int64) & int(sat_mask)) > 0

    # Fallback sentinel (robust across sims)
    if {"FLUXCAL","FLUXCALERR"}.issubset(phot_df.columns):
        return (phot_df["FLUXCAL"].abs() < 1e-6) & (phot_df["FLUXCALERR"].abs() > 1e7)
    # Last resort: try raw flux columns if present
    if {"FLUX","FLUXERR"}.issubset(phot_df.columns):
        return (phot_df["FLUX"].abs() < 1e-6) & (phot_df["FLUXERR"].abs() > 1e7)
    # If neither convention is present, assume no epochs are flagged
    return np.zeros(len(phot_df), dtype=bool)

def _slice_indices_from_pointers(pmin, pmax, n_rows: int) -> tuple[int,int]:
    """
    Map SNANA HEAD pointers to a Python slice [a:b).
    Tries 1-based inclusive first, then 0-based inclusive.
    """
    pmin = int(pmin) if pd.notna(pmin) else 0
    pmax = int(pmax) if pd.notna(pmax) else -1
    a = max(0, pmin - 1); b = min(n_rows, pmax)
    if b > a:
        return a, b
    a = max(0, pmin); b = min(n_rows, pmax + 1)
    if b > a:
        return a, b
    return (0, 0)

def nosat_mask_for_run(
    fit_df: pd.DataFrame,
    head_df: pd.DataFrame,
    phot_df: pd.DataFrame,
    phot_path: Path | None = None,
    rest_window: tuple[float,float] | None = None,  # e.g., (-10.0, +10.0) to ignore saturation outside ±10d
) -> pd.Series:
    """
    Return a boolean Series aligned to fit_df.index:
      True  -> SN passes the "no saturation" cut (i.e., has ZERO saturated epochs in the selected window)
      False -> SN has ≥1 saturated epoch in the selected window
    If rest_window is None, look over all epochs; otherwise limit to [tmin, tmax] in rest-frame days from PKMJD.
    """
    ok = pd.Series(True, index=fit_df.index, dtype=bool)

    if phot_df is None or phot_df.empty or head_df is None or head_df.empty:
        # No information → don't veto anything
        return ok

    # Build epoch-level saturation mask once
    sat_epoch = epoch_saturation_mask_df(phot_df, phot_path=phot_path)
    n_rows = len(phot_df)

    # Fast lookup for HEAD pointers by SN ID
    need_cols = {"ID_int","PTROBS_MIN","PTROBS_MAX"}
    if not need_cols.issubset(head_df.columns):
        return ok

    head_idx = (head_df[list(need_cols)]
                .dropna()
                .astype({"PTROBS_MIN":"Int64","PTROBS_MAX":"Int64"})
                .set_index("ID_int"))

    # Iterate SNe in this run
    for i, row in fit_df.iterrows():
        sid = row.get("ID_int", pd.NA)
        if pd.isna(sid) or sid not in head_idx.index:
            continue

        pmin = head_idx.at[sid, "PTROBS_MIN"]
        pmax = head_idx.at[sid, "PTROBS_MAX"]
        a, b = _slice_indices_from_pointers(pmin, pmax, n_rows)
        if b <= a:
            continue

        # Optionally restrict to a rest-frame window around peak
        if rest_window is not None and {"MJD"}.issubset(phot_df.columns):
            z  = pd.to_numeric(row.get("z", np.nan), errors="coerce")
            pk = pd.to_numeric(row.get("PKMJD", np.nan), errors="coerce")
            if np.isfinite(z) and np.isfinite(pk):
                t_rest = (phot_df.loc[a:b, "MJD"].to_numpy() - pk) / (1.0 + z)
                mwin = (t_rest >= rest_window[0]) & (t_rest <= rest_window[1])
                if np.any(sat_epoch[a:b][mwin]):
                    ok.at[i] = False
                continue  # done with this SN (windowed path)

        # Otherwise: any saturated epoch anywhere in the LC slice
        if np.any(sat_epoch[a:b]):
            ok.at[i] = False

    return ok

def _sn_pm10_snr5_ok_for_run(fit_df: pd.DataFrame, head_df: pd.DataFrame, phot_df: pd.DataFrame) -> pd.Series:
    """
    For a single run (matching HEAD<->PHOT pointers), return a boolean Series (indexed like fit_df)
    that is True if the SN has >=3 points with S/N>5 within ±10 rest-frame days of PKMJD.
    Uses HEAD PTROBS_MIN/MAX to slice the PHOT table.
    """
    ok = pd.Series(False, index=fit_df.index, dtype=bool)  # label-aligned

    if phot_df is None or phot_df.empty:
        return ok
    if {"FLUXCAL","FLUXCALERR"}.issubset(phot_df.columns):
        flux_col, err_col = "FLUXCAL", "FLUXCALERR"
    elif {"FLUX","FLUXERR"}.issubset(phot_df.columns):
        flux_col, err_col = "FLUX", "FLUXERR"
    else:
        return ok

    if not {"ID_int","PTROBS_MIN","PTROBS_MAX"}.issubset(head_df.columns):
        return ok

    head_idx = (head_df[["ID_int","PTROBS_MIN","PTROBS_MAX"]]
                .dropna()
                .astype({"PTROBS_MIN":"Int64","PTROBS_MAX":"Int64"})
                .set_index("ID_int"))

    for i, row in fit_df.iterrows():  # i is label
        sid = row.get("ID_int", pd.NA)
        if pd.isna(sid) or sid not in head_idx.index:
            continue
        pmin = head_idx.at[sid, "PTROBS_MIN"]
        pmax = head_idx.at[sid, "PTROBS_MAX"]
        if pd.isna(pmin) or pd.isna(pmax):
            continue

        lo = max(int(pmin) - 1, 0)  # 1-based inclusive -> 0-based slice
        hi = min(int(pmax), len(phot_df))
        if hi <= lo:
            continue

        sl = phot_df.iloc[lo:hi]
        z  = row.get("z", np.nan)
        pk = row.get("PKMJD", np.nan)
        if not (np.isfinite(z) and np.isfinite(pk)) or sl.empty:
            continue

        snr = sl[flux_col].to_numpy() / np.maximum(sl[err_col].to_numpy(), 1e-9)
        t_rest = (sl["MJD"].to_numpy() - pk) / (1.0 + z)
        m = (t_rest >= -10.0) & (t_rest <= 10.0) & np.isfinite(snr)
        if np.count_nonzero(snr[m] > 5.0) >= 3:
            ok.loc[i] = True
    return ok

def ross_qc_with_report(fit_all: pd.DataFrame,
                        head_hi: pd.DataFrame, phot_hi: pd.DataFrame,
                        head_lo: pd.DataFrame, phot_lo: pd.DataFrame,
                        phot_hi_path: Path | None = None,
                        phot_lo_path: Path | None = None,
                        sat_window_rest: tuple[float,float] | None = None) -> pd.DataFrame:
    """
    Apply Rosselli+ style cuts to fit_all (already merged with HEAD columns) **and** drop any SN
    with saturated epochs (any filter). If sat_window_rest is provided (e.g., (-10,+10)),
    only saturation within that rest-frame window causes a veto.
    """
    N0 = len(fit_all)
    def keep_and_print(mask, tag):
        kept = int(mask.sum())
        print(f"[ROSS QC] {tag:<18} keep={kept:6d}/{N0:6d} ({kept/N0:5.1%})")
        return mask

    m_all = pd.Series(True, index=fit_all.index)

    # Standard Ross+ cuts
    m_fitprob = fit_all["FITPROB"].fillna(0) > 0.05 if "FITPROB" in fit_all.columns else pd.Series(True, index=fit_all.index)
    keep_and_print(m_fitprob, "FITPROB cut"); m_all &= m_fitprob

    m_x1 = np.abs(pd.to_numeric(fit_all.get("x1", np.nan), errors="coerce")) <= 3.0
    keep_and_print(m_x1, "x1 range"); m_all &= m_x1

    m_c  = np.abs(pd.to_numeric(fit_all.get("c",  np.nan), errors="coerce"))  <= 0.3
    keep_and_print(m_c, "c range");  m_all &= m_c

    m_pkmjd = pd.to_numeric(fit_all.get("PKMJDERR", np.nan), errors="coerce") <= 1.0
    keep_and_print(m_pkmjd, "PKMJDERR<=1d"); m_all &= m_pkmjd

    m_x1e = pd.to_numeric(fit_all.get("x1ERR", np.nan), errors="coerce") <= 1.0
    keep_and_print(m_x1e, "x1ERR<=1"); m_all &= m_x1e

    m_ce  = pd.to_numeric(fit_all.get("cERR",  np.nan), errors="coerce")  <= 0.05
    keep_and_print(m_ce,  "cERR<=0.05"); m_all &= m_ce

    # >=3 obs ±10d with S/N>5, per run (you already had this)
    ids_hi = set(head_hi["ID_int"].dropna().astype("Int64"))
    ids_lo = set(head_lo["ID_int"].dropna().astype("Int64"))
    fit_hi = fit_all[fit_all["ID_int"].isin(ids_hi)]
    fit_lo = fit_all[fit_all["ID_int"].isin(ids_lo)]

    obs_ok_hi = _sn_pm10_snr5_ok_for_run(fit_hi, head_hi, phot_hi)
    obs_ok_lo = _sn_pm10_snr5_ok_for_run(fit_lo, head_lo, phot_lo)
    m_obs = pd.Series(False, index=fit_all.index)
    m_obs.loc[fit_hi.index] = obs_ok_hi.values
    m_obs.loc[fit_lo.index] = obs_ok_lo.values
    keep_and_print(m_obs, ">=3 obs ±10d"); m_all &= m_obs

    # NEW: No-saturation cut (any epoch unless sat_window_rest is set)
    nosat_hi = nosat_mask_for_run(fit_hi, head_hi, phot_hi, phot_path=phot_hi_path, rest_window=sat_window_rest)
    nosat_lo = nosat_mask_for_run(fit_lo, head_lo, phot_lo, phot_path=phot_lo_path, rest_window=sat_window_rest)
    m_nosat = pd.Series(False, index=fit_all.index)
    m_nosat.loc[fit_hi.index] = nosat_hi.values
    m_nosat.loc[fit_lo.index] = nosat_lo.values
    keep_and_print(m_nosat, "no saturation"); m_all &= m_nosat

    kept = int(m_all.sum())
    print(f"[ROSS QC] TOTAL          keep={kept:6d}/{N0:6d} ({kept/N0:5.1%})")
    return fit_all.loc[m_all].copy()

def nz_hist(z, z_edges):
    h, _ = np.histogram(z, bins=z_edges)
    return h

def show_fig_y1_nz(z_mid, N_det, N_cos, N_cos_tw, DZ, Z_TW_MIN, Z_TW_MAX, Z_MAX):
    plt.figure(figsize=(10,6))
    plt.plot(z_mid, N_det, label=f"Detection (Y1, N={int(N_det.sum())})")
    plt.plot(z_mid, N_cos, label=f"Cosmology (Y1, N={int(N_cos.sum())})")
    plt.plot(z_mid, N_cos_tw, label=f"Cosmo + Twilight (Y1, N={int(N_cos_tw.sum())})")
    plt.axvspan(Z_TW_MIN, Z_TW_MAX, color="k", alpha=0.06, label="Twilight Promotion")
    plt.xlim(0.0, Z_MAX)
    plt.ylim(bottom=0)
    plt.xlabel(f"Redshift z"); plt.ylabel("Count per Δz")
    plt.grid(alpha=0.2); plt.legend(); plt.tight_layout()
    plt.show()


In [None]:
# =========================
# Cell 2 — Load HIGHZ+LOWZ HEAD/FITRES, standardize, combine, Year-1 slice
# =========================

# HEAD tables (detection catalog info)
head_hi = read_head_fits(HEAD_HIGHZ)
head_lo = read_head_fits(HEAD_LOWZ)
head_all = pd.concat([head_lo, head_hi], ignore_index=True)

# Type Ia & z-range
if "SNTYPE" in head_all.columns:
    head_all = head_all.loc[head_all["SNTYPE"] == 1].copy()
head_all["z"] = pd.to_numeric(head_all.get("z", np.nan), errors="coerce")
head_all = head_all.loc[(head_all["z"] > 0) & (head_all["z"] <= Z_MAX)].copy()

# FITRES (cosmology-fit outputs)
fit_hi_raw = read_fitres_any(FITRES_HIGHZ_ASCII, FITRES_HIGHZ_DIR)
fit_lo_raw = read_fitres_any(FITRES_LOWZ_ASCII if FITRES_LOWZ_ASCII.exists() else FITRES_LOWZ_DIR,
                              FITRES_LOWZ_ROOT if FITRES_LOWZ_ROOT.exists() else None)

fit_hi = standardize_fitres(fit_hi_raw)
fit_lo = standardize_fitres(fit_lo_raw)
fit_all = pd.concat([fit_lo, fit_hi], ignore_index=True)

fit_all["z"] = pd.to_numeric(fit_all.get("z", np.nan), errors="coerce")
fit_all = fit_all.loc[(fit_all["z"] > 0) & (fit_all["z"] <= Z_MAX)].copy()

# Merge HEAD columns into FITRES (carry pointers for ±10d cut)
cols_keep = ["ID_int","z"] + [c for c in ("PEAKMJD","SNTYPE","RA","DEC","NOBS","PTROBS_MIN","PTROBS_MAX") if c in head_all.columns]
fit_all = fit_all.merge(head_all.loc[:, cols_keep], on="ID_int", how="inner", suffixes=("", "_HEAD"))

# Prefer z from FITRES when present, else HEAD
fit_all["z"] = np.where(np.isfinite(fit_all["z"].values),
                        fit_all["z"].values,
                        pd.to_numeric(fit_all.get("z_HEAD", np.nan), errors="coerce"))

# Year-1 densest window by PKMJD (from FITRES)
if "PKMJD" not in fit_all.columns or fit_all["PKMJD"].notna().sum() == 0:
    raise SystemExit("PKMJD not found in FITRES — required for Year-1 densest-window selection.")
t0, t1, _ = densest_year_window(fit_all["PKMJD"].to_numpy(), 365.25)

# Slice Y1: detection (HEAD by PEAKMJD) and cosmology (FITRES by PKMJD)
if "PEAKMJD" in head_all.columns:
    head_y1 = head_all.loc[(head_all["PEAKMJD"] >= t0) & (head_all["PEAKMJD"] < t1)].copy()
else:
    head_y1 = head_all.merge(fit_all[["ID_int","PKMJD"]], on="ID_int", how="left")
    head_y1 = head_y1.loc[(head_y1["PKMJD"] >= t0) & (head_y1["PKMJD"] < t1)].copy()

fitres_y1 = fit_all.loc[(fit_all["PKMJD"] >= t0) & (fit_all["PKMJD"] < t1)].copy()

# Basic sanity prints
def _z_stats(z):
    z = np.asarray(z[np.isfinite(z)])
    if z.size == 0:
        return np.array([np.nan]*5)
    return np.array([np.min(z), np.percentile(z,1), np.percentile(z,50),
                     np.percentile(z,99), np.max(z)])

zs_head = _z_stats(head_all["z"])
zs_fit  = _z_stats(fit_all["z"])
print(f"[HEAD]   N={len(head_all):,}  z[min, p1, p50, p99, max] = [{zs_head[0]:.5f} {zs_head[1]:.5f} {zs_head[2]:.5f} {zs_head[3]:.5f} {zs_head[4]:.5f}]")
print(f"[FITRES] N={len(fit_all):,}  z[min, p1, p50, p99, max] = [{zs_fit[0]:.5f} {zs_fit[1]:.5f} {zs_fit[2]:.5f} {zs_fit[3]:.5f} {zs_fit[4]:.5f}]")
print(f"[Y1] window = [{t0:.1f}, {t1:.1f})  HEAD_Y1={len(head_y1):,}  FITRES_Y1={len(fitres_y1):,}")


In [None]:
# =========================
# Cell 3 — PHOT load, Ross QC, Y1 hist/figure, banded counts, WRITE binned CSVs
# =========================

# Load PHOT once (per run) for the ±10d cut
phot_hi_df = read_phot_fits(PHOT_HIGHZ)
phot_lo_df = read_phot_fits(PHOT_LOWZ)

# Rosselli-style QC on the FULL fit catalog (before Y1 slice)
fit_qc = ross_qc_with_report(
    fit_all,
    head_hi=head_hi, phot_hi=phot_hi_df,
    head_lo=head_lo, phot_lo=phot_lo_df
)

# Cosmology Y1 = QC + Y1 time slice
fit_qc_y1 = fit_qc.loc[(fit_qc["PKMJD"] >= t0) & (fit_qc["PKMJD"] < t1)].copy()

# Histograms
DZ = 0.01
Z_TW_MIN, Z_TW_MAX = 0.02, 0.22

z_edges = np.arange(0.0, Z_MAX + DZ + 1e-12, DZ)
z_mid   = 0.5*(z_edges[:-1] + z_edges[1:])

N_det = nz_hist(head_y1["z"].to_numpy(float), z_edges)
N_cos = nz_hist(fit_qc_y1["z"].to_numpy(float), z_edges)

# Twilight promotion at low-z
band = (z_edges[:-1] >= Z_TW_MIN) & (z_edges[1:] <= Z_TW_MAX)
N_cos_tw = N_cos.copy()
N_cos_tw[band] = np.maximum(N_cos_tw[band], N_det[band])

# Plot (figure)
show_fig_y1_nz(z_mid, N_det, N_cos, N_cos_tw, DZ, Z_TW_MIN, Z_TW_MAX, Z_MAX)

# Banded summary counts
low_band = (z_mid >= Z_TW_MIN) & (z_mid < Z_TW_MAX)
hi_band  = (z_mid >= Z_TW_MAX) & (z_mid < Z_MAX)

det_low = int(N_det[low_band].sum()); det_hi = int(N_det[hi_band].sum())
cos_low = int(N_cos[low_band].sum()); cos_hi = int(N_cos[hi_band].sum())
tw_low  = int(N_cos_tw[low_band].sum()); tw_hi = int(N_cos_tw[hi_band].sum())

print("\n--- Summary Counts ---")
print(f"[LOW-Z 0.02–0.14]  DET={det_low}  COS={cos_low}  COS_Tw={tw_low}")
print(f"[HI-Z 0.14–1.20]   DET={det_hi}  COS={cos_hi}  COS_Tw={tw_hi}")

# --- Build sigma_mu per bin from the QC Year-1 cosmology sample and WRITE CSVs ---

DERIVED = Path("/Users/tz/Documents/GitHub/Lsst-Twilight-SNe/Motivation/derived")
DERIVED.mkdir(parents=True, exist_ok=True)

# Per-SN distance-modulus uncertainty (robust to missing columns)
def _sigma_mu_per_sn(df: pd.DataFrame) -> pd.Series:
    alpha = pd.to_numeric(df.get("SIM_alpha", df.get("alpha", 0.14)), errors="coerce").fillna(0.14)
    beta  = pd.to_numeric(df.get("SIM_beta",  df.get("beta",  3.10)), errors="coerce").fillna(3.10)
    mBERR = pd.to_numeric(df.get("mBERR", np.nan), errors="coerce").fillna(0.12)
    x1ERR = pd.to_numeric(df.get("x1ERR", np.nan), errors="coerce").fillna(0.9)
    cERR  = pd.to_numeric(df.get("cERR",  np.nan), errors="coerce").fillna(0.04)
    cov   = pd.to_numeric(df.get("COV_x1_c", df.get("COV_x1c", 0.0)), errors="coerce").fillna(0.0)
    z     = pd.to_numeric(df.get("z", np.nan), errors="coerce").astype(float)

    mu2 = (mBERR**2) + (alpha*x1ERR)**2 + (beta*cERR)**2 - 2.0*alpha*beta*cov
    sig_lens = 0.055*z
    sig_vpec = (5.0/np.log(10.0))*(300.0/(299792.458*np.maximum(z, 1e-3)))
    mu2 = mu2 + (0.08**2) + (sig_lens**2) + (sig_vpec**2)
    return np.sqrt(np.maximum(mu2, 0.0))

# per-SN sigma_mu on the QC+Y1 cosmology sample
fit_qc_y1 = fit_qc_y1.copy()
fit_qc_y1["sigma_mu_sn"] = _sigma_mu_per_sn(fit_qc_y1)

# per-bin median sigma_mu
sigma_bin = np.full_like(z_mid, np.nan, dtype=float)
for k in range(len(z_mid)):
    m = (fit_qc_y1["z"] >= z_edges[k]) & (fit_qc_y1["z"] < z_edges[k+1])
    if m.any():
        sigma_bin[k] = float(fit_qc_y1.loc[m, "sigma_mu_sn"].median())

# fill empty bins smoothly, fallback to global median
if np.isnan(sigma_bin).any():
    s = pd.Series(sigma_bin)
    s = s.fillna(method="ffill").fillna(method="bfill")
    global_med = float(np.nanmedian(fit_qc_y1["sigma_mu_sn"])) if "sigma_mu_sn" in fit_qc_y1 else 0.12
    s = s.fillna(global_med)
    sigma_bin = s.to_numpy()

# write Fisher-ready BINNED CSVs (preferred & compat names)
df_base = pd.DataFrame({"z": z_mid, "N": N_cos,    "sigma_mu": sigma_bin})
df_tw   = pd.DataFrame({"z": z_mid, "N": N_cos_tw, "sigma_mu": sigma_bin})

df_base.to_csv(DERIVED / "y1_cat_bin_base_ep_lsst.csv", index=False)
df_tw.to_csv  (DERIVED / "y1_cat_bin_tw_ep_lsst.csv",   index=False)

df_base.to_csv(DERIVED / "y1_cat_bin_base_fix.csv", index=False)
df_tw.to_csv  (DERIVED / "y1_cat_bin_tw_fix.csv",   index=False)

# optional: combined histogram dump
pd.DataFrame({
    "z_lo": z_edges[:-1], "z_hi": z_edges[1:],
    "N_det": N_det, "N_cosmo_WFD": N_cos, "N_cosmo_WFD_Tw": N_cos_tw,
    "sigma_mu_bin": sigma_bin
}).to_csv(DERIVED / "y1_nz_hist_ep_lsst.csv", index=False)

print(f"[WRITE] wrote binned Fisher inputs to: {DERIVED.name}/")
print("        y1_cat_bin_base_ep_lsst.csv / y1_cat_bin_tw_ep_lsst.csv")
print("        y1_cat_bin_base_fix.csv  / y1_cat_bin_tw_fix.csv (compat)")


In [None]:
# =========================
# Cell 3b — Create & write binned Fisher CSVs (run BEFORE Cell 4)
# =========================
from pathlib import Path
import numpy as np
import pandas as pd

# ---- config / paths ----
DERIVED = Path("/Users/tz/Documents/GitHub/Lsst-Twilight-SNe/Motivation/derived")
DERIVED.mkdir(parents=True, exist_ok=True)

DZ    = globals().get("DZ", 0.01)
Z_MAX = float(globals().get("Z_MAX", 1.20))

# ---- ensure we have Y1 samples ----
# Expect head_y1 (detection, from HEAD) and fit_qc_y1 (QC+Y1 cosmology) from earlier cells.
# Fallbacks are provided to make this cell self-contained.

def _require_y1_samples():
    global head_y1, fit_qc_y1, fitres_y1, t0, t1
    # head_y1
    if "head_y1" not in globals():
        if "head_all" in globals() and "fit_all" in globals():
            if "t0" not in globals() or "t1" not in globals():
                t0, t1, _ = densest_year_window(fit_all["PKMJD"].to_numpy(), 365.25)
            if "PEAKMJD" in head_all.columns:
                head_y1 = head_all.loc[(head_all["PEAKMJD"] >= t0) & (head_all["PEAKMJD"] < t1)].copy()
            else:
                head_y1 = head_all.merge(fit_all[["ID_int","PKMJD"]], on="ID_int", how="left")
                head_y1 = head_y1.loc[(head_y1["PKMJD"] >= t0) & (head_y1["PKMJD"] < t1)].copy()
        else:
            raise RuntimeError("head_y1 is missing and head_all/fit_all not found. Run Cell 2 first.")
    # fit_qc_y1
    if "fit_qc_y1" not in globals():
        if "fitres_y1" in globals():
            # fallback: no QC, use fitres_y1 directly
            print("[WARN] fit_qc_y1 not found — using fitres_y1 (no QC) to build binned files.")
            fit_qc_y1 = fitres_y1.copy()
        else:
            raise RuntimeError("fit_qc_y1/fitres_y1 missing. Run QC/Cell 3 first.")

_require_y1_samples()

# ---- build z-grid and histograms ----
z_edges = np.arange(0.0, Z_MAX + DZ + 1e-12, DZ)
z_mid   = 0.5*(z_edges[:-1] + z_edges[1:])

def _nz(zvals):
    return np.histogram(np.asarray(zvals, dtype=float), bins=z_edges)[0]

# detection & cosmology counts
N_det = _nz(head_y1["z"])
N_cos = _nz(fit_qc_y1["z"])

# twilight promotion
band = (z_edges[:-1] >= Z_TW_MIN) & (z_edges[1:] <= Z_TW_MAX)
N_cos_tw = N_cos.copy()
N_cos_tw[band] = np.maximum(N_cos_tw[band], N_det[band])

# ---- per-SN sigma_mu and per-bin median sigma_mu ----
def _sigma_mu_per_sn(df: pd.DataFrame) -> pd.Series:
    alpha = pd.to_numeric(df.get("SIM_alpha", df.get("alpha", 0.14)), errors="coerce").fillna(0.14)
    beta  = pd.to_numeric(df.get("SIM_beta",  df.get("beta",  3.10)), errors="coerce").fillna(3.10)
    mBERR = pd.to_numeric(df.get("mBERR", np.nan), errors="coerce").fillna(0.12)
    x1ERR = pd.to_numeric(df.get("x1ERR", np.nan), errors="coerce").fillna(0.9)
    cERR  = pd.to_numeric(df.get("cERR",  np.nan), errors="coerce").fillna(0.04)
    cov   = pd.to_numeric(df.get("COV_x1_c", df.get("COV_x1c", 0.0)), errors="coerce").fillna(0.0)
    z     = pd.to_numeric(df.get("z", np.nan), errors="coerce").astype(float)
    mu2 = (mBERR**2) + (alpha*x1ERR)**2 + (beta*cERR)**2 - 2.0*alpha*beta*cov
    sig_lens = 0.055*z
    sig_vpec = (5.0/np.log(10.0))*(300.0/(299792.458*np.maximum(z, 1e-3)))
    mu2 = mu2 + (0.08**2) + (sig_lens**2) + (sig_vpec**2)
    return np.sqrt(np.maximum(mu2, 0.0))

fit_qc_y1 = fit_qc_y1.copy()
fit_qc_y1["sigma_mu_sn"] = _sigma_mu_per_sn(fit_qc_y1)

sigma_bin = np.full_like(z_mid, np.nan, dtype=float)
for k in range(len(z_mid)):
    m = (fit_qc_y1["z"] >= z_edges[k]) & (fit_qc_y1["z"] < z_edges[k+1])
    if m.any():
        sigma_bin[k] = float(fit_qc_y1.loc[m, "sigma_mu_sn"].median())

# fill missing bins by interpolation then fallback to global median
if np.isnan(sigma_bin).any():
    s = pd.Series(sigma_bin).fillna(method="ffill").fillna(method="bfill")
    global_med = float(np.nanmedian(fit_qc_y1["sigma_mu_sn"])) if "sigma_mu_sn" in fit_qc_y1 else 0.12
    sigma_bin = s.fillna(global_med).to_numpy()

# ---- write Fisher-ready binned CSVs (both preferred + compat names) ----
df_base = pd.DataFrame({"z": z_mid, "N": N_cos,    "sigma_mu": sigma_bin})
df_tw   = pd.DataFrame({"z": z_mid, "N": N_cos_tw, "sigma_mu": sigma_bin})

base_ep = DERIVED / "y1_cat_bin_base_ep_lsst.csv"
tw_ep   = DERIVED / "y1_cat_bin_tw_ep_lsst.csv"
base_fx = DERIVED / "y1_cat_bin_base_fix.csv"
tw_fx   = DERIVED / "y1_cat_bin_tw_fix.csv"

df_base.to_csv(base_ep, index=False); df_tw.to_csv(tw_ep, index=False)
df_base.to_csv(base_fx, index=False); df_tw.to_csv(tw_fx, index=False)

# Your new scenario: double counts in the twilight band
df_tw_guess = df_tw.copy()
df_tw_guess.loc[band, 'N'] = (df_base.loc[band, 'N'] * 2).astype(int)
tw_guess_ep = DERIVED / "y1_cat_bin_tw_guess_ep_lsst.csv"
df_tw_guess.to_csv(tw_guess_ep, index=False)
df_tw_guess.to_csv(DERIVED / "y1_cat_bin_tw_guess_fix.csv", index=False) # Optional: add a compat name for the guess


# optional: combined histogram dump (useful for QA)
pd.DataFrame({
    "z_lo": z_edges[:-1], "z_hi": z_edges[1:],
    "N_det": N_det, "N_cosmo_WFD": N_cos, "N_cosmo_WFD_Tw": N_cos_tw,
    "N_cosmo_WFD_Tw_Guess": df_tw_guess['N'], # Add your new scenario to the QA file
    "sigma_mu_bin": sigma_bin
}).to_csv(DERIVED / "y1_nz_hist_ep_lsst.csv", index=False)

print(f"[WRITE] Created binned files in {DERIVED}:")
print("        - y1_cat_bin_base_ep_lsst.csv")
print("        - y1_cat_bin_tw_ep_lsst.csv")
print("        - y1_cat_bin_tw_guess_ep_lsst.csv (Your Guess scenario)") # Add a printout for confirmation
print("        - y1_cat_bin_base_fix.csv  (compat)")
print("        - y1_cat_bin_tw_fix.csv    (compat)")
print(f"        Totals -> WFD N={int(df_base['N'].sum())}, WFD+Twilight N={int(df_tw['N'].sum())}, WFD+Twilight (Guess) N={int(df_tw_guess['N'].sum())}") # Update the final printout



In [None]:
# =========================
# Cell 4 — Load binned catalogs (now that Cell 3 wrote them) + print dataset info
# =========================
from pathlib import Path

DERIVED = Path("/Users/tz/Documents/GitHub/Lsst-Twilight-SNe/Motivation/derived")

# Preferred filenames produced by Cell 3
BASE_FILE = DERIVED / "y1_cat_bin_base_ep_lsst.csv"   # WFD-only cosmology (combined)
TW_FILE   = DERIVED / "y1_cat_bin_tw_ep_lsst.csv"     # WFD+Twilight cosmology (combined)

def _latest(pattern: str) -> Path:
    paths = sorted(DERIVED.glob(pattern), key=lambda p: p.stat().st_mtime)
    if not paths:
        raise FileNotFoundError(f"No files match: {pattern}")
    return paths[-1]

# Fallbacks if you kept a different suffix (compat names written in Cell 3)
if not BASE_FILE.exists():
    BASE_FILE = _latest("y1_cat_bin_base_*.csv")
if not TW_FILE.exists():
    TW_FILE   = _latest("y1_cat_bin_tw_*.csv")

print(f"[binned] using BASE: {BASE_FILE.name}")
print(f"[binned] using TW  : {TW_FILE.name}")

df_wfd_bin    = pd.read_csv(BASE_FILE)
df_twfull_bin = pd.read_csv(TW_FILE)


# Load your new "guess" scenario CSV
TW_GUESS_FILE = DERIVED / "y1_cat_bin_tw_guess_ep_lsst.csv"
print(f"[binned] using GUESS: {TW_GUESS_FILE.name}")
df_twguess_bin = pd.read_csv(TW_GUESS_FILE)

# Also apply the Z_MAX cut to it
if "Z_MAX" in globals():
    df_twguess_bin = df_twguess_bin.query("0 < z <= @Z_MAX").copy()


# Safety checks and optional z-cut
assert {"z","N","sigma_mu"}.issubset(df_wfd_bin.columns)
assert {"z","N","sigma_mu"}.issubset(df_twfull_bin.columns)
if "Z_MAX" in globals():
    df_wfd_bin    = df_wfd_bin.query("0 < z <= @Z_MAX").copy()
    df_twfull_bin = df_twfull_bin.query("0 < z <= @Z_MAX").copy()

# Print the exact binned data fed into Fisher
def _info_binned(name, df):
    nz_mask = df["N"] > 0
    Ntot    = int(df["N"].sum())
    nbins   = int(nz_mask.sum())
    if nbins:
        zmin = float(df.loc[nz_mask, "z"].min())
        zmax = float(df.loc[nz_mask, "z"].max())
        zbar = float(np.average(df.loc[nz_mask, "z"], weights=df.loc[nz_mask, "N"]))
        sigm = float(df.loc[nz_mask, "sigma_mu"].median())
    else:
        zmin = zmax = zbar = sigm = np.nan
    wsum = float((df["N"] / (df["sigma_mu"]**2)).sum())
    print(f"[{name}] bins>0={nbins:4d}  N_total={Ntot:7,d}  "
          f"z[min,⟨z⟩,max]=[{zmin:0.3f}, {zbar:0.3f}, {zmax:0.3f}]  "
          f"median σμ={sigm:0.3f}  Σ(N/σμ²)={wsum:0.1f}")

_info_binned("WFD", df_wfd_bin)
_info_binned("WFD+Twilight", df_twfull_bin)

# (Optional) peek first/last few active bins
def _head_tail_nonzero(df, k=3):
    nz = df[df["N"] > 0][["z","N","sigma_mu"]]
    print("  first bins:\n", nz.head(k).to_string(index=False))
    print("  last  bins:\n", nz.tail(k).to_string(index=False))

_head_tail_nonzero(df_wfd_bin)
_head_tail_nonzero(df_twfull_bin)


In [None]:

import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

C_KMS = 299792.458  # km/s

@dataclass
class CosmoParams:
    Om: float = 0.3
    w0: float = -1.0
    wa: float = 0.0
    H0: float = 70.0
    M: float = 0.0

def Ez_flat_w0wa(z: np.ndarray, p: CosmoParams) -> np.ndarray:
    z = np.asarray(z, dtype=float)
    Om = p.Om
    Ode = 1.0 - Om
    de_factor = (1.0 + z) ** (3.0 * (1.0 + p.w0 + p.wa)) * np.exp(-3.0 * p.wa * z / (1.0 + z))
    return np.sqrt(Om * (1.0 + z) ** 3 + Ode * de_factor)

def DC_Mpc(z: np.ndarray, p: CosmoParams, n_steps: int = 4096) -> np.ndarray:
    z = np.atleast_1d(z).astype(float)
    zmax = np.max(z)
    zz = np.linspace(0.0, zmax, n_steps)
    invE = 1.0 / Ez_flat_w0wa(zz, p)
    if len(zz) % 2 == 0:
        zz = zz[:-1]
        invE = invE[:-1]
    primitive = np.cumsum((invE[:-1] + invE[1:]) * 0.5 * (zz[1:] - zz[:-1]))
    primitive = np.concatenate([[0.0], primitive])
    integral = np.interp(z, zz, primitive)
    return (C_KMS / p.H0) * integral

def DL_Mpc(z: np.ndarray, p: CosmoParams) -> np.ndarray:
    z = np.asarray(z, dtype=float)
    return (1.0 + z) * DC_Mpc(z, p)

def mu_theory(z: np.ndarray, p: CosmoParams) -> np.ndarray:
    dl = DL_Mpc(z, p)
    return 5.0 * np.log10(dl) + 25.0 + p.M

@dataclass
class FisherSetup:
    vary_params: Tuple[str, ...]
    fid: CosmoParams = field(default_factory=CosmoParams)
    step_frac: Dict[str, float] = field(default_factory=lambda: {'Om': 1e-3, 'w0': 1e-3, 'wa': 1e-3, 'M': 1e-3})

def jacobian_mu(z: np.ndarray, p: CosmoParams, setup: FisherSetup) -> np.ndarray:
    J_cols = []
    for name in setup.vary_params + ('M',):
        if name == 'M':
            J_cols.append(np.ones_like(z, dtype=float))
            continue
        step = setup.step_frac[name] * getattr(p, name if name != 'Om' else 'Om')
        if step == 0.0:
            step = 1e-5
        p_plus = CosmoParams(**vars(p))
        p_minus = CosmoParams(**vars(p))
        setattr(p_plus, name, getattr(p_plus, name) + step)
        setattr(p_minus, name, getattr(p_minus, name) - step)
        mu_p = mu_theory(z, p_plus)
        mu_m = mu_theory(z, p_minus)
        J_cols.append((mu_p - mu_m) / (2.0 * step))
    J = np.vstack(J_cols).T
    return J

def fisher_SN(z: np.ndarray, sigma_mu: np.ndarray, setup: FisherSetup) -> Tuple[np.ndarray, List[str]]:
    z = np.asarray(z, dtype=float)
    sigma_mu = np.asarray(sigma_mu, dtype=float)
    assert z.shape == sigma_mu.shape
    p = setup.fid
    J = jacobian_mu(z, p, setup)
    Cinv = np.diag(1.0 / (sigma_mu ** 2))
    F_full = J.T @ Cinv @ J
    n = F_full.shape[0]
    idx_M = n - 1
    idx_theta = list(range(n - 1))
    F_tt = F_full[np.ix_(idx_theta, idx_theta)]
    F_tM = F_full[np.ix_(idx_theta, [idx_M])]
    F_Mt = F_full[np.ix_([idx_M], idx_theta)]
    F_MM = F_full[idx_M, idx_M]
    F_marg = F_tt - F_tM @ np.linalg.inv(np.array([[F_MM]])) @ F_Mt
    labels = list(setup.vary_params)
    return F_marg, labels

def cov_to_sigmas(F: np.ndarray, labels: List[str]) -> Tuple[Dict[str, float], np.ndarray]:
    C = np.linalg.inv(F)
    errs = {lab: float(np.sqrt(C[i, i])) for i, lab in enumerate(labels)}
    return errs, C

EXPECTED_COLS = ['z', 'sigma_mu']

def load_catalog(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    missing = [c for c in EXPECTED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"{path} is missing required columns: {missing}")
    return df[EXPECTED_COLS].copy()

def combine_wfd_twilight(df_wfd: pd.DataFrame, df_twilight: pd.DataFrame) -> pd.DataFrame:
    return pd.concat([df_wfd, df_twilight], ignore_index=True)

def sigma_mu_from_salt2_row(
    mBERR: float,
    x1ERR: float,
    cERR: float,
    COV_mB_x1: float,
    COV_mB_c: float,
    COV_x1_c: float,
    z: float,
    alpha: float = 0.14,
    beta: float = 3.1,
    sigma_int: float = 0.10,
    sigma_vpec_kms: float = 300.0,
) -> float:
    var = (
        mBERR**2
        + (alpha**2) * (x1ERR**2)
        + (beta**2) * (cERR**2)
        + 2.0 * alpha * COV_mB_x1
        - 2.0 * beta * COV_mB_c
        - 2.0 * alpha * beta * COV_x1_c
        + sigma_int**2
    )
    if z <= 0:
        z = 1e-4
    var += (5.0 / np.log(10.0)) ** 2 * (sigma_vpec_kms / (C_KMS * z)) ** 2
    return float(np.sqrt(var))

def build_catalog_from_fitres(
    fitres_csv_path: str,
    alpha: float = 0.14,
    beta: float = 3.1,
    sigma_int: float = 0.10,
    sigma_vpec_kms: float = 300.0,
    z_col: str = 'zHD',
) -> pd.DataFrame:
    df = pd.read_csv(fitres_csv_path)
    needed = ['mBERR', 'x1ERR', 'cERR', 'COV_x1_c', 'COV_mB_c', 'COV_mB_x1', z_col]
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise ValueError(f"FITRES is missing required columns: {missing}")
    sigmas = []
    zs = df[z_col].values
    for i in range(len(df)):
        sig = sigma_mu_from_salt2_row(
            mBERR=df.loc[i, 'mBERR'],
            x1ERR=df.loc[i, 'x1ERR'],
            cERR=df.loc[i, 'cERR'],
            COV_mB_x1=df.loc[i, 'COV_mB_x1'],
            COV_mB_c=df.loc[i, 'COV_mB_c'],
            COV_x1_c=df.loc[i, 'COV_x1_c'],
            z=zs[i],
            alpha=alpha,
            beta=beta,
            sigma_int=sigma_int,
            sigma_vpec_kms=sigma_vpec_kms,
        )
        sigmas.append(sig)
    out = pd.DataFrame({'z': zs, 'sigma_mu': sigmas})
    return out

def run_forecast(df_wfd: pd.DataFrame, df_twilight: pd.DataFrame, model: str = 'lcdm') -> Dict[str, Dict]:
    if model not in {'lcdm', 'w0wa'}:
        raise ValueError("model must be 'lcdm' or 'w0wa'")
    if model == 'lcdm':
        vary = ('Om',)
    else:
        vary = ('Om', 'w0', 'wa')
    setup = FisherSetup(vary_params=vary, fid=CosmoParams())
    df_combo = combine_wfd_twilight(df_wfd, df_twilight)
    out = {}
    for label, df in [('WFD', df_wfd), ('WFD+Twilight', df_combo)]:
        F, labels = fisher_SN(df['z'].values, df['sigma_mu'].values, setup)
        errs, C = cov_to_sigmas(F, labels)
        out[label] = {'errs': errs, 'cov': C, 'labels': labels, 'N': len(df)}
    return out

def pretty_print_results(res: Dict[str, Dict]):
    for key in res:
        labels = res[key]['labels']
        errs = res[key]['errs']
        N = res[key]['N']
        print(f"\n=== {key} (N={N}) ===")
        for lab in labels:
            print(f"σ({lab}) = {errs[lab]:.4f}")

def plot_pairwise_contours(res, pair=('w0','wa')):
    fig = plt.figure(dpi=150)
    ax = plt.gca()
    xmin = ymin = +1e9
    xmax = ymax = -1e9
    centers = {'Om': 0.3, 'w0': -1.0, 'wa': 0.0}
    for label in ['WFD', 'WFD+Twilight']:
        labels = res[label]['labels']
        C = res[label]['cov']
        i = labels.index(pair[0])
        j = labels.index(pair[1])
        cov2 = C[np.ix_([i,j],[i,j])]
        vals, vecs = np.linalg.eigh(cov2)
        order = vals.argsort()[::-1]
        vals = vals[order]
        vecs = vecs[:, order]
        angle = np.arctan2(vecs[1, 0], vecs[0, 0])
        a = np.sqrt(vals[0]); b = np.sqrt(vals[1])
        cx = centers.get(pair[0], 0.0)
        cy = centers.get(pair[1], 0.0)
        t = np.linspace(0, 2*np.pi, 200)
        for nsig in (1.0, 2.0):
            x = cx + nsig * (a*np.cos(t)*np.cos(angle) - b*np.sin(t)*np.sin(angle))
            y = cy + nsig * (a*np.cos(t)*np.sin(angle) + b*np.sin(t)*np.cos(angle))
            ax.plot(x, y, label=f'{label} {nsig:.0f}σ' if nsig==1.0 else None)
            xmin, xmax = min(xmin, x.min()), max(xmax, x.max())
            ymin, ymax = min(ymin, y.min()), max(ymax, y.max())
    ax.set_xlabel(pair[0]); ax.set_ylabel(pair[1]); ax.legend()
    padx = 0.1*(xmax-xmin); pady = 0.1*(ymax-ymin)
    ax.set_xlim(xmin-padx, xmax+padx); ax.set_ylim(ymin-pady, ymax+pady)
    plt.title(f'Forecast contours: {pair[0]} vs {pair[1]}')
    plt.show()


In [None]:
# =========================
# Cell 5 — Fisher: ΛCDM & w0waCDM plots, with dataset info already printed by Cell 4
# =========================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Fisher from binned
def fisher_from_binned(df_bin: pd.DataFrame, setup: FisherSetup):
    z = df_bin["z"].to_numpy(float)
    w = (df_bin["N"].to_numpy(float)) / (df_bin["sigma_mu"].to_numpy(float) ** 2)
    J = jacobian_mu(z, CosmoParams(), setup)  # last column is dμ/dM
    W = np.diag(w)
    F_full = J.T @ W @ J
    # marginalize over nuisance M (last)
    n = F_full.shape[0]
    idx_M = n - 1
    idx_theta = list(range(n - 1))
    F_tt = F_full[np.ix_(idx_theta, idx_theta)]
    F_tM = F_full[np.ix_(idx_theta, [idx_M])]
    F_Mt = F_full[np.ix_([idx_M], idx_theta)]
    F_MM = F_full[idx_M, idx_M]
    F_marg = F_tt - F_tM @ np.linalg.inv(np.array([[F_MM]])) @ F_Mt
    labels = list(setup.vary_params)
    return F_marg, labels

def run_binned_forecast(df_base, df_twfull, model="lcdm"):
    vary = ("Om",) if model=="lcdm" else ("Om","w0","wa")
    setup = FisherSetup(vary_params=vary, fid=CosmoParams())
    # WFD
    F_wfd, labels = fisher_from_binned(df_base, setup)
    errs_wfd, C_wfd = cov_to_sigmas(F_wfd, labels)
    # WFD+Twilight
    F_tw, _ = fisher_from_binned(df_twfull, setup)
    errs_tw, C_tw = cov_to_sigmas(F_tw, labels)
    return {"WFD": {"errs": errs_wfd, "cov": C_wfd, "labels": labels},
            "WFD+Twilight": {"errs": errs_tw, "cov": C_tw, "labels": labels}}

# Colors
COLORS = {
    "WFD": {"primary": "#1f77b4", "light": "#9ecae1"},
    "WFD+Twilight": {"primary": "#2ca02c", "light": "#98df8a"},
}
FID_CENTERS = {"Om": 0.3, "w0": -1.0, "wa": 0.0}

def ellipse_points_from_cov(C2, n_sigma=1.0, n_pts=400):
    delta = 2.30 if np.isclose(n_sigma,1.0) else (6.17 if np.isclose(n_sigma,2.0) else n_sigma**2)
    vals, vecs = np.linalg.eigh(C2)
    order = vals.argsort()[::-1]
    vals = vals[order]; vecs = vecs[:, order]
    a = np.sqrt(delta * vals[0]); b = np.sqrt(delta * vals[1])
    t = np.linspace(0, 2*np.pi, n_pts)
    circ = np.stack([a*np.cos(t), b*np.sin(t)], axis=0)
    pts = (vecs @ circ)
    return pts[0], pts[1]

def plot_corner(res, params=("Om","w0","wa"), title="w0waCDM Forecast"):
    labs = res["WFD"]["labels"]; idx = [labs.index(p) for p in params]; n = len(params)
    fig, axes = plt.subplots(n, n, figsize=(3.2*n, 3.2*n), dpi=150)
    for r in range(n):
        for c in range(n):
            ax = axes[r, c]
            if r < c: ax.axis('off'); continue
            p_i = params[c]
            if r == c:
                for label in ["WFD", "WFD+Twilight"]:
                    C = res[label]["cov"]; i = labs.index(p_i); s = np.sqrt(C[i,i])
                    x0 = FID_CENTERS[p_i]; xs = np.linspace(x0-5*s, x0+5*s, 600)
                    ys = np.exp(-0.5*((xs-x0)/s)**2)/(s*np.sqrt(2*np.pi))
                    ax.plot(xs, ys, color=COLORS[label]["primary"], lw=2)
                ax.set_ylabel("PDF"); ax.set_xlabel(p_i)
            else:
                p_x, p_y = params[c], params[r]; i, j = labs.index(p_x), labs.index(p_y)
                for label in ["WFD", "WFD+Twilight"]:
                    C = res[label]["cov"]; C2 = C[np.ix_([i,j],[i,j])]
                    ex, ey = ellipse_points_from_cov(C2, n_sigma=1.0)
                    ex2, ey2 = ellipse_points_from_cov(C2, n_sigma=2.0)
                    ax.plot(FID_CENTERS[p_x]+ex,  FID_CENTERS[p_y]+ey,  color=COLORS[label]["primary"], lw=2)
                    ax.plot(FID_CENTERS[p_x]+ex2, FID_CENTERS[p_y]+ey2, color=COLORS[label]["light"],   lw=1.5)
                ax.set_xlabel(p_x); ax.set_ylabel(p_y)
    from matplotlib.lines import Line2D
    handles = [Line2D([0],[0], color=COLORS["WFD"]["primary"], lw=2, label="WFD 1σ"),
               Line2D([0],[0], color=COLORS["WFD"]["light"],   lw=1.5, label="WFD 2σ"),
               Line2D([0],[0], color=COLORS["WFD+Twilight"]["primary"], lw=2, label="WFD+Twilight 1σ"),
               Line2D([0],[0], color=COLORS["WFD+Twilight"]["light"],   lw=1.5, label="WFD+Twilight 2σ")]
    axes[0,0].legend(handles=handles, frameon=False, loc="upper right")
    fig.suptitle(title, y=0.93); fig.tight_layout(); plt.show()

def plot_lcdm_1d(res, param="Om", title="ΛCDM Forecast"):
    fig = plt.figure(dpi=150); ax = plt.gca(); x0 = FID_CENTERS[param]
    sigs = {}
    for label in ["WFD", "WFD+Twilight"]:
        labs = res[label]["labels"]; C = res[label]["cov"]; i = labs.index(param); sigs[label] = np.sqrt(C[i,i])
    span = 5.0*max(sigs.values()); xs = np.linspace(x0-3*span, x0+3*span, 800)
    for label in ["WFD", "WFD+Twilight"]:
        s = sigs[label]; ys = np.exp(-0.5*((xs-x0)/s)**2)/(s*np.sqrt(2*np.pi))
        ax.plot(xs, ys, label=f"{label} (σ={s:.4g})", color=COLORS[label]["primary"])
    ax.set_xlabel(param); ax.set_ylabel("Gaussian (norm.)"); ax.legend(frameon=False)
    if title: plt.title(title); plt.show()

# Run forecasts & plots
res_lcdm  = run_binned_forecast(df_wfd_bin, df_twfull_bin, model="lcdm")
res_w0wa  = run_binned_forecast(df_wfd_bin, df_twfull_bin, model="w0wa")

def _print_errors(res, name):
    labs = res["WFD"]["labels"]
    sig_wfd = {p: np.sqrt(res["WFD"]["cov"][i,i]) for i,p in enumerate(labs)}
    sig_tw  = {p: np.sqrt(res["WFD+Twilight"]["cov"][i,i]) for i,p in enumerate(labs)}
    print(f"\n[{name}] 1σ parameter uncertainties")
    for p in labs:
        print(f"  {p:>3s}:  WFD={sig_wfd[p]:.4g}   WFD+Twilight={sig_tw[p]:.4g}   improvement x{sig_wfd[p]/sig_tw[p]:.2f}")

_print_errors(res_lcdm, "ΛCDM")
_print_errors(res_w0wa, "w0waCDM")

plot_lcdm_1d(res_lcdm, param="Om", title="ΛCDM: constraint on Ωm")
plot_corner(res_w0wa, params=("Om","w0","wa"), title="w0waCDM: Ωm–w0–wa constraints (1σ/2σ)")


In [None]:
# ===================================================================
# START: Replacement and new code block
# ===================================================================

# --- Run forecasts for all three scenarios ---
# Baseline vs. Twilight
res_lcdm  = run_binned_forecast(df_wfd_bin, df_twfull_bin, model="lcdm")
res_w0wa  = run_binned_forecast(df_wfd_bin, df_twfull_bin, model="w0wa")

# Baseline vs. Your Guess
res_guess_lcdm = run_binned_forecast(df_wfd_bin, df_twguess_bin, model="lcdm")
res_guess_w0wa = run_binned_forecast(df_wfd_bin, df_twguess_bin, model="w0wa")


# --- Print a comparison of all results ---
def _print_errors_comparison(res_orig, res_guess, name):
    labs = res_orig["WFD"]["labels"]
    sig_wfd = {p: res_orig["WFD"]["errs"][p] for p in labs}
    sig_tw = {p: res_orig["WFD+Twilight"]["errs"][p] for p in labs}
    sig_guess = {p: res_guess["WFD+Twilight"]["errs"][p] for p in labs} # Note: uses res_guess

    print(f"\n[{name}] 1-sigma parameter uncertainty comparison")
    for p in labs:
        print(f"  {p:>3s}:")
        print(f"    WFD               = {sig_wfd[p]:.4g}")
        print(f"    WFD+Twilight      = {sig_tw[p]:.4g}   (improvement x{sig_wfd[p]/sig_tw[p]:.2f})")
        print(f"    WFD+Twilight(Guess)= {sig_guess[p]:.4g}   (improvement x{sig_wfd[p]/sig_guess[p]:.2f})")

_print_errors_comparison(res_lcdm, res_guess_lcdm, "ΛCDM")
_print_errors_comparison(res_w0wa, res_guess_w0wa, "w0waCDM")


# --- Update plotting functions to include the third scenario ---
COLORS["WFD+Twilight (Guess)"] = {"primary": "#ff7f0e", "light": "#ffbb78"} # Define a color for the new scenario

def plot_lcdm_1d_comparison(res_list, param="Om", title="ΛCDM Forecast Comparison"):
    fig = plt.figure(dpi=150); ax = plt.gca(); x0 = FID_CENTERS[param]
    max_s = 0
    for label, res in res_list.items():
        labs = res["labels"]; C = res["cov"]; i = labs.index(param); s = np.sqrt(C[i,i])
        max_s = max(max_s, s)

    span = 5.0*max_s; xs = np.linspace(x0-3*span, x0+3*span, 800)

    for label, res in res_list.items():
        labs = res["labels"]; C = res["cov"]; i = labs.index(param); s = np.sqrt(C[i,i])
        ys = np.exp(-0.5*((xs-x0)/s)**2)/(s*np.sqrt(2*np.pi))
        ax.plot(xs, ys, label=f"{label} (σ={s:.4g})", color=COLORS[label]["primary"])

    ax.set_xlabel(param); ax.set_ylabel("Gaussian (norm.)"); ax.legend(frameon=False)
    if title: plt.title(title); plt.show()

def plot_corner_comparison(res_list, params=("Om","w0","wa"), title="w0waCDM Forecast Comparison (1σ/2σ)"):
    labs_ref = res_list["WFD"]["labels"]; idx = [labs_ref.index(p) for p in params]; n = len(params)
    fig, axes = plt.subplots(n, n, figsize=(3.2*n, 3.2*n), dpi=150)
    for r in range(n):
        for c in range(n):
            ax = axes[r, c]
            if r < c: ax.axis('off'); continue
            p_i = params[c]
            if r == c:
                for label, res in res_list.items():
                    C = res["cov"]; i = labs_ref.index(p_i); s = np.sqrt(C[i,i])
                    x0 = FID_CENTERS[p_i]; xs = np.linspace(x0-5*s, x0+5*s, 600)
                    ys = np.exp(-0.5*((xs-x0)/s)**2)/(s*np.sqrt(2*np.pi))
                    ax.plot(xs, ys, color=COLORS[label]["primary"], lw=2)
                ax.set_ylabel("PDF"); ax.set_xlabel(p_i)
            else:
                p_x, p_y = params[c], params[r]; i, j = labs_ref.index(p_x), labs_ref.index(p_y)
                for label, res in res_list.items():
                    C = res["cov"]; C2 = C[np.ix_([i,j],[i,j])]
                    ex, ey = ellipse_points_from_cov(C2, n_sigma=1.0)
                    ex2, ey2 = ellipse_points_from_cov(C2, n_sigma=2.0)
                    ax.plot(FID_CENTERS[p_x]+ex,  FID_CENTERS[p_y]+ey,  color=COLORS[label]["primary"], lw=2, label=f"{label} 1σ")
                    ax.plot(FID_CENTERS[p_x]+ex2, FID_CENTERS[p_y]+ey2, color=COLORS[label]["light"],   lw=1.5, label=f"{label} 2σ")
                ax.set_xlabel(p_x); ax.set_ylabel(p_y)

    from matplotlib.lines import Line2D
    handles = [Line2D([0],[0], color=COLORS[label]["primary"], lw=2, label=f"{label} 1σ") for label in res_list]
    axes[0,0].legend(handles=handles, frameon=False, loc="upper right")
    fig.suptitle(title, y=0.93); fig.tight_layout(); plt.show()


# --- Generate final plots including all three scenarios ---
all_res_lcdm = {
    "WFD": res_lcdm["WFD"],
    "WFD+Twilight": res_lcdm["WFD+Twilight"],
    "WFD+Twilight (Guess)": res_guess_lcdm["WFD+Twilight"]
}
all_res_w0wa = {
    "WFD": res_w0wa["WFD"],
    "WFD+Twilight": res_w0wa["WFD+Twilight"],
    "WFD+Twilight (Guess)": res_guess_w0wa["WFD+Twilight"]
}

plot_lcdm_1d_comparison(all_res_lcdm, param="Om", title="ΛCDM: constraint on Ωm (Comparison)")
plot_corner_comparison(all_res_w0wa, params=("Om","w0","wa"), title="w0waCDM: Ωm–w0–wa constraints (Comparison)")

# ===================================================================
# END: Replacement and new code block
# ===================================================================