<a href="https://colab.research.google.com/github/shaikh2010/seasonal-two-sex-matrix-caregiver/blob/main/Untitled39.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ==============================================================
# Polar bear (SBS; caregiver baseline; P=1) replication script
# --------------------------------------------------------------
# This Colab-ready script reproduces the polar-bear caregiver-baseline
# results reported in the manuscript (see Section X.X / Appendix Y):
#   1) Constructs the two-sex projection operator L (block lower-triangular)
#   2) Computes lambda = rho(L) and R0 = rho( R (I - U_f)^{-1} )
#   3) Recreates the manuscript figures:
#        - pb_projection_balanced.png
#        - pb_projection_log.png
#        - pb_elasticities.png
#   4) Writes machine-readable audit artifacts under ./out/
#
# Reproducibility contract:
#   - Given the fixed primitives in PB (below), outputs are deterministic.
#   - No external data are downloaded; no randomness is used.
# Environment:
#   - Tested in Google Colab (Python 3.x; numpy/pandas/matplotlib).
# ==============================================================


import numpy as np
import numpy.linalg as la
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import json
import shutil

# -----------------------------
# Paths / outputs
# -----------------------------
ROOT = Path(".")
OUT = ROOT / "out"
FIGS = OUT / "figs"
PARAMS = OUT / "params"
SUMMARY = OUT / "summary"
TABLES = OUT / "tables"

for d in [FIGS, PARAMS, SUMMARY, TABLES]:
    d.mkdir(parents=True, exist_ok=True)


# -----------------------------
# Helpers
# -----------------------------
def spectral_radius(A: np.ndarray) -> float:
    """Spectral radius rho(A) = max |eig(A)|."""
    vals = la.eigvals(A)
    return float(np.max(np.abs(vals)))


def perron_left_right(A: np.ndarray):
    """
    Perron (dominant) eigenvalue + left/right eigenvectors for a nonnegative primitive matrix.
    Returns (lam, v, w) where:
      A w = lam w,  v^T A = lam v^T
      v^T w = 1
    """
    vals, vecs = la.eig(A)
    idx = int(np.argmax(np.abs(vals)))
    lam = float(np.real(vals[idx]))

    w = np.real(vecs[:, idx])

    valsL, vecsL = la.eig(A.T)
    idxL = int(np.argmax(np.abs(valsL)))
    v = np.real(vecsL[:, idxL])

    # Make them "mostly" positive (sign convention)
    if np.sum(w) < 0:
        w = -w
    if np.sum(v) < 0:
        v = -v

    # Enforce strictly real and normalize v^T w = 1
    denom = float(v @ w)
    if abs(denom) < 1e-14:
        raise RuntimeError("Normalization failed: v^T w ~ 0.")
    v = v / denom

    return lam, v, w


def check_column_sums_leq_one(U: np.ndarray, name: str, tol: float = 1e-12):
    col_sums = U.sum(axis=0)
    bad = np.where(col_sums > 1.0 + tol)[0]
    if bad.size > 0:
        raise ValueError(
            f"{name}: column sums exceed 1 (to-from convention). "
            f"Bad columns (1-based): {list(bad + 1)}; sums={col_sums[bad]}"
        )


def save_json(path: Path, obj):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2)


# -----------------------------
# Polar-bear primitives (Table pbparams; SBS; P=1)
# -----------------------------
PB = {
    "alpha": 0.50,                 # female fraction among recruits
    "s_f1": 0.62,                  # cub -> subadult survival
    "s_f2": 0.88,                  # subadult -> solitary adult survival
    "s_adult_prim": 0.93,          # solitary adult survival (primitive)
    "s_mother_COY_prim": 0.93,     # COY-mother survival (primitive)
    "s_mother_YRL_prim": 0.93,     # yearling-mother survival (primitive)
    "p_breed": 0.45,               # breeding probability
    "s_COY": 0.65,                 # dependent COY survival
    "s_YRL": 0.80,                 # dependent yearling survival (reported; not used in mother-only bookkeeping)
    "litter_size": 1.6,            # mean litter size
    # male survival proxies (do not govern lambda in caregiver baseline if rho(M_f) > rho(M_m))
    "s_m1": 0.90,
    "s_m2": 0.90,
    "s_m3": 0.96,
}

# Stage definitions (fixed for this model)
female_stages = ["F1 (cub)", "F2 (subadult)", "F3 (solitary adult)", "F4 (COY-mother)", "F5 (YRL-mother)"]
male_stages   = ["M1 (juvenile)", "M2 (subadult)", "M3 (adult)"]


# -----------------------------
# Build matrices (caregiver baseline; P=1)
# -----------------------------
def build_polar_bear_blocks(prm: dict):
    """
    Constructs:
      U_f (5x5), U_m (3x3), fecundity vector f (5,), alpha
    under the mother-only caregiver bookkeeping in the manuscript.
    """
    alpha = float(prm["alpha"])

    s_f1 = float(prm["s_f1"])
    s_f2 = float(prm["s_f2"])

    s_adult = float(prm["s_adult_prim"])
    s_mCOY  = float(prm["s_mother_COY_prim"])
    s_mYRL  = float(prm["s_mother_YRL_prim"])

    p_breed = float(prm["p_breed"])
    s_COY   = float(prm["s_COY"])
    litter  = float(prm["litter_size"])

    # Derived annual fecundity (total recruits to next census, both sexes, per adult female per year)
    b3 = p_breed * litter

    # Composite caregiver-loop entries (to-from indexing: row=To, col=From)
    u33 = s_adult * (1.0 - p_breed)      # F3 -> F3
    u43 = s_adult * p_breed              # F3 -> F4
    u54 = s_mCOY * s_COY                 # F4 -> F5
    u34 = s_mCOY * (1.0 - s_COY)         # F4 -> F3
    u35 = s_mYRL                         # F5 -> F3 (mother-only bookkeeping)

    # Female survival/transition block U_f
    U_f = np.zeros((5, 5), dtype=float)
    U_f[1, 0] = s_f1     # F1 -> F2
    U_f[2, 1] = s_f2     # F2 -> F3
    U_f[2, 2] = u33      # F3 -> F3
    U_f[3, 2] = u43      # F3 -> F4
    U_f[4, 3] = u54      # F4 -> F5
    U_f[2, 3] = u34      # F4 -> F3
    U_f[2, 4] = u35      # F5 -> F3

    # Male survival/transition block U_m (proxy)
    U_m = np.zeros((3, 3), dtype=float)
    U_m[1, 0] = float(prm["s_m1"])  # M1 -> M2
    U_m[2, 1] = float(prm["s_m2"])  # M2 -> M3
    U_m[2, 2] = float(prm["s_m3"])  # M3 -> M3

    # Fecundity vector f (total recruits to next census, both sexes, per female stage)
    f = np.zeros(5, dtype=float)
    f[2] = b3  # only solitary adult females reproduce in this baseline

    # Validation: survival blocks should have column sums <= 1
    check_column_sums_leq_one(U_f, "U_f")
    check_column_sums_leq_one(U_m, "U_m")

    # Also sanity-check probability ranges
    if not (0.0 <= alpha <= 1.0):
        raise ValueError("alpha must be in [0,1].")
    if np.any(U_f < -1e-15) or np.any(U_m < -1e-15):
        raise ValueError("Matrices must be nonnegative.")
    if np.any(f < -1e-15):
        raise ValueError("Fecundities must be nonnegative.")

    # Composite mapping table (audit)
    mapping_rows = [
        {"to": "F2", "from": "F1", "symbol": "u^(f)_{2,1}", "expression": "s_f1", "value": s_f1},
        {"to": "F3", "from": "F2", "symbol": "u^(f)_{3,2}", "expression": "s_f2", "value": s_f2},
        {"to": "F3", "from": "F3", "symbol": "u^(f)_{3,3}", "expression": "s_adult*(1-p_breed)", "value": u33},
        {"to": "F4", "from": "F3", "symbol": "u^(f)_{4,3}", "expression": "s_adult*p_breed", "value": u43},
        {"to": "F5", "from": "F4", "symbol": "u^(f)_{5,4}", "expression": "s_mCOY*s_COY", "value": u54},
        {"to": "F3", "from": "F4", "symbol": "u^(f)_{3,4}", "expression": "s_mCOY*(1-s_COY)", "value": u34},
        {"to": "F3", "from": "F5", "symbol": "u^(f)_{3,5}", "expression": "s_mYRL", "value": u35},
        {"to": "F1", "from": "F3", "symbol": "(A_f)_{1,3}", "expression": "alpha*b3 = alpha*(p_breed*litter)", "value": alpha * b3},
    ]

    derived = {
        "b3_total_recruits_both_sexes": b3,
        "u33": u33, "u43": u43, "u54": u54, "u34": u34, "u35": u35
    }

    return U_f, U_m, f, alpha, mapping_rows, derived


def build_caregiver_baseline_operator(U_f: np.ndarray, U_m: np.ndarray, f: np.ndarray, alpha: float):
    """
    Builds:
      R (female recruitment matrix, 5x5), A_f (5x5),
      B (male recruit block, 3x5), L (8x8).
    """
    nf = U_f.shape[0]
    nm = U_m.shape[0]

    e1f = np.zeros(nf, dtype=float); e1f[0] = 1.0
    e1m = np.zeros(nm, dtype=float); e1m[0] = 1.0

    R = alpha * np.outer(e1f, f)                  # 5x5
    A_f = U_f + R                                  # 5x5
    B = (1.0 - alpha) * np.outer(e1m, f)           # 3x5

    L = np.zeros((nf + nm, nf + nm), dtype=float)
    L[:nf, :nf] = A_f
    L[nf:, :nf] = B
    L[nf:, nf:] = U_m

    return R, A_f, B, L


def compute_R0(U_f: np.ndarray, R: np.ndarray) -> float:
    """
    For P=1: K = R (I - U_f)^(-1),  R0 = rho(K)
    Uses a linear solve (no explicit inverse in user-facing logic).
    """
    nf = U_f.shape[0]
    I = np.eye(nf)
    # Solve (I - U_f) X = I  => X = (I - U_f)^(-1)
    X = la.solve(I - U_f, I)
    K = R @ X
    return spectral_radius(K), K


def entrywise_elasticities(A: np.ndarray):
    """
    Entrywise elasticities of rho(A) w.r.t. entries a_ij:
      S_ij = v_i w_j
      E_ij = (a_ij / rho(A)) * S_ij
    """
    lam, v, w = perron_left_right(A)
    S = np.outer(v, w)
    E = (A / lam) * S
    return lam, v, w, S, E


# -----------------------------
# Build and compute polar-bear results
# -----------------------------
U_f, U_m, f, alpha, composite_mapping_rows, derived = build_polar_bear_blocks(PB)
R, A_f, B, L = build_caregiver_baseline_operator(U_f, U_m, f, alpha)

lambda_full = spectral_radius(L)         # = max(rho(A_f), rho(U_m)) because L is block lower-triangular
lambda_f = spectral_radius(A_f)
rho_Um = spectral_radius(U_m)
rho_Uf = spectral_radius(U_f)

R0, K = compute_R0(U_f, R)

print("=== Polar bear (SBS; P=1) caregiver baseline ===")
print(f"rho(U_f)  = {rho_Uf:.6f}")
print(f"rho(U_m)  = {rho_Um:.6f}")
print(f"lambda    = {lambda_full:.6f}  (should equal rho(A_f) when rho(A_f) >= rho(U_m))")
print(f"rho(A_f)  = {lambda_f:.6f}")
print(f"R0        = {R0:.6f}")
print("")
print("Expected (from manuscript):")
print("  lambda ≈ 1.040848")
print("  R0     ≈ 1.678714")


# -----------------------------
# Write LaTeX macros for the manuscript (optional but convenient)
# -----------------------------
def write_metrics_macros(path: Path, pb_lambda: float, pb_R0: float):
    path.parent.mkdir(parents=True, exist_ok=True)
    lines = [
        "% Auto-generated by polar-bear replication code",
        f"\\providecommand{{\\pbRzero}}{{{pb_R0:.6f}}}",
        f"\\providecommand{{\\pbLambda}}{{{pb_lambda:.6f}}}",
        "",
    ]
    path.write_text("\n".join(lines), encoding="utf-8")

# Write in repo root (so \IfFileExists{metrics_macros.tex}{...} can find it),
# and also in out/ for auditability.
write_metrics_macros(ROOT / "metrics_macros.tex", lambda_full, R0)
write_metrics_macros(OUT / "metrics_macros.tex", lambda_full, R0)


# -----------------------------
# Save audit artifacts (optional but helpful)
# -----------------------------
# 1) primitives
save_json(PARAMS / "params_polarbear.json", PB)

# 2) expanded inputs (matrices + derived)
save_json(
    PARAMS / "polar_bear_input_full.json",
    {
        "stages": {"female": female_stages, "male": male_stages},
        "alpha": alpha,
        "fecundity_vector_f_total_recruits_both_sexes": f.tolist(),
        "derived": derived,
        "U_f": U_f.tolist(),
        "U_m": U_m.tolist(),
        "R_female_recruitment_matrix": R.tolist(),
        "A_f": A_f.tolist(),
        "B_male_recruit_block": B.tolist(),
        "L_full_two_sex_baseline": L.tolist(),
        "metrics": {
            "rho_Uf": rho_Uf,
            "rho_Um": rho_Um,
            "lambda": lambda_full,
            "rho_Af": lambda_f,
            "R0": R0,
        },
    },
)

# 3) composite mapping table
pd.DataFrame(composite_mapping_rows).to_csv(PARAMS / "pb_composite_mapping.csv", index=False)

# 4) full nonzero entries table (U_f, U_m, and induced recruitment entries)
def nonzero_entries_table():
    rows = []
    # U_f
    for i in range(U_f.shape[0]):
        for j in range(U_f.shape[1]):
            if abs(U_f[i, j]) > 0:
                rows.append({
                    "block": "U_f",
                    "to_index": i + 1, "from_index": j + 1,
                    "to_stage": female_stages[i], "from_stage": female_stages[j],
                    "value": float(U_f[i, j])
                })
    # U_m
    for i in range(U_m.shape[0]):
        for j in range(U_m.shape[1]):
            if abs(U_m[i, j]) > 0:
                rows.append({
                    "block": "U_m",
                    "to_index": i + 1, "from_index": j + 1,
                    "to_stage": male_stages[i], "from_stage": male_stages[j],
                    "value": float(U_m[i, j])
                })
    # Induced recruitment entries (A_f and B) from fecundity
    for i in range(A_f.shape[0]):
        for j in range(A_f.shape[1]):
            if abs(R[i, j]) > 0:
                rows.append({
                    "block": "A_f (recruitment part)",
                    "to_index": i + 1, "from_index": j + 1,
                    "to_stage": female_stages[i], "from_stage": female_stages[j],
                    "value": float(R[i, j])
                })
    for i in range(B.shape[0]):
        for j in range(B.shape[1]):
            if abs(B[i, j]) > 0:
                rows.append({
                    "block": "B (male recruits)",
                    "to_index": i + 1, "from_index": j + 1,
                    "to_stage": male_stages[i], "from_stage": female_stages[j],
                    "value": float(B[i, j])
                })
    df = pd.DataFrame(rows)
    return df.sort_values(["block", "from_index", "to_index"]).reset_index(drop=True)

nz_df = nonzero_entries_table()
nz_df.to_csv(PARAMS / "pb_U_complete_nonzero.csv", index=False)

# 5) summary JSON (machine-readable)
save_json(
    SUMMARY / "RESULTS_SUMMARY_polar_bear.json",
    {
        "system": "polar_bear_SBS",
        "period_P": 1,
        "metrics": {
            "R0": R0,
            "lambda": lambda_full,
            "rho_Uf": rho_Uf,
            "rho_Um": rho_Um,
        },
        "primitives": PB,
    },
)


# -----------------------------
# Projections (30 years) and figures
# -----------------------------
def project_constant(L: np.ndarray, x0: np.ndarray, T: int) -> np.ndarray:
    """Project x(t+1)=L x(t) for t=0..T-1 (P=1)."""
    X = np.zeros((T + 1, L.shape[0]), dtype=float)
    X[0, :] = x0
    for t in range(T):
        X[t + 1, :] = L @ X[t, :]
    return X


def perron_right_vector(A: np.ndarray) -> np.ndarray:
    """Dominant right eigenvector (real, mostly positive)."""
    vals, vecs = la.eig(A)
    idx = int(np.argmax(np.abs(vals)))
    w = np.real(vecs[:, idx])
    if np.sum(w) < 0:
        w = -w
    # ensure nonnegative (small numerical noise)
    w = np.maximum(w, 0.0)
    return w


T_years = 30
time = np.arange(T_years + 1)

# Log-scale panel: initialize at Perron right eigenvector of L (scaled for visualization)
wL = perron_right_vector(L)

# Scale so total females at t=0 are 1000 (arbitrary units)
female_total_target = 1000.0
scale = female_total_target / max(wL[:5].sum(), 1e-15)
x0_log = wL * scale

# Linear "balanced" panel: rescale male totals at t=0 to match female totals (for visibility)
x0_bal = x0_log.copy()
female_total0 = x0_bal[:5].sum()
male_total0 = x0_bal[5:].sum()
if male_total0 > 0:
    x0_bal[5:] *= (female_total0 / male_total0)

X_log = project_constant(L, x0_log, T_years)
X_bal = project_constant(L, x0_bal, T_years)

# Figure 1: pb_projection_balanced.png (linear scale)
plt.figure(figsize=(8, 4.8))
for i in range(5):
    plt.plot(time, X_bal[:, i], linestyle="-", label=female_stages[i])
for j in range(3):
    plt.plot(time, X_bal[:, 5 + j], linestyle="--", label=male_stages[j])

plt.xlabel("Year (t)")
plt.ylabel("Abundance (arbitrary units)")
plt.title("Polar bear projections (caregiver baseline; balanced start; linear scale)")
plt.grid(True, alpha=0.3)
plt.legend(ncol=2, fontsize=9)
plt.tight_layout()
pb_balanced_path = FIGS / "pb_projection_balanced.png"
plt.savefig(pb_balanced_path, dpi=300)
plt.close()

# Figure 2: pb_projection_log.png (log scale)
plt.figure(figsize=(8, 4.8))
for i in range(5):
    plt.plot(time, X_log[:, i], linestyle="-", label=female_stages[i])
for j in range(3):
    plt.plot(time, X_log[:, 5 + j], linestyle="--", label=male_stages[j])

plt.xlabel("Year (t)")
plt.ylabel("Abundance (log scale)")
plt.yscale("log")
plt.title("Polar bear projections (caregiver baseline; Perron start; log scale)")
plt.grid(True, which="both", alpha=0.3)
plt.legend(ncol=2, fontsize=9)
plt.tight_layout()
pb_log_path = FIGS / "pb_projection_log.png"
plt.savefig(pb_log_path, dpi=300)
plt.close()


# -----------------------------
# Elasticities figure: pb_elasticities.png
# -----------------------------
lamA, vA, wA, SA, EA = entrywise_elasticities(A_f)

# Collect nonzero entries of A_f for a clean bar plot
rows = []
for i in range(A_f.shape[0]):
    for j in range(A_f.shape[1]):
        if abs(A_f[i, j]) > 0:
            rows.append({
                "to_index": i + 1,
                "from_index": j + 1,
                "to_stage": female_stages[i],
                "from_stage": female_stages[j],
                "A_entry": float(A_f[i, j]),
                "elasticity": float(EA[i, j]),
                "label": f"{female_stages[j].split(' ')[0]} → {female_stages[i].split(' ')[0]}",
                "detail": f"{female_stages[j]} → {female_stages[i]}",
            })

elas_df = pd.DataFrame(rows).sort_values("elasticity", ascending=True).reset_index(drop=True)
elas_df.to_csv(TABLES / "pb_elasticities_entries.csv", index=False)

plt.figure(figsize=(9, 4.8))
plt.barh(elas_df["detail"], elas_df["elasticity"])
plt.xlabel("Elasticity of ρ(A_f) (sum over all entries = 1)")
plt.title("Polar bear: entrywise elasticities of female projection matrix A_f")
plt.grid(True, axis="x", alpha=0.3)
plt.tight_layout()
pb_elas_path = FIGS / "pb_elasticities.png"
plt.savefig(pb_elas_path, dpi=300)
plt.close()


# -----------------------------
# Convenience: also copy figures to the notebook root (optional)
# so LaTeX can find them via \graphicspath{{./}{out/figs/}}
# -----------------------------
for fname in ["pb_projection_balanced.png", "pb_projection_log.png", "pb_elasticities.png"]:
    src = FIGS / fname
    if src.exists():
        shutil.copy2(src, ROOT / fname)

print("\nFigures written to:")
print(f"  {pb_balanced_path}")
print(f"  {pb_log_path}")
print(f"  {pb_elas_path}")
print("\nAlso copied to current directory for convenience:")
print("  ./pb_projection_balanced.png")
print("  ./pb_projection_log.png")
print("  ./pb_elasticities.png")


=== Polar bear (SBS; P=1) caregiver baseline ===
rho(U_f)  = 0.930000
rho(U_m)  = 0.960000
lambda    = 1.040848  (should equal rho(A_f) when rho(A_f) >= rho(U_m))
rho(A_f)  = 1.040848
R0        = 1.678714

Expected (from manuscript):
  lambda ≈ 1.040848
  R0     ≈ 1.678714

Figures written to:
  out/figs/pb_projection_balanced.png
  out/figs/pb_projection_log.png
  out/figs/pb_elasticities.png

Also copied to current directory for convenience:
  ./pb_projection_balanced.png
  ./pb_projection_log.png
  ./pb_elasticities.png


In [2]:
# ============================================================
# PURPOSE AND MANUSCRIPT LINKAGE
#
# This script implements the uncertainty / parametric bootstrap
# procedure described in Section X.X of the manuscript.
#
# Specifically:
#   - Beta-PERT sampling corresponds to Eq. (12) and Table S3
#   - The caregiver-baseline SBS model (P=1) corresponds to
#     Section 2.3 and Appendix B
#   - Lambda and R0 computations correspond to Propositions 1–2
#
# The script is intended for reproducibility and auditability,
# not for interactive model development.
# ============================================================

from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, Tuple, Optional

import numpy as np
import pandas as pd


# ----------------------------
# Utilities
# ----------------------------

def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def spectral_radius(A: np.ndarray) -> float:
    """Spectral radius rho(A) for a small dense matrix."""
    eigvals = np.linalg.eigvals(A)
    return float(np.max(np.abs(eigvals)))

def fmt(x: float, nd: int = 4) -> str:
    """Format a float for LaTeX tables."""
    return f"{x:.{nd}f}"

def safe_float(x: Any) -> float:
    try:
        return float(x)
    except Exception as e:
        raise ValueError(f"Cannot convert to float: {x!r}") from e


# ----------------------------
# Beta-PERT sampler
# ----------------------------

def beta_pert_sample(
    rng: np.random.Generator,
    a: float,
    m: float,
    b: float,
    shape_lambda: float = 4.0,
    size: Optional[int] = None,
) -> np.ndarray:
    """
    Draw from a bounded Beta-PERT distribution on [a, b] with mode m.

    Standard PERT mapping:
        alpha = 1 + λ * (m - a) / (b - a)
        beta  = 1 + λ * (b - m) / (b - a)
        X ~ Beta(alpha, beta), return a + X * (b - a)

    Parameters must satisfy: a <= m <= b and b > a (unless degenerate).
    """
    a = float(a); m = float(m); b = float(b); shape_lambda = float(shape_lambda)
    if b < a:
        raise ValueError(f"PERT bounds invalid: min={a} > max={b}")
    if not (a <= m <= b):
        raise ValueError(f"PERT mode must be within [min,max]: min={a}, mode={m}, max={b}")
    if np.isclose(b, a):
        # Degenerate
        if size is None:
            return np.array(m, dtype=float)
        return np.full(size, m, dtype=float)

    alpha = 1.0 + shape_lambda * (m - a) / (b - a)
    beta  = 1.0 + shape_lambda * (b - m) / (b - a)

    x = rng.beta(alpha, beta, size=size)
    return a + x * (b - a)


# ----------------------------
# Polar bear model builder (SBS; P=1)
# ----------------------------

@dataclass(frozen=True)
class PolarBearPrimitives:
    # Sex ratio at recruitment
    alpha: float

    # Female primitives
    s_f1: float                 # F1 (cub) -> F2 (subadult)
    s_f2: float                 # F2 (subadult) -> F3 (solitary adult)
    s_adult_prim: float         # adult survival (primitive)
    s_mother_COY_prim: float    # COY-mother survival (primitive)
    s_mother_YRL_prim: float    # YRL-mother survival (primitive)
    p_breed: float              # breeding probability (solitary adult)
    s_COY: float                # dependent COY survival
    s_YRL: float                # dependent yearling survival (reported; not used in mother-only bookkeeping)
    litter_size: float          # litter size (mean)

    # Male primitives (proxy; baseline lambda governed by females)
    s_m1: float                 # M1 -> M2
    s_m2: float                 # M2 -> M3
    s_m3: float                 # M3 -> M3 retention


def build_polar_bear_blocks(pr: PolarBearPrimitives) -> Dict[str, np.ndarray]:
    """
    Build U_f, U_m, fecundity vector f (total recruits to next census, both sexes),
    and baseline matrices A_f, R (female recruitment matrix), B (male recruit block), L.
    """
    # Unpack
    alpha = pr.alpha

    # Derived fecundity (as in manuscript): b3 = p_breed * litter_size
    b3_total_recruits = pr.p_breed * pr.litter_size

    # Female stages: F1 cub, F2 subadult, F3 solitary adult, F4 COY-mother, F5 YRL-mother
    nf = 5
    U_f = np.zeros((nf, nf), dtype=float)

    # Primitive survivals / early transitions
    U_f[1, 0] = pr.s_f1  # F1 -> F2
    U_f[2, 1] = pr.s_f2  # F2 -> F3

    # Composite caregiver loop (mother-only bookkeeping; matches manuscript mapping)
    u33 = pr.s_adult_prim * (1.0 - pr.p_breed)     # F3 -> F3
    u43 = pr.s_adult_prim * pr.p_breed             # F3 -> F4
    u54 = pr.s_mother_COY_prim * pr.s_COY          # F4 -> F5
    u34 = pr.s_mother_COY_prim * (1.0 - pr.s_COY)  # F4 -> F3
    u35 = pr.s_mother_YRL_prim                     # F5 -> F3 (independent of s_YRL under mother-only bookkeeping)

    U_f[2, 2] = u33
    U_f[3, 2] = u43
    U_f[4, 3] = u54
    U_f[2, 3] = u34
    U_f[2, 4] = u35

    # Male stages: M1 juvenile, M2 subadult, M3 adult
    nm = 3
    U_m = np.zeros((nm, nm), dtype=float)
    U_m[1, 0] = pr.s_m1
    U_m[2, 1] = pr.s_m2
    U_m[2, 2] = pr.s_m3

    # Fecundity vector f: total recruits (both sexes) to next census per female stage
    f = np.zeros(nf, dtype=float)
    f[2] = b3_total_recruits  # only solitary adults reproduce in this baseline

    # Recruitment placement
    e1f = np.zeros(nf, dtype=float); e1f[0] = 1.0
    e1m = np.zeros(nm, dtype=float); e1m[0] = 1.0
    R = alpha * np.outer(e1f, f)          # female recruits enter F1
    B = (1.0 - alpha) * np.outer(e1m, f)  # male recruits enter M1

    A_f = U_f + R

    # Full baseline L (two-sex, block lower-triangular)
    L = np.zeros((nf + nm, nf + nm), dtype=float)
    L[:nf, :nf] = A_f
    L[nf:, :nf] = B
    L[nf:, nf:] = U_m

    return {
        "U_f": U_f,
        "U_m": U_m,
        "f": f,
        "R": R,
        "B": B,
        "A_f": A_f,
        "L": L,
    }


def compute_lambda_and_R0(blocks: Dict[str, np.ndarray]) -> Dict[str, float]:
    """
    Compute caregiver-baseline metrics for P=1:
      - lambda = rho(L) (equals rho(A_f) when female block dominates)
      - R0 = rho( K ), where K = R (I - U_f)^(-1)
    """
    U_f = blocks["U_f"]
    U_m = blocks["U_m"]
    R   = blocks["R"]
    A_f = blocks["A_f"]
    L   = blocks["L"]

    rho_Uf = spectral_radius(U_f)
    rho_Um = spectral_radius(U_m)
    rho_Af = spectral_radius(A_f)

    # Full two-sex baseline (block triangular): rho(L) = max(rho(A_f), rho(U_m))
    lam = spectral_radius(L)

    # Next-generation operator: K = R (I - U_f)^(-1)
    I = np.eye(U_f.shape[0], dtype=float)
    # Solve for K without explicitly forming the inverse:
    # K = R @ (I - U_f)^(-1)  <=>  K^T solves (I - U_f)^T K^T = R^T
    K = np.linalg.solve((I - U_f).T, R.T).T
    R0 = spectral_radius(K)

    return {
        "rho_Uf": rho_Uf,
        "rho_Um": rho_Um,
        "rho_Af": rho_Af,
        "lambda": lam,
        "R0": R0,
    }


# ----------------------------
# Uncertainty config I/O
# ----------------------------

def try_import_yaml():
    try:
        import yaml  # type: ignore
        return yaml
    except Exception:
        return None

def load_uncertainty_config(path: Path) -> Dict[str, Any]:
    """
    Load uncertainty config from YAML or JSON.
    Expected schema:
      {
        "B": 5000,
        "seed": 12345,
        "shape_lambda": 4.0,
        "distributions": {
            "s_f1": {"min":..., "mode":..., "max":...},
            ...
        }
      }
    """
    if not path.exists():
        raise FileNotFoundError(str(path))

    if path.suffix.lower() in [".yaml", ".yml"]:
        yaml = try_import_yaml()
        if yaml is None:
            raise RuntimeError("PyYAML not available; either install pyyaml or use JSON config.")
        with path.open("r", encoding="utf-8") as f:
            return yaml.safe_load(f)
    elif path.suffix.lower() == ".json":
        with path.open("r", encoding="utf-8") as f:
            return json.load(f)
    else:
        raise ValueError(f"Unsupported config format: {path.suffix}")

def save_uncertainty_config(config: Dict[str, Any], yaml_path: Path, json_path: Path) -> None:
    """
    Save config as YAML (if PyYAML available) and always as JSON.
    """
    ensure_dir(yaml_path.parent)
    ensure_dir(json_path.parent)

    # Always write JSON
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(config, f, indent=2, sort_keys=True)

    # Write YAML if available
    yaml = try_import_yaml()
    if yaml is not None:
        with yaml_path.open("w", encoding="utf-8") as f:
            yaml.safe_dump(config, f, sort_keys=False)


def default_uncertainty_config_template() -> Dict[str, Any]:
    """
    Template PERT ranges. Replace min/mode/max with your exact values
    if you want to reproduce your earlier table exactly.
    """
    return {
        "system": "polar_bear_SBS",
        "period_P": 1,
        "B": 5000,
        "seed": 12345,
        "shape_lambda": 4.0,  # common PERT choice; set to your run's value
        "note": "Edit distributions min/mode/max to match your archived uncertainty_config_used.yaml.",
        "distributions": {
            # Female primitives (probabilities)
            "s_f1":              {"min": 0.45, "mode": 0.62, "max": 0.75},
            "s_f2":              {"min": 0.75, "mode": 0.88, "max": 0.95},
            "s_adult_prim":      {"min": 0.90, "mode": 0.93, "max": 0.97},
            "s_mother_COY_prim": {"min": 0.90, "mode": 0.93, "max": 0.97},
            "s_mother_YRL_prim": {"min": 0.90, "mode": 0.93, "max": 0.97},
            "p_breed":           {"min": 0.25, "mode": 0.45, "max": 0.70},
            "s_COY":             {"min": 0.40, "mode": 0.65, "max": 0.85},

            # Reported but not used in U_f (mother-only bookkeeping); keep fixed if desired
            # If you sampled it previously, include it here.
            "s_YRL":             {"min": 0.60, "mode": 0.80, "max": 0.90},

            # Litter size (positive; not necessarily bounded by 1)
            "litter_size":       {"min": 1.20, "mode": 1.60, "max": 2.00},
        },
        # Fixed parameters (not sampled)
        "fixed": {
            "alpha": 0.5,
            "s_m1": 0.90,
            "s_m2": 0.90,
            "s_m3": 0.96,
        },
        # Baseline modes (for audit)
        "baseline_modes": {
            "alpha": 0.5,
            "s_f1": 0.62,
            "s_f2": 0.88,
            "s_adult_prim": 0.93,
            "s_mother_COY_prim": 0.93,
            "s_mother_YRL_prim": 0.93,
            "p_breed": 0.45,
            "s_COY": 0.65,
            "s_YRL": 0.80,
            "litter_size": 1.60,
            "s_m1": 0.90,
            "s_m2": 0.90,
            "s_m3": 0.96,
        },
    }


# ----------------------------
# Monte Carlo driver
# ----------------------------

def run_uncertainty(config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Run Monte Carlo sampling and return:
      - draws DataFrame (as dict for JSON)
      - summary stats for lambda and R0
      - additional diagnostics
    """
    B = int(config.get("B", 5000))
    seed = int(config.get("seed", 12345))
    shape_lambda = float(config.get("shape_lambda", 4.0))

    dists: Dict[str, Dict[str, float]] = config["distributions"]
    fixed: Dict[str, float] = {k: safe_float(v) for k, v in config.get("fixed", {}).items()}

    rng = np.random.default_rng(seed)

    rows = []
    bad_subcritical_count = 0
    mismatch_equivalence = 0

    for draw_id in range(1, B + 1):
        # Start with fixed values
        prm = dict(fixed)

        # Draw PERT samples
        for name, spec in dists.items():
            a = safe_float(spec["min"])
            m = safe_float(spec["mode"])
            b = safe_float(spec["max"])
            prm[name] = float(beta_pert_sample(rng, a, m, b, shape_lambda))

        # Build primitives
        pr = PolarBearPrimitives(
            alpha=prm["alpha"],
            s_f1=prm["s_f1"],
            s_f2=prm["s_f2"],
            s_adult_prim=prm["s_adult_prim"],
            s_mother_COY_prim=prm["s_mother_COY_prim"],
            s_mother_YRL_prim=prm["s_mother_YRL_prim"],
            p_breed=prm["p_breed"],
            s_COY=prm["s_COY"],
            s_YRL=prm["s_YRL"],
            litter_size=prm["litter_size"],
            s_m1=prm["s_m1"],
            s_m2=prm["s_m2"],
            s_m3=prm["s_m3"],
        )

        blocks = build_polar_bear_blocks(pr)
        mets = compute_lambda_and_R0(blocks)

        # Diagnostics: subcritical rho(U_f)<1 requirement for K validity (P=1)
        if mets["rho_Uf"] >= 1.0:
            bad_subcritical_count += 1

        # If you rely on threshold equivalence draw-by-draw, test it:
        # (Under subcritical rho(U_f)<1, for P=1 caregiver baseline, R0>1 iff lambda>1.)
        # In practice floating-point tolerances can cause rare mismatches near 1.
        lam_gt = mets["lambda"] > 1.0
        r0_gt  = mets["R0"] > 1.0
        if lam_gt != r0_gt:
            mismatch_equivalence += 1

        # Record draw
        row = {
            "draw_id": draw_id,
            # primitives drawn
            "alpha": pr.alpha,
            "s_f1": pr.s_f1,
            "s_f2": pr.s_f2,
            "s_adult_prim": pr.s_adult_prim,
            "s_mother_COY_prim": pr.s_mother_COY_prim,
            "s_mother_YRL_prim": pr.s_mother_YRL_prim,
            "p_breed": pr.p_breed,
            "s_COY": pr.s_COY,
            "s_YRL": pr.s_YRL,  # reported / optional
            "litter_size": pr.litter_size,
            # male proxies
            "s_m1": pr.s_m1,
            "s_m2": pr.s_m2,
            "s_m3": pr.s_m3,
            # metrics
            "rho_Uf": mets["rho_Uf"],
            "rho_Um": mets["rho_Um"],
            "rho_Af": mets["rho_Af"],
            "lambda": mets["lambda"],
            "R0": mets["R0"],
        }
        rows.append(row)

    df = pd.DataFrame(rows)

    def summarize(series: pd.Series) -> Dict[str, float]:
        x = series.to_numpy(dtype=float)
        return {
            "mean": float(np.mean(x)),
            "sd": float(np.std(x, ddof=1)),
            "q025": float(np.quantile(x, 0.025)),
            "q50": float(np.quantile(x, 0.50)),
            "q975": float(np.quantile(x, 0.975)),
            "pr_gt_1": float(np.mean(x > 1.0)),
        }

    summ_lambda = summarize(df["lambda"])
    summ_R0 = summarize(df["R0"])

    summary = {
        "system": config.get("system", "polar_bear_SBS"),
        "period_P": int(config.get("period_P", 1)),
        "B": B,
        "seed": seed,
        "shape_lambda": shape_lambda,
        "summary": {
            "lambda": summ_lambda,
            "R0": summ_R0,
        },
        "diagnostics": {
            "count_rho_Uf_ge_1": int(bad_subcritical_count),
            "count_threshold_mismatch": int(mismatch_equivalence),
        },
    }

    return {
        "draws_df": df,
        "summary": summary,
    }


# ----------------------------
# LaTeX writers (table + macros)
# ----------------------------

def write_latex_uncertainty_table(summary: Dict[str, Any], out_tex_path: Path, nd: int = 4) -> None:
    """
    Write a LaTeX tabular body matching your manuscript \input usage:
      out_full_with_uncertaintypolarbear/tables/pb_uncertainty_metrics.tex
    """
    ensure_dir(out_tex_path.parent)

    sL = summary["summary"]["lambda"]
    sR = summary["summary"]["R0"]

    lines = []
    lines.append(r"\begin{tabular}{@{}lrrrrr r@{}}")
    lines.append(r"\toprule")
    lines.append(r"Metric & Mean & SD & 2.5\% & 50\% & 97.5\% & $\Pr(>1)$\\")
    lines.append(r"\midrule")
    lines.append(
        rf"$\lambda$ & {fmt(sL['mean'], nd)} & {fmt(sL['sd'], nd)} & {fmt(sL['q025'], nd)} & {fmt(sL['q50'], nd)} & {fmt(sL['q975'], nd)} & {fmt(sL['pr_gt_1'], nd)}\\"
    )
    lines.append(
        rf"$R_0$ & {fmt(sR['mean'], nd)} & {fmt(sR['sd'], nd)} & {fmt(sR['q025'], nd)} & {fmt(sR['q50'], nd)} & {fmt(sR['q975'], nd)} & {fmt(sR['pr_gt_1'], nd)}\\"
    )
    lines.append(r"\bottomrule")
    lines.append(r"\end{tabular}")
    lines.append("")  # trailing newline

    out_tex_path.write_text("\n".join(lines), encoding="utf-8")


def write_uncertainty_macros(summary: Dict[str, Any], out_macro_path: Path, nd: int = 6) -> None:
    """
    Write LaTeX macros compatible with your \PBset/\PBuse wrapper and your manuscript defs:
      \PBset{pbLambdaP025}{...}, etc.
    """
    ensure_dir(out_macro_path.parent)

    sL = summary["summary"]["lambda"]
    sR = summary["summary"]["R0"]

    # NOTE: You can add additional macros here if you want mean/sd in TeX.
    # Your manuscript already expects:
    #   pbLambdaP025, pbLambdaP50, pbLambdaP975,
    #   pbRzeroP025, pbRzeroP50, pbRzeroP975,
    #   pbPrLambdaGTone, pbPrRzeroGTone
    lines = []
    lines.append("% Auto-generated by polar bear uncertainty script")
    lines.append("% Do not edit by hand; edit the Python config and regenerate.")
    lines.append("")

    lines.append(rf"\PBset{{pbLambdaP025}}{{{fmt(sL['q025'], nd)}}}")
    lines.append(rf"\PBset{{pbLambdaP50}}{{{fmt(sL['q50'], nd)}}}")
    lines.append(rf"\PBset{{pbLambdaP975}}{{{fmt(sL['q975'], nd)}}}")
    lines.append(rf"\PBset{{pbPrLambdaGTone}}{{{fmt(sL['pr_gt_1'], nd)}}}")

    lines.append(rf"\PBset{{pbRzeroP025}}{{{fmt(sR['q025'], nd)}}}")
    lines.append(rf"\PBset{{pbRzeroP50}}{{{fmt(sR['q50'], nd)}}}")
    lines.append(rf"\PBset{{pbRzeroP975}}{{{fmt(sR['q975'], nd)}}}")
    lines.append(rf"\PBset{{pbPrRzeroGTone}}{{{fmt(sR['pr_gt_1'], nd)}}}")

    lines.append("")  # trailing newline
    out_macro_path.write_text("\n".join(lines), encoding="utf-8")


# ----------------------------
# Main (Colab-friendly)
# ----------------------------

def main() -> None:
    # ---- Output roots (match your manuscript paths) ----
    out_root = Path("out")
    out_unc = out_root / "uncertainty"
    out_params = out_root / "params"

    out_full = Path("out_full_with_uncertaintypolarbear")
    out_tables = out_full / "tables"
    out_tex = out_full / "tex"
    out_summary = out_full / "summary"

    ensure_dir(out_unc)
    ensure_dir(out_params)
    ensure_dir(out_tables)
    ensure_dir(out_tex)
    ensure_dir(out_summary)

    # ---- Load config if you have it; otherwise use template ----
    # If you already have your archived config file, put it at one of these paths:
    #   out/params/uncertainty_config_used.yaml
    #   out/params/uncertainty_config_used.json
    #   uncertainty_config_used.yaml
    #   uncertainty_config_used.json
    candidate_paths = [
        out_params / "uncertainty_config_used.yaml",
        out_params / "uncertainty_config_used.json",
        Path("uncertainty_config_used.yaml"),
        Path("uncertainty_config_used.json"),
    ]

    config = None
    for p in candidate_paths:
        if p.exists():
            config = load_uncertainty_config(p)
            break

    if config is None:
        config = default_uncertainty_config_template()

    # ---- Run uncertainty ----
    result = run_uncertainty(config)
    df_draws: pd.DataFrame = result["draws_df"]
    summary: Dict[str, Any] = result["summary"]

    # ---- Save draw-level outputs ----
    draws_csv_path = out_unc / "pb_draws.csv"
    df_draws.to_csv(draws_csv_path, index=False)

    # ---- Save summary JSON ----
    summary_json_path = out_summary / "UNCERTAINTY_SUMMARY_polar_bear.json"
    with summary_json_path.open("w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2, sort_keys=True)

    # ---- Save config used (YAML if possible; always JSON) ----
    config_used_yaml = out_params / "uncertainty_config_used.yaml"
    config_used_json = out_params / "uncertainty_config_used.json"
    save_uncertainty_config(config, config_used_yaml, config_used_json)

    # ---- Write LaTeX table and macros expected by your manuscript ----
    table_tex_path = out_tables / "pb_uncertainty_metrics.tex"
    macros_tex_path = out_tex / "uncertainty_macros_polar_bear.tex"

    write_latex_uncertainty_table(summary, table_tex_path, nd=4)
    write_uncertainty_macros(summary, macros_tex_path, nd=6)

    # ---- Print a compact console summary (optional) ----
    sL = summary["summary"]["lambda"]
    sR = summary["summary"]["R0"]
    print("=== Polar bear uncertainty (caregiver baseline; P=1) ===")
    print(f"B={summary['B']}  seed={summary['seed']}  shape_lambda={summary['shape_lambda']}")
    print("")
    print("lambda summary:")
    print(f"  mean={sL['mean']:.6f}, sd={sL['sd']:.6f}, q025={sL['q025']:.6f}, q50={sL['q50']:.6f}, q975={sL['q975']:.6f}, Pr(>1)={sL['pr_gt_1']:.6f}")
    print("R0 summary:")
    print(f"  mean={sR['mean']:.6f}, sd={sR['sd']:.6f}, q025={sR['q025']:.6f}, q50={sR['q50']:.6f}, q975={sR['q975']:.6f}, Pr(>1)={sR['pr_gt_1']:.6f}")
    print("")
    print("Diagnostics:")
    print(f"  count rho(U_f) >= 1: {summary['diagnostics']['count_rho_Uf_ge_1']}")
    print(f"  count threshold mismatches (lambda>1 vs R0>1): {summary['diagnostics']['count_threshold_mismatch']}")
    print("")
    print("Wrote:")
    print(f"  {draws_csv_path}")
    print(f"  {table_tex_path}")
    print(f"  {macros_tex_path}")
    print(f"  {summary_json_path}")
    print(f"  {config_used_yaml}  (if PyYAML available)")
    print(f"  {config_used_json}")


# Run in notebook / Colab
if __name__ == "__main__":
    main()


  Write a LaTeX tabular body matching your manuscript \input usage:
  Write LaTeX macros compatible with your \PBset/\PBuse wrapper and your manuscript defs:


=== Polar bear uncertainty (caregiver baseline; P=1) ===
B=5000  seed=12345  shape_lambda=4.0

lambda summary:
  mean=1.040094, sd=0.018490, q025=1.006436, q50=1.039642, q975=1.077403, Pr(>1)=0.990600
R0 summary:
  mean=1.711993, sd=0.379889, q025=1.095210, q50=1.673300, q975=2.539195, Pr(>1)=0.990600

Diagnostics:
  count rho(U_f) >= 1: 0
  count threshold mismatches (lambda>1 vs R0>1): 0

Wrote:
  out/uncertainty/pb_draws.csv
  out_full_with_uncertaintypolarbear/tables/pb_uncertainty_metrics.tex
  out_full_with_uncertaintypolarbear/tex/uncertainty_macros_polar_bear.tex
  out_full_with_uncertaintypolarbear/summary/UNCERTAINTY_SUMMARY_polar_bear.json
  out/params/uncertainty_config_used.yaml  (if PyYAML available)
  out/params/uncertainty_config_used.json


In [4]:
# ============================================================
# Prairie dog (USGS; P=12) replication script (Google Colab-ready)
# Caregiver (female-driven) baseline + seasonal (periodic) R0_per
#
# Manuscript linkage:
#   - Data ingestion and pooling rule: Methods Section X.X
#   - Monthly caregiver-baseline model (newborns not a state): Section X.X / Appendix Y
#   - Neutrality calibration (rho(M)=lambda^12=1): Section X.X
#   - Periodic next-generation operator K and R0_per: Appendix Z
#
# Workflow implemented:
#   - Ingest USGS survival + reproduction CSVs (Control/NONE treatment)
#   - If survival CSV has no month field: pooled survival -> per-month via s_month = s_pooled^(1/m)
#   - Build monthly two-sex caregiver-baseline matrices (newborns not a state)
#   - Calibrate breeding-window fecundity scale s so that rho(M)=1 (annual neutrality)
#   - Compute periodic next-generation operator K and R0_per = rho(K)
#   - Produce figures under out/figs and audit artifacts under out/
#
# Inputs expected (upload to Colab runtime):
#   - Survival CSV: columns {TREATMENT, AGE, SURVIVE} with AGE in {J,A}; MONTH optional
#   - Reproduction CSV: columns {TREATMENT, JUVENILES, ADULTS}
# ============================================================


from __future__ import annotations

import json
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ----------------------------
# Optional YAML support
# ----------------------------

def _try_import_yaml():
    try:
        import yaml  # type: ignore
        return yaml
    except Exception:
        return None


# ----------------------------
# Helpers
# ----------------------------

def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)

def spectral_radius(A: np.ndarray) -> float:
    eigvals = np.linalg.eigvals(A)
    return float(np.max(np.abs(eigvals)))

def perron_right(A: np.ndarray) -> np.ndarray:
    """Return a nonnegative (approximately) Perron right eigenvector for a nonnegative matrix."""
    vals, vecs = np.linalg.eig(A)
    idx = int(np.argmax(np.abs(vals)))
    v = np.real(vecs[:, idx])
    v = np.abs(v)
    if np.allclose(v, 0):
        v = np.ones(A.shape[0], dtype=float)
    return v

def induced_matrix_1norm(A: np.ndarray) -> float:
    """Induced matrix 1-norm (max column sum), matching the paper's ||T_k||_1 usage."""
    return float(np.linalg.norm(A, 1))

def moving_average(x: np.ndarray, window: int = 3) -> np.ndarray:
    if window <= 1:
        return x.copy()
    w = np.ones(window, dtype=float) / window
    return np.convolve(x, w, mode="same")


# ----------------------------
# Auto-detect input CSVs
# ----------------------------

def _lower_set(cols: List[str]) -> set:
    return {c.strip().lower() for c in cols}

def detect_survival_vs_reproduction_csv(csv_paths: List[Path]) -> Tuple[Path, Path]:
    """
    Heuristic:
      - survival CSV contains {age, survive} columns
      - reproduction CSV contains {juveniles, adults} columns
    """
    survival_candidates = []
    repro_candidates = []

    for p in csv_paths:
        try:
            df = pd.read_csv(p, nrows=5)
        except Exception:
            continue
        cols = _lower_set(list(df.columns))

        has_age = any("age" == c or "age_" in c or c.endswith("age") for c in cols)
        has_survive = any("survive" in c or "survival" in c for c in cols)

        has_juv = any("juvenile" in c or "juv" in c for c in cols)
        has_adult = any("adult" in c for c in cols)

        if has_age and has_survive:
            survival_candidates.append(p)
        if has_juv and has_adult:
            repro_candidates.append(p)

    if len(survival_candidates) != 1 or len(repro_candidates) != 1:
        msg = [
            "Could not uniquely identify survival and reproduction CSVs.",
            f"Found CSVs: {[str(p) for p in csv_paths]}",
            f"Survival candidates: {[str(p) for p in survival_candidates]}",
            f"Reproduction candidates: {[str(p) for p in repro_candidates]}",
            "",
            "Fix by setting CONFIG.survival_csv and CONFIG.reproduction_csv explicitly.",
        ]
        raise RuntimeError("\n".join(msg))

    return survival_candidates[0], repro_candidates[0]


# ----------------------------
# Config
# ----------------------------

@dataclass
class PrairieDogConfig:
    # If None, auto-detect from *.csv in current working directory.
    survival_csv: Optional[str] = None
    reproduction_csv: Optional[str] = None

    # Treatment label used as Control baseline in the paper
    treatment_label: str = "NONE"

    # Breeding months (Apr-Jun) as month indices 1..12
    breeding_months: Tuple[int, ...] = (4, 5, 6)

    # If survival CSV has no month field: interpret SURVIVE as interval survival over m months
    interval_months_when_month_missing: int = 12

    # Sex ratio at recruitment
    alpha: float = 0.5

    # Yearling fecundity fraction: b2,k = theta * b3,k in breeding months (default 0)
    theta_yearling_fraction: float = 0.0

    # Projection horizon
    projection_months: int = 30

    # Invasion curve settings (Beverton–Holt male availability)
    invasion_hm: float = 1.0
    invasion_eta_max: float = 10.0
    invasion_eta_n: int = 250

    # Numerical tolerance for neutrality calibration
    bisect_iters: int = 80
    bisect_tol: float = 1e-12


CONFIG = PrairieDogConfig()


# ----------------------------
# Ingestion: survival CSV
# ----------------------------

def find_column(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
    cols = list(df.columns)
    lower = {c.lower(): c for c in cols}
    for cand in candidates:
        for k, orig in lower.items():
            if cand in k:
                return orig
    return None

def ingest_survival(
    path: Path,
    treatment_label: str,
    interval_m: int = 12,
) -> Dict[str, Any]:
    """
    Implements the paper's rule:
      - If month field exists: compute monthwise mean SURVIVE by AGE class.
      - If month missing: pooled survival by AGE class, then convert to per-month via s_month = s_pooled^(1/m),
        and replicate across months.
    """
    df = pd.read_csv(path)

    # Detect columns
    col_treat = find_column(df, ["treat"])
    col_age = find_column(df, ["age"])
    col_surv = find_column(df, ["survive", "survival"])

    # Optional month field
    col_month = find_column(df, ["month"])

    if col_treat is None or col_age is None or col_surv is None:
        raise RuntimeError(
            "Survival CSV missing required columns. "
            f"Detected treat={col_treat}, age={col_age}, survive={col_surv}. "
            f"Columns={list(df.columns)}"
        )

    # Filter to treatment
    tr = df[col_treat].astype(str).str.upper().str.strip()
    df = df.loc[tr.eq(str(treatment_label).upper().strip())].copy()

    # Standardize AGE
    df[col_age] = df[col_age].astype(str).str.upper().str.strip()
    # Standardize survive to numeric 0/1
    df[col_surv] = pd.to_numeric(df[col_surv], errors="coerce")

    # Valid ages: J (juvenile) and A (adult), per your paper extract
    df = df.loc[df[col_age].isin(["J", "A"])].copy()

    if df.empty:
        raise RuntimeError(f"No rows left after filtering treatment={treatment_label} and AGE in {{J,A}}.")

    out = {
        "source_path": str(path),
        "treatment_label": str(treatment_label),
        "detected_columns": {
            "treatment": col_treat,
            "age": col_age,
            "survive": col_surv,
            "month": col_month,
        },
        "month_field_present": col_month is not None,
        "interval_m_used_if_pooled": int(interval_m),
    }

    if col_month is not None:
        # Monthwise aggregation (expects month index 1..12)
        df[col_month] = pd.to_numeric(df[col_month], errors="coerce")
        df = df.loc[df[col_month].between(1, 12, inclusive="both")].copy()

        g = df.groupby([col_month, col_age])[col_surv].agg(["mean", "count"]).reset_index()
        # Initialize arrays
        sJ = np.full(12, np.nan, dtype=float)
        sA = np.full(12, np.nan, dtype=float)
        nJ = np.zeros(12, dtype=int)
        nA = np.zeros(12, dtype=int)

        for _, r in g.iterrows():
            m = int(r[col_month])  # 1..12
            age = str(r[col_age])
            if age == "J":
                sJ[m-1] = float(r["mean"])
                nJ[m-1] = int(r["count"])
            elif age == "A":
                sA[m-1] = float(r["mean"])
                nA[m-1] = int(r["count"])

        # Basic validation
        if np.any(np.isnan(sJ)) or np.any(np.isnan(sA)):
            raise RuntimeError("Month field present but some months missing AGE=J or AGE=A survival estimates.")
        out.update({
            "sJ_month": sJ.tolist(),
            "sA_month": sA.tolist(),
            "nJ_month": nJ.tolist(),
            "nA_month": nA.tolist(),
            "pooled": None,
        })
        return out

    # Pooled survival -> per-month conversion s_month = s_pooled^(1/m)
    pooled = df.groupby(col_age)[col_surv].agg(["mean", "count"]).reset_index()
    pooled_dict = {}
    for _, r in pooled.iterrows():
        pooled_dict[str(r[col_age])] = {"mean": float(r["mean"]), "count": int(r["count"])}

    if "J" not in pooled_dict or "A" not in pooled_dict:
        raise RuntimeError("Expected pooled survival for AGE=J and AGE=A.")

    sJ_pooled = pooled_dict["J"]["mean"]
    sA_pooled = pooled_dict["A"]["mean"]
    nJ = pooled_dict["J"]["count"]
    nA = pooled_dict["A"]["count"]

    # Conversion rule used in your manuscript:
    sJ_month = float(sJ_pooled ** (1.0 / interval_m))
    sA_month = float(sA_pooled ** (1.0 / interval_m))

    sJ = np.full(12, sJ_month, dtype=float)
    sA = np.full(12, sA_month, dtype=float)

    out.update({
        "sJ_month": sJ.tolist(),
        "sA_month": sA.tolist(),
        "nJ_month": [nJ]*12,
        "nA_month": [nA]*12,
        "pooled": {
            "sJ_pooled": float(sJ_pooled),
            "sA_pooled": float(sA_pooled),
            "nJ": int(nJ),
            "nA": int(nA),
            "sJ_month_converted": float(sJ_month),
            "sA_month_converted": float(sA_month),
        }
    })
    return out


# ----------------------------
# Ingestion: reproduction CSV
# ----------------------------

def ingest_reproduction(
    path: Path,
    treatment_label: str,
    breeding_months: Tuple[int, ...],
) -> Dict[str, Any]:
    """
    Implements the paper's rule:
      - Aggregate (sum) JUVENILES and ADULTS across all rows for treatment.
      - ratio = JUVENILES_tot / ADULTS_tot
      - base monthly adult fecundity in Apr-Jun: b3_base = ratio / 3
    """
    df = pd.read_csv(path)

    col_treat = find_column(df, ["treat"])
    col_juv = find_column(df, ["juvenile", "juv"])
    col_ad = find_column(df, ["adult"])

    if col_treat is None or col_juv is None or col_ad is None:
        raise RuntimeError(
            "Reproduction CSV missing required columns. "
            f"Detected treat={col_treat}, juveniles={col_juv}, adults={col_ad}. "
            f"Columns={list(df.columns)}"
        )

    tr = df[col_treat].astype(str).str.upper().str.strip()
    df = df.loc[tr.eq(str(treatment_label).upper().strip())].copy()
    if df.empty:
        raise RuntimeError(f"No reproduction rows left after filtering treatment={treatment_label}.")

    df[col_juv] = pd.to_numeric(df[col_juv], errors="coerce")
    df[col_ad] = pd.to_numeric(df[col_ad], errors="coerce")

    juv_tot = float(df[col_juv].sum())
    ad_tot = float(df[col_ad].sum())
    if ad_tot <= 0:
        raise RuntimeError("Total ADULTS must be positive to form juveniles/adults ratio.")

    ratio = juv_tot / ad_tot
    b3_base = ratio / float(len(breeding_months))

    # Monthly base fecundities (length 12)
    b3_base_k = np.zeros(12, dtype=float)
    for k in breeding_months:
        b3_base_k[k-1] = b3_base

    return {
        "source_path": str(path),
        "treatment_label": str(treatment_label),
        "detected_columns": {
            "treatment": col_treat,
            "juveniles": col_juv,
            "adults": col_ad,
        },
        "juveniles_total": juv_tot,
        "adults_total": ad_tot,
        "ratio_juveniles_over_adults": float(ratio),
        "breeding_months": list(breeding_months),
        "b3_base_per_breeding_month": float(b3_base),
        "b3_base_monthly_vector": b3_base_k.tolist(),
    }


# ----------------------------
# Model builder (monthly; P=12)
# ----------------------------

def build_monthly_blocks(
    s_scale: float,
    sJ_month: np.ndarray,
    sA_month: np.ndarray,
    b3_base_monthly: np.ndarray,
    alpha: float,
    theta: float,
) -> Dict[str, Any]:
    """
    Monthly 3-stage model (female + male), with to-from (column-vector) convention:

      Female stages: F1 juvenile, F2 yearling, F3 adult
      U_f,k:
        (2,1) = sJ_k
        (3,2) = sA_k  (proxy yearling->adult)
        (3,3) = sA_k  (adult retention)

      Fecundity vector f_k = [0, b2_k, b3_k]^T (total recruits to next census, both sexes)
        b3_k = s_scale * b3_base_k in breeding months
        b2_k = theta * b3_k in breeding months, else 0

      Male block U_m,k is set equal to U_f,k (sex not disaggregated).
    """
    P = 12
    nf = 3
    nm = 3

    e1f = np.zeros(nf, dtype=float); e1f[0] = 1.0
    e1m = np.zeros(nm, dtype=float); e1m[0] = 1.0

    U_f_list = []
    U_m_list = []
    f_list = []
    alpha_list = []
    A_f_list = []
    R_list = []
    B_list = []
    L_list = []

    for k in range(1, P+1):
        U_f = np.zeros((nf, nf), dtype=float)
        U_f[1, 0] = float(sJ_month[k-1])
        U_f[2, 1] = float(sA_month[k-1])
        U_f[2, 2] = float(sA_month[k-1])

        U_m = U_f.copy()

        # Fecundity (both sexes)
        b3_k = float(s_scale) * float(b3_base_monthly[k-1])
        b2_k = float(theta) * b3_k if b3_k > 0 else 0.0
        f_k = np.array([0.0, b2_k, b3_k], dtype=float)

        alpha_k = float(alpha)

        R_k = alpha_k * np.outer(e1f, f_k)           # female recruits into F1
        B_k = (1.0 - alpha_k) * np.outer(e1m, f_k)   # male recruits into M1
        A_f_k = U_f + R_k

        L_k = np.zeros((nf+nm, nf+nm), dtype=float)
        L_k[:nf, :nf] = A_f_k
        L_k[nf:, :nf] = B_k
        L_k[nf:, nf:] = U_m

        U_f_list.append(U_f)
        U_m_list.append(U_m)
        f_list.append(f_k)
        alpha_list.append(alpha_k)
        A_f_list.append(A_f_k)
        R_list.append(R_k)
        B_list.append(B_k)
        L_list.append(L_k)

    return {
        "P": P,
        "nf": nf,
        "nm": nm,
        "stages_female": ["F1 (juvenile)", "F2 (yearling)", "F3 (adult)"],
        "stages_male": ["M1 (juvenile)", "M2 (subadult)", "M3 (adult)"],
        "U_f_list": U_f_list,
        "U_m_list": U_m_list,
        "f_list": f_list,
        "alpha_list": alpha_list,
        "R_list": R_list,
        "B_list": B_list,
        "A_f_list": A_f_list,
        "L_list": L_list,
        "b3_monthly": [float(f[2]) for f in f_list],
        "b2_monthly": [float(f[1]) for f in f_list],
    }


# ----------------------------
# Monodromy (period product) and lambda
# ----------------------------

def monodromy(blocks: Dict[str, Any]) -> Dict[str, Any]:
    P = blocks["P"]
    nf = blocks["nf"]
    nm = blocks["nm"]

    M = np.eye(nf + nm, dtype=float)
    Mf = np.eye(nf, dtype=float)
    Mm = np.eye(nm, dtype=float)

    for k in range(P):
        Lk = blocks["L_list"][k]
        Afk = blocks["A_f_list"][k]
        Umk = blocks["U_m_list"][k]

        M = Lk @ M
        Mf = Afk @ Mf
        Mm = Umk @ Mm

    rho_M = spectral_radius(M)
    rho_Mf = spectral_radius(Mf)
    rho_Mm = spectral_radius(Mm)

    lam_per_month = rho_M ** (1.0 / P)

    return {
        "M": M,
        "Mf": Mf,
        "Mm": Mm,
        "rho_M": rho_M,
        "rho_Mf": rho_Mf,
        "rho_Mm": rho_Mm,
        "lambda_per_month": lam_per_month,
    }


# ----------------------------
# Neutrality calibration (solve for s such that rho(M)=1)
# ----------------------------

def calibrate_scale_s(
    sJ_month: np.ndarray,
    sA_month: np.ndarray,
    b3_base_monthly: np.ndarray,
    alpha: float,
    theta: float,
    target_rho: float = 1.0,
    iters: int = 80,
    tol: float = 1e-12,
) -> float:
    def rho_of_s(s: float) -> float:
        blk = build_monthly_blocks(s, sJ_month, sA_month, b3_base_monthly, alpha, theta)
        return monodromy(blk)["rho_M"]

    lo = 0.0
    hi = 1.0
    # Ensure bracket
    while rho_of_s(hi) < target_rho:
        hi *= 2.0
        if hi > 1e9:
            raise RuntimeError("Failed to bracket neutrality root; hi exploded.")

    for _ in range(iters):
        mid = 0.5 * (lo + hi)
        val = rho_of_s(mid)
        if abs(val - target_rho) <= tol:
            return mid
        if val >= target_rho:
            hi = mid
        else:
            lo = mid
    return hi


# ----------------------------
# Periodic next-generation operator K (Eq. Kper in your paper)
# ----------------------------

def periodic_next_generation_operator(blocks: Dict[str, Any]) -> Dict[str, Any]:
    """
    Implements:
      U_k = U_{f,k} (survival/transition only)
      R_k = alpha_k e1 f_k^T  (female recruits only)
      U_prod = U_P...U_1

      K = sum_k ( R_k * U_{k-1}...U_1 * U_P...U_{k+1} ) * (I - U_prod)^(-1)

    Computed via linear solves (no explicit inverse).
    Also returns monthly kernel terms T_k used for diagnostics.
    """
    P = blocks["P"]
    nf = blocks["nf"]
    U_list = blocks["U_f_list"]
    R_list = blocks["R_list"]

    # U_prod = U_P...U_1
    U_prod = np.eye(nf, dtype=float)
    for k in range(P):
        U_prod = U_list[k] @ U_prod

    I = np.eye(nf, dtype=float)

    # prev[k] = U_{k-1}...U_1, for k=1..P (prev[1]=I)
    prev = [None] * (P + 1)
    prev[1] = np.eye(nf, dtype=float)
    for k in range(2, P + 1):
        prev[k] = U_list[k - 2] @ prev[k - 1]

    # next[k] = U_P...U_{k+1}, for k=1..P (next[P]=I)
    nxt = [None] * (P + 1)
    nxt[P] = np.eye(nf, dtype=float)
    for k in range(P - 1, 0, -1):
        nxt[k] = nxt[k + 1] @ U_list[k]  # U_{k+1}

    # Sum terms
    H = np.zeros((nf, nf), dtype=float)
    term_list = []
    for k in range(1, P + 1):
        rot = prev[k] @ nxt[k]  # U_{k-1}...U_1 U_P...U_{k+1}
        term = R_list[k - 1] @ rot
        term_list.append(term)
        H += term

    # K = H (I - U_prod)^(-1) via solve:
    # K^T solves (I - U_prod)^T K^T = H^T
    K = np.linalg.solve((I - U_prod).T, H.T).T

    # Monthly kernel terms T_k = term_k (I - U_prod)^(-1) via solve
    T_terms = [np.linalg.solve((I - U_prod).T, term.T).T for term in term_list]

    R0_per = spectral_radius(K)

    # Norm-based diagnostics c_k ∝ ||T_k||_1
    norms = np.array([induced_matrix_1norm(T) for T in T_terms], dtype=float)
    denom = float(np.sum(norms)) if float(np.sum(norms)) > 0 else 1.0
    c_k = (norms / denom).tolist()

    return {
        "U_prod": U_prod,
        "K": K,
        "R0_per": R0_per,
        "T_terms": T_terms,
        "c_k": c_k,
        "norms": norms.tolist(),
    }


# ----------------------------
# Projections (30 months) and newborn output
# ----------------------------

def project_months(blocks: Dict[str, Any], T: int = 30, x0: Optional[np.ndarray] = None) -> Dict[str, Any]:
    """
    Iterate x(t+1)=L_k x(t), cycling k mod 12.
    Report newborn output N(t+1)=f_k^T F(t) each month (caregiver baseline).
    """
    P = blocks["P"]
    nf = blocks["nf"]
    nm = blocks["nm"]
    L_list = blocks["L_list"]
    f_list = blocks["f_list"]

    if x0 is None:
        # Use Perron start from full monodromy to minimize transients
        M = monodromy(blocks)["M"]
        v = perron_right(M)
        # Scale to arbitrary total abundance (paper uses arbitrary units)
        total = 1000.0
        v = v / np.sum(v) * total
        x0 = v

    x = x0.copy().astype(float)
    X = np.zeros((T + 1, nf + nm), dtype=float)
    N = np.zeros(T, dtype=float)

    X[0, :] = x
    for t in range(T):
        k = (t % P)  # 0..11
        F = x[:nf]
        f_k = f_list[k]
        N[t] = float(np.dot(f_k, F))  # total recruits to next census (both sexes)
        x = L_list[k] @ x
        X[t + 1, :] = x

    return {"X": X, "N": N, "x0": x0}


# ----------------------------
# Periodic stable stage distribution table (female only)
# ----------------------------

def periodic_stable_female_distribution(blocks: Dict[str, Any]) -> np.ndarray:
    """
    Compute periodic stable stage distribution sequence for females:
      - Get Perron right eigenvector w1 of female monodromy Mf
      - Propagate within-year w_{k+1} = A_f,k w_k
      - Normalize each month to sum 1
    Returns array shape (12, nf): rows = months 1..12
    """
    P = blocks["P"]
    nf = blocks["nf"]
    Af_list = blocks["A_f_list"]
    Mf = monodromy(blocks)["Mf"]

    w1 = perron_right(Mf)
    if np.sum(w1) <= 0:
        w1 = np.ones(nf, dtype=float)
    w1 = np.abs(w1)
    # propagate
    W = np.zeros((P, nf), dtype=float)
    w = w1.copy()
    for k in range(P):
        w = np.abs(w)
        s = float(np.sum(w))
        W[k, :] = w / (s if s > 0 else 1.0)
        w = Af_list[k] @ w
    return W


# ----------------------------
# Invasion curve (Beverton–Holt)
# ----------------------------

def beverton_holt_psi(eta: np.ndarray, h: float = 1.0) -> np.ndarray:
    return eta / (h + eta)

def invasion_curve(R0_per: float, eta_max: float, n: int, h: float = 1.0) -> Tuple[np.ndarray, np.ndarray]:
    eta = np.linspace(0.0, float(eta_max), int(n))
    psi = beverton_holt_psi(eta, h=h)
    R = psi * float(R0_per)
    return eta, R

def invasion_fingerprint_table(R0_values: Dict[str, float], etas: List[float], h: float = 1.0) -> pd.DataFrame:
    rows = []
    for label, R0v in R0_values.items():
        row = {"Scenario": label}
        for e in etas:
            row[f"R(eta={e:g})"] = float(beverton_holt_psi(np.array([e]), h=h)[0] * R0v)
        rows.append(row)
    return pd.DataFrame(rows)


# ----------------------------
# Writers
# ----------------------------

def write_yaml_or_json(obj: Dict[str, Any], yaml_path: Path, json_path: Path) -> None:
    ensure_dir(yaml_path.parent)
    ensure_dir(json_path.parent)

    # JSON always
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)

    yaml = _try_import_yaml()
    if yaml is not None:
        with yaml_path.open("w", encoding="utf-8") as f:
            yaml.safe_dump(obj, f, sort_keys=False)

def write_metrics_macros(pdS: float, pdR0: float, pdLambda: float, pdRhoMm: float, out_path: Path) -> None:
    ensure_dir(out_path.parent)
    lines = []
    lines.append("% Auto-generated prairie-dog macros")
    lines.append(rf"\def\pdS{{{pdS:.6f}}}")
    lines.append(rf"\def\pdRzeroPer{{{pdR0:.6f}}}")
    lines.append(rf"\def\pdLambda{{{pdLambda:.6f}}}")
    lines.append(rf"\def\pdRhoMm{{{pdRhoMm:.6f}}}")
    lines.append("")
    out_path.write_text("\n".join(lines), encoding="utf-8")

def write_pd_periodic_ssd_tex(W: np.ndarray, out_path: Path) -> None:
    """
    Writes pd_periodic_ssd.tex compatible with your supplement \IfFileExists{pd_periodic_ssd.tex}{...}.
    """
    ensure_dir(out_path.parent)
    lines = []
    lines.append(r"\begin{tabular}{@{}lccc@{}}")
    lines.append(r"\toprule")
    lines.append(r"Month & $\hat F_{1,k}$ & $\hat F_{2,k}$ & $\hat F_{3,k}$\\")
    lines.append(r"\midrule")
    for k in range(12):
        lines.append(f"{k+1} & {W[k,0]:.6f} & {W[k,1]:.6f} & {W[k,2]:.6f}\\\\")
    lines.append(r"\bottomrule")
    lines.append(r"\end{tabular}")
    lines.append("")
    out_path.write_text("\n".join(lines), encoding="utf-8")


# ----------------------------
# Plotting
# ----------------------------

def plot_newborns(N: np.ndarray, out_path: Path) -> None:
    ensure_dir(out_path.parent)
    t = np.arange(1, len(N) + 1)
    ma = moving_average(N, window=3)
    plt.figure()
    plt.plot(t, N, label="Newborn/recruit output N(t+1)")
    plt.plot(t, ma, linestyle="--", label="3-month moving avg")
    plt.xlabel("Month (t)")
    plt.ylabel("Recruits to next census (both sexes)")
    plt.title("Prairie dog: monthly recruit output (caregiver baseline)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()

def plot_projection(X: np.ndarray, out_path: Path) -> None:
    ensure_dir(out_path.parent)
    T = X.shape[0] - 1
    t = np.arange(0, T + 1)

    # Stages: female 0..2, male 3..5
    plt.figure()
    plt.plot(t, X[:, 0], label="F1 (juvenile)")
    plt.plot(t, X[:, 1], label="F2 (yearling)")
    plt.plot(t, X[:, 2], label="F3 (adult)")

    plt.plot(t, X[:, 3], linestyle="--", label="M1 (juvenile)")
    plt.plot(t, X[:, 4], linestyle="--", label="M2 (subadult)")
    plt.plot(t, X[:, 5], linestyle="--", label="M3 (adult)")

    plt.xlabel("Month (t)")
    plt.ylabel("Abundance (arbitrary units)")
    plt.title("Prairie dog projections (caregiver baseline; P=12)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()

def plot_R0_contrib(c_k: List[float], out_path: Path) -> None:
    ensure_dir(out_path.parent)
    months = np.arange(1, 13)
    plt.figure()
    plt.bar(months, c_k)
    plt.xlabel("Month k (1..12)")
    plt.ylabel(r"Diagnostic weight $c_k$")
    plt.title(r"Prairie dog: norm-based monthly diagnostics for $R_{0,\mathrm{per}}$")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()

def plot_invasion_curve(eta: np.ndarray, R: np.ndarray, out_path: Path) -> None:
    ensure_dir(out_path.parent)
    plt.figure()
    plt.plot(eta, R)
    plt.axhline(1.0, linestyle="--")
    plt.xlabel(r"Stocking intensity $\eta$")
    plt.ylabel(r"$R_{F\mid M^\star}(\eta)$")
    plt.title("Prairie dog: female invasion curve under Beverton–Holt male availability")
    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()


# ----------------------------
# Main
# ----------------------------

def main(cfg: PrairieDogConfig) -> None:
    # Output dirs (match your paper's conventions)
    out_root = Path("out")
    out_params = out_root / "params"
    out_figs = out_root / "figs"
    out_summary = out_root / "summary"
    out_tables = out_root / "tables"
    out_tex = out_root / "tex"

    ensure_dir(out_params)
    ensure_dir(out_figs)
    ensure_dir(out_summary)
    ensure_dir(out_tables)
    ensure_dir(out_tex)

    # Resolve input files
    if cfg.survival_csv is not None and cfg.reproduction_csv is not None:
        survival_path = Path(cfg.survival_csv)
        reproduction_path = Path(cfg.reproduction_csv)
    else:
        # Auto-detect from CSVs in working dir
        csvs = sorted(Path(".").glob("*.csv"))
        if len(csvs) < 2:
            raise RuntimeError(
                "Need survival and reproduction CSV files in the working directory. "
                "Upload them to Colab (e.g., /content) or set CONFIG.survival_csv and CONFIG.reproduction_csv."
            )
        survival_path, reproduction_path = detect_survival_vs_reproduction_csv(csvs)

    # Ingest data
    surv_info = ingest_survival(
        survival_path,
        treatment_label=cfg.treatment_label,
        interval_m=cfg.interval_months_when_month_missing,
    )
    repro_info = ingest_reproduction(
        reproduction_path,
        treatment_label=cfg.treatment_label,
        breeding_months=cfg.breeding_months,
    )

    # Extract monthly survivals (length 12)
    sJ_month = np.array(surv_info["sJ_month"], dtype=float)
    sA_month = np.array(surv_info["sA_month"], dtype=float)
    b3_base_monthly = np.array(repro_info["b3_base_monthly_vector"], dtype=float)

    # Calibrate s so that rho(M)=1 (annual neutrality)
    s_star = calibrate_scale_s(
        sJ_month=sJ_month,
        sA_month=sA_month,
        b3_base_monthly=b3_base_monthly,
        alpha=cfg.alpha,
        theta=cfg.theta_yearling_fraction,
        target_rho=1.0,
        iters=cfg.bisect_iters,
        tol=cfg.bisect_tol,
    )

    # Build final blocks with calibrated s
    blocks = build_monthly_blocks(
        s_scale=s_star,
        sJ_month=sJ_month,
        sA_month=sA_month,
        b3_base_monthly=b3_base_monthly,
        alpha=cfg.alpha,
        theta=cfg.theta_yearling_fraction,
    )

    # Monodromy metrics
    mono = monodromy(blocks)

    # Periodic next-generation operator + diagnostics
    ng = periodic_next_generation_operator(blocks)

    # Projections
    proj = project_months(blocks, T=cfg.projection_months, x0=None)

    # Periodic stable female distribution table
    W = periodic_stable_female_distribution(blocks)

    # Invasion curve (paper illustration uses Beverton–Holt with h_m=1)
    eta, R_inv = invasion_curve(
        R0_per=float(ng["R0_per"]),
        eta_max=cfg.invasion_eta_max,
        n=cfg.invasion_eta_n,
        h=cfg.invasion_hm,
    )

    # Invasion fingerprint table (as in paper)
    # Baseline uses R0_per from the calibrated run (should be ~1).
    R0_values = {
        "Calibrated baseline (R0_per=1)": float(ng["R0_per"]),
        "Supercritical illustration (R0_per=1.250)": 1.250,
        "Treated fecundity (R0_per=0.672)": 0.672,
    }
    fingerprint = invasion_fingerprint_table(R0_values, etas=[0, 1, 5, 10], h=cfg.invasion_hm)

    # ----------------------------
    # Write outputs (params/audit/schema/summary/macros)
    # ----------------------------

    # Audit aggregation report (CSV)
    # (Matches your paper narrative: pooled survivals + converted per-month survivals + reproduction ratio)
    pooled = surv_info.get("pooled", None)
    audit_rows = []
    if pooled is not None:
        audit_rows.append({
            "treatment": cfg.treatment_label,
            "age_class": "J",
            "n": pooled["nJ"],
            "s_pooled": pooled["sJ_pooled"],
            "m_interval": cfg.interval_months_when_month_missing,
            "s_month": pooled["sJ_month_converted"],
        })
        audit_rows.append({
            "treatment": cfg.treatment_label,
            "age_class": "A",
            "n": pooled["nA"],
            "s_pooled": pooled["sA_pooled"],
            "m_interval": cfg.interval_months_when_month_missing,
            "s_month": pooled["sA_month_converted"],
        })
    else:
        # Month field present: store monthwise counts (compact)
        audit_rows.append({
            "treatment": cfg.treatment_label,
            "note": "Month field present; see surv_info in prairie_dog_input_full.yaml for monthwise means/counts."
        })

    # Reproduction audit
    audit_rows.append({
        "treatment": cfg.treatment_label,
        "juveniles_total": repro_info["juveniles_total"],
        "adults_total": repro_info["adults_total"],
        "ratio_juveniles_over_adults": repro_info["ratio_juveniles_over_adults"],
        "breeding_months": ",".join(str(m) for m in cfg.breeding_months),
        "b3_base_per_breeding_month": repro_info["b3_base_per_breeding_month"],
        "s_star_neutrality_scale": float(s_star),
    })

    audit_df = pd.DataFrame(audit_rows)
    audit_csv_path = out_params / "pd_usgs_aggregation_report.csv"
    audit_df.to_csv(audit_csv_path, index=False)

    # Schema/provenance
    schema = {
        "system": "prairie_dog_USGS",
        "period_P": 12,
        "census": "monthly pre-breeding",
        "state_convention": "column-vector; to-from indexing; newborns not a state; recruits enter F1/M1 at next census",
        "treatment_label": cfg.treatment_label,
        "breeding_months": list(cfg.breeding_months),
        "interval_m_when_month_missing": cfg.interval_months_when_month_missing,
        "sex_ratio_alpha": cfg.alpha,
        "theta_yearling_fraction": cfg.theta_yearling_fraction,
        "detected_columns": {
            "survival_csv": surv_info["detected_columns"],
            "reproduction_csv": repro_info["detected_columns"],
        },
        "input_files": {
            "survival_csv": str(survival_path),
            "reproduction_csv": str(reproduction_path),
        },
    }
    write_yaml_or_json(
        schema,
        yaml_path=out_params / "pd_schema.yaml",
        json_path=out_params / "pd_schema.json",
    )

    # Full input bundle (machine-readable)
    input_full = {
        "stages": {
            "female": blocks["stages_female"],
            "male": blocks["stages_male"],
        },
        "alpha": cfg.alpha,
        "breeding_months": list(cfg.breeding_months),
        "theta_yearling_fraction": cfg.theta_yearling_fraction,
        "survival_inputs": {
            "sJ_month": sJ_month.tolist(),
            "sA_month": sA_month.tolist(),
            "survival_ingestion": surv_info,
        },
        "reproduction_inputs": {
            "b3_base_monthly_vector": b3_base_monthly.tolist(),
            "reproduction_ingestion": repro_info,
        },
        "calibration": {
            "s_star": float(s_star),
            "target_rho_M": 1.0,
        },
        "monthly_fecundities_used": {
            "b2_monthly": blocks["b2_monthly"],
            "b3_monthly": blocks["b3_monthly"],
        },
        # Matrices (small; included for audit)
        "U_f_list": [U.tolist() for U in blocks["U_f_list"]],
        "U_m_list": [U.tolist() for U in blocks["U_m_list"]],
        "f_list": [f.tolist() for f in blocks["f_list"]],
        "A_f_list": [A.tolist() for A in blocks["A_f_list"]],
    }
    write_yaml_or_json(
        input_full,
        yaml_path=out_params / "prairie_dog_input_full.yaml",
        json_path=out_params / "prairie_dog_input_full.json",
    )

    # Summary JSON (results)
    results_summary = {
        "system": "prairie_dog_USGS",
        "period_P": 12,
        "metrics": {
            "s_star": float(s_star),
            "R0_per": float(ng["R0_per"]),
            "rho_M": float(mono["rho_M"]),                 # = lambda^12
            "lambda_per_month": float(mono["lambda_per_month"]),
            "rho_Mf": float(mono["rho_Mf"]),               # female block
            "rho_Mm": float(mono["rho_Mm"]),               # male survival product block
            "rho_Uf_period": float(spectral_radius(ng["U_prod"])),
        },
        "diagnostics": {
            "c_k": ng["c_k"],   # norm-based monthwise diagnostics
        },
        "inputs": {
            "treatment_label": cfg.treatment_label,
            "breeding_months": list(cfg.breeding_months),
            "interval_m_when_month_missing": cfg.interval_months_when_month_missing,
            "alpha": cfg.alpha,
            "theta": cfg.theta_yearling_fraction,
        },
    }
    with (out_summary / "RESULTS_SUMMARY_prairie_dog.json").open("w", encoding="utf-8") as f:
        json.dump(results_summary, f, indent=2, sort_keys=True)

    # LaTeX macros (for your manuscript)
    # pdS, pdRzeroPer, pdLambda, pdRhoMm
    macros_path = out_root / "metrics_macros.tex"
    write_metrics_macros(
        pdS=float(s_star),
        pdR0=float(ng["R0_per"]),
        pdLambda=float(mono["lambda_per_month"]),
        pdRhoMm=float(mono["rho_Mm"]),
        out_path=macros_path,
    )
    # Also write a copy at repo root if desired
    write_metrics_macros(
        pdS=float(s_star),
        pdR0=float(ng["R0_per"]),
        pdLambda=float(mono["lambda_per_month"]),
        pdRhoMm=float(mono["rho_Mm"]),
        out_path=Path("metrics_macros.tex"),
    )

    # Periodic SSD table (supplement hook)
    write_pd_periodic_ssd_tex(W, Path("pd_periodic_ssd.tex"))
    write_pd_periodic_ssd_tex(W, out_tables / "pd_periodic_ssd.tex")

    # Invasion fingerprint outputs
    fingerprint_csv = out_tables / "pd_invasion_fingerprint.csv"
    fingerprint.to_csv(fingerprint_csv, index=False)

    # ----------------------------
    # Figures
    # ----------------------------
    plot_newborns(proj["N"], out_figs / "pd_newborns.png")
    plot_projection(proj["X"], out_figs / "pd_projection.png")
    plot_R0_contrib(ng["c_k"], out_figs / "pd_R0_contrib.png")
    plot_invasion_curve(eta, R_inv, out_figs / "pd_invasion_curve.png")

    # Also copy figures to current directory for convenience (like your polar-bear runs)
    for fn in ["pd_newborns.png", "pd_projection.png", "pd_R0_contrib.png", "pd_invasion_curve.png"]:
        src = out_figs / fn
        if src.exists():
            Path(fn).write_bytes(src.read_bytes())

    # ----------------------------
    # Console report (should match paper numbers)
    # ----------------------------
    print("=== Prairie dog (USGS; P=12) caregiver baseline ===")
    print(f"Input files:")
    print(f"  survival:      {survival_path}")
    print(f"  reproduction:  {reproduction_path}")
    print("")
    if surv_info.get("pooled", None) is not None:
        print("Survival ingestion (month missing => pooled -> per-month conversion):")
        print(f"  pooled s_J = {surv_info['pooled']['sJ_pooled']:.6f}  (n={surv_info['pooled']['nJ']})")
        print(f"  pooled s_A = {surv_info['pooled']['sA_pooled']:.6f}  (n={surv_info['pooled']['nA']})")
        print(f"  per-month s_J = {surv_info['pooled']['sJ_month_converted']:.6f}")
        print(f"  per-month s_A = {surv_info['pooled']['sA_month_converted']:.6f}")
        print("")
    print("Reproduction ingestion:")
    print(f"  juveniles_total = {repro_info['juveniles_total']:.6f}")
    print(f"  adults_total    = {repro_info['adults_total']:.6f}")
    print(f"  ratio J/A       = {repro_info['ratio_juveniles_over_adults']:.6f}")
    print(f"  b3_base (Apr-Jun each) = {repro_info['b3_base_per_breeding_month']:.6f}")
    print("")
    print("Calibrated neutrality:")
    print(f"  s* = {s_star:.6f}")
    print(f"  rho(M)=lambda^12 = {mono['rho_M']:.6f}")
    print(f"  lambda (per month) = {mono['lambda_per_month']:.6f}")
    print(f"  rho(M_f) = {mono['rho_Mf']:.6f}")
    print(f"  rho(M_m) = {mono['rho_Mm']:.6f}")
    print("")
    print("Periodic next-generation operator:")
    print(f"  R0_per = rho(K) = {ng['R0_per']:.6f}")
    print("")
    print("Wrote key outputs:")
    print(f"  {out_params / 'prairie_dog_input_full.yaml'} (and .json)")
    print(f"  {audit_csv_path}")
    print(f"  {out_params / 'pd_schema.yaml'} (and .json)")
    print(f"  {out_summary / 'RESULTS_SUMMARY_prairie_dog.json'}")
    print(f"  {macros_path} and ./metrics_macros.tex")
    print(f"  ./pd_periodic_ssd.tex")
    print(f"  Figures in {out_figs} (also copied to ./)")
    print(f"  Fingerprint table: {fingerprint_csv}")


# Run (Colab)
if __name__ == "__main__":
    main(CONFIG)


  Writes pd_periodic_ssd.tex compatible with your supplement \IfFileExists{pd_periodic_ssd.tex}{...}.


=== Prairie dog (USGS; P=12) caregiver baseline ===
Input files:
  survival:      Fipronil 2018 Monthly Survival.csv
  reproduction:  Fipronil 2020 Reproduction.csv

Survival ingestion (month missing => pooled -> per-month conversion):
  pooled s_J = 0.582278  (n=79)
  pooled s_A = 0.471074  (n=121)
  per-month s_J = 0.955933
  per-month s_A = 0.939199

Reproduction ingestion:
  juveniles_total = 1.000000
  adults_total    = 24.000000
  ratio J/A       = 0.041667
  b3_base (Apr-Jun each) = 0.013889

Calibrated neutrality:
  s* = 49.731746
  rho(M)=lambda^12 = 1.000000
  lambda (per month) = 1.000000
  rho(M_f) = 1.000000
  rho(M_m) = 0.471074

Periodic next-generation operator:
  R0_per = rho(K) = 1.000000

Wrote key outputs:
  out/params/prairie_dog_input_full.yaml (and .json)
  out/params/pd_usgs_aggregation_report.csv
  out/params/pd_schema.yaml (and .json)
  out/summary/RESULTS_SUMMARY_prairie_dog.json
  out/metrics_macros.tex and ./metrics_macros.tex
  ./pd_periodic_ssd.tex
  Figu