In [1]:
# Python 3.13
# If needed in a fresh env:
#   pip install "pymc>=5.21" "arviz>=0.17" "numpy>=2" "pandas>=2.2"
from pytensor.tensor.variable import TensorVariable
from pytensor.tensor import slinalg
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import arviz as az

# ---- 1) Get synthetic data from helpers ----
from bayes_tools.helpers.synthetic_data_helpers import (
    make_hierarchical_ou_dataset,
)

# Build OU-level monthly panel
df_ou = make_hierarchical_ou_dataset(
    n_regions=1,
    n_sites_per_region=3,
    n_ous_per_site=3,
    n_years=1,
    wave_months=(6, 12),
    wave_missing_prob=0.,
    seed=7,
)

# (Optional) aggregate to parent level (e.g., 'site' or 'region')
# df_site = aggregate_to_parent(df_ou, level="site")

# We'll model at the OU level:
df = df_ou.copy()

In [None]:


# -------------------------------
# 0) Data prep + indices (with coords)
# -------------------------------
df = df_ou.copy().sort_values(["ou_code","date"]).reset_index(drop=True)

# OU index
ou_cat = df["ou_code"].astype("category")
ou_idx = ou_cat.cat.codes.to_numpy().astype("int32")
ou_labels = ou_cat.cat.categories.astype(str).to_numpy()
G = len(ou_labels)

# Site mapping (fallback if not present)
if "site_code" in df.columns:
    ou_to_site = (df[["ou_code","site_code"]].drop_duplicates()
                    .set_index("ou_code").loc[ou_labels, "site_code"])
    site_cat = ou_to_site.astype("category")
    site_labels = site_cat.cat.categories.astype(str).to_numpy()
    site_of_ou = site_cat.cat.codes.to_numpy().astype("int32")
else:
    site_labels = np.array(["site0"])
    site_of_ou = np.zeros(G, dtype="int32")
S = len(site_labels)

# Time index (monthly)
dates = pd.to_datetime(df["date"])
unique_months = np.sort(dates.unique())
T = len(unique_months)
month_to_idx = {m: i for i, m in enumerate(unique_months)}
time_idx = np.array([month_to_idx[d] for d in dates], dtype="int32")

# Observed outcomes/predictors
y_raw = df["productivity"].to_numpy(dtype="float64")
y = np.log(y_raw).astype("float64")

x_obs = df["survey_score"].to_numpy(dtype="float64")   # NaN where missing
n_resp = df["n_respondents"].to_numpy(dtype="float64") # NaN where no survey

n_resp_filled = np.where(np.isnan(n_resp), 0.0, n_resp).astype("float64")
has_survey = ~np.isnan(x_obs)
idx_obs = np.flatnonzero(has_survey)

# (Optional) Known survey schedule: e.g., months 6 and 12 of each year
WAVE_MONTHS = (6, 12)
is_scheduled = np.isin(pd.DatetimeIndex(dates).month.to_numpy(), WAVE_MONTHS)
idx_sched = np.flatnonzero(is_scheduled)  # only these can be observed
responded_sched = has_survey[idx_sched].astype("int8")  # observed 1/0 on scheduled rows

# Standardize y and x (using observed x only)
def zscore(a: np.ndarray):
    m = np.nanmean(a); s = np.nanstd(a); s = s if s > 0 else 1.0
    return (a - m) / s, m, s

y_z, y_mean, y_sd = zscore(y)
x_z = x_obs.copy()
if np.isfinite(x_obs[idx_obs]).any():
    x_z, x_mean, x_sd = zscore(x_obs)
else:
    x_mean, x_sd = 0.0, 1.0
    x_z = (x_obs - x_mean) / x_sd

# Optional: data-informed prior scales for latent survey process (EB flavor)
df_tmp = df.copy()
df_tmp["x_z"] = x_z
between_sd_z = np.nanstd(df_tmp.groupby("ou_code")["x_z"].mean().to_numpy())
within_sd_z = np.nanmedian(df_tmp.groupby("ou_code")["x_z"].std().to_numpy())
between_sd_z = np.nan_to_num(between_sd_z, nan=1.0)
within_sd_z = np.nan_to_num(within_sd_z, nan=1.0)

# Shared arrays
N = len(df)
n_eff_np = np.clip(np.sqrt(np.maximum(n_resp_filled, 1.0)), 1.0, 1000.0)

# -------------------------------
# 1) Build model with coords/dims
# -------------------------------
coords = {
    "obs": np.arange(N),
    "obs_svy": np.arange(len(idx_obs)),
    "obs_sched": np.arange(len(idx_sched)),
    "ou": ou_labels,
    "site": site_labels,
    "time": pd.to_datetime(unique_months),
    "ab": ["alpha","beta"],  # just for labeling when needed
}

with pm.Model(coords=coords) as model:
    # Index helpers (pm.Data so we can reuse)
    ou_of_obs   = pm.Data("ou_of_obs", ou_idx, dims="obs")
    time_of_obs = pm.Data("time_of_obs", time_idx, dims="obs")
    site_of_ouD = pm.Data("site_of_ou", site_of_ou, dims="ou")
    n_eff       = pm.Data("n_eff", n_eff_np, dims="obs")
    obs_idx_svy = pm.Data("obs_idx_svy", idx_obs, dims="obs_svy")
    x_z_obs     = pm.Data("x_z_obs", x_z[idx_obs], dims="obs_svy")
    obs_idx_sched = pm.Data("obs_idx_sched", idx_sched, dims="obs_sched")
    responded_schedD = pm.Data("responded_sched", responded_sched, dims="obs_sched")

    # ------------------
    # Global means
    # ------------------
    mu_alpha = pm.Normal("mu_alpha", 0.0, 1.0)
    mu_beta  = pm.Normal("mu_beta",  0.0, 0.5)

    # ------------------
    # Site-level (alpha,beta) correlated random effects via LKJ
    # ------------------
    L_site, corr_site, sds_site = pm.LKJCholeskyCov(
        "L_site", n=2, eta=2.0, sd_dist=pm.HalfNormal.dist(1.0),
        compute_corr=True, store_in_trace=True
    )
    z_site = pm.Normal("z_site", 0.0, 1.0, size=(S, 2))
    assert isinstance(L_site, TensorVariable) and isinstance(corr_site, TensorVariable) and isinstance(sds_site, TensorVariable)
    ab_site = z_site @ L_site.T     # (S,2)
    alpha_site = pm.Deterministic("alpha_site", ab_site[:, 0], dims="site")
    beta_site  = pm.Deterministic("beta_site",  ab_site[:, 1], dims="site")

    # ------------------
    # OU-level (alpha,beta) residual random effects via LKJ
    # ------------------
    L_ou, corr_ou, sds_ou = pm.LKJCholeskyCov(
        "L_ou", n=2, eta=2.0, sd_dist=pm.HalfNormal.dist(1.0),
        compute_corr=True, store_in_trace=True
    )
    z_ou = pm.Normal("z_ou", 0.0, 1.0, size=(G, 2))
    assert isinstance(L_ou, TensorVariable) and isinstance(corr_ou, TensorVariable) and isinstance(sds_ou, TensorVariable)
    ab_ou = z_ou @ L_ou.T            # (G,2)
    alpha_ou_off = ab_ou[:, 0]
    beta_ou_off  = ab_ou[:, 1]

    # Combine levels to get OU-specific coeffs
    alpha_ou = pm.Deterministic(
        "alpha_ou", mu_alpha + alpha_site[site_of_ouD] + alpha_ou_off, dims="ou"
    )
    beta_ou = pm.Deterministic(
        "beta_ou", mu_beta + beta_site[site_of_ouD] + beta_ou_off, dims="ou"
    )

    # Helpful deterministics
    pm.Deterministic("corr_site_alpha_beta", corr_site[0, 1])
    pm.Deterministic("corr_ou_alpha_beta",   corr_ou[0, 1])

    # ------------------
    # Time effects with AR(1) structure (both outcome and survey processes)
    # ------------------
    def ar1_chol(phi: TensorVariable, sigma: TensorVariable, T: int):
        t = pt.arange(T)
        diff = pt.abs(t[:, None] - t[None, :])
        R = pt.power(phi, diff)                      # correlation
        variance = (sigma**2) / (1 - phi**2)         # stationary variance
        Sigma = variance * R
        Sigma = Sigma + 1e-6 * pt.eye(T)             # jitter
        return slinalg.cholesky(Sigma)

    # Outcome time effect λ_t
    phi_lambda  = pm.Uniform("phi_lambda", lower=-0.95, upper=0.95)
    sigma_lambda = pm.HalfNormal("sigma_lambda", 1.0)
    lambda_time = pm.MvNormal(
        "lambda_time",
        mu=pt.zeros(T), chol=ar1_chol(phi_lambda, sigma_lambda, T),
        dims="time"
    )
    lambda_eff = pm.Deterministic("lambda_eff", lambda_time - pt.mean(lambda_time), dims="time")

    # Survey time effect τ_t (for x*)
    phi_tau  = pm.Uniform("phi_tau", lower=-0.95, upper=0.95)
    sigma_tau = pm.HalfNormal("sigma_tau", 1.0)
    tau_time = pm.MvNormal(
        "tau_time",
        mu=pt.zeros(T), chol=ar1_chol(phi_tau, sigma_tau, T),
        dims="time"
    )
    tau_eff = pm.Deterministic("tau_eff", tau_time - pt.mean(tau_time), dims="time")

    # ------------------
    # Latent survey process x* (z scale), with EB-scaled priors (optional)
    # ------------------
    mu_x = pm.Normal("mu_x", 0.0, 1.0)
    sigma_mu_x = pm.HalfNormal("sigma_mu_x", max(0.3, 1.5 * between_sd_z))  # EB-flavored scale
    mu_x_ou = pm.Normal("mu_x_ou", mu_x, sigma_mu_x, dims="ou")

    gamma_time_x = pm.Normal("gamma_time_x", 0.05, 0.3)  # learn time wiggle magnitude
    sigma_x = pm.HalfNormal("sigma_x", max(0.3, 1.5 * within_sd_z))

    x_latent = pm.Normal(
        "x_latent",
        mu=mu_x_ou[ou_of_obs] + gamma_time_x * tau_eff[time_of_obs],
        sigma=sigma_x,
        dims="obs",
    )

    # ------------------
    # Survey measurement model (only at observed rows)
    # ------------------
    sigma_meas_base = pm.HalfNormal("sigma_meas_base", 1.0)
    sigma_meas = sigma_meas_base / n_eff
    pm.Normal(
        "survey_obs",
        mu=x_latent[obs_idx_svy],
        sigma=sigma_meas[obs_idx_svy],
        observed=x_z_obs,
    )

    # ------------------
    # Missingness mechanism (only on scheduled months)
    #   logit P(survey present | scheduled) = ρ0 + ρx * x* + ρ_ou + ρ_time
    # ------------------
    rho0 = pm.Normal("rho0", 0.0, 1.0)
    rho_x = pm.Normal("rho_x", 0.0, 1.0)
    # OU random intercept for response
    sigma_rho_ou = pm.HalfNormal("sigma_rho_ou", 0.5)
    rho_ou_raw = pm.Normal("rho_ou_raw", 0.0, 1.0, dims="ou")
    rho_ou = pm.Deterministic("rho_ou", sigma_rho_ou * rho_ou_raw, dims="ou")
    # Time random intercept (centered)
    sigma_rho_t = pm.HalfNormal("sigma_rho_t", 0.5)
    rho_t_raw = pm.Normal("rho_t_raw", 0.0, 1.0, dims="time")
    rho_t = pm.Deterministic("rho_t", sigma_rho_t * (rho_t_raw - pt.mean(rho_t_raw)), dims="time")

    logit_p_sched = (
        rho0
        + rho_x * x_latent[obs_idx_sched]
        + rho_ou[ou_of_obs[obs_idx_sched]]
        + rho_t[time_of_obs[obs_idx_sched]]
    )
    pm.Bernoulli(
        "survey_present_given_scheduled",
        logit_p=logit_p_sched,
        observed=responded_schedD,
    )

    # ------------------
    # Outcome model: Student-t likelihood for robustness
    # ------------------
    sigma_y = pm.HalfNormal("sigma_y", 0.5)
    nu_y_raw = pm.Exponential("nu_y_raw", 1/15)  # mean ~15
    nu_y = pm.Deterministic("nu_y", nu_y_raw + 2.0)  # ensure nu>2 (finite variance)

    mu_y = (
        alpha_ou[ou_of_obs]
        + lambda_eff[time_of_obs]
        + beta_ou[ou_of_obs] * x_latent
    )
    y_like = pm.StudentT("y_like", nu=nu_y, mu=mu_y, sigma=sigma_y, observed=y_z)

    # (optional global deterministics)
    pm.Deterministic("beta_global", mu_beta)
    pm.Deterministic("alpha_global", mu_alpha)

    # Sampling (tune target_accept up if you see divergences)
    idata = pm.sample(
        draws=1500, tune=2000, chains=4, cores=4,
        target_accept=0.95, random_seed=7, progressbar=True
    )

In [None]:
# # TWO WAYS TO MITIGATE OVERFITTING
# # 1. Gaussian (ridge) shrinkage: beta ~ Normal(0, tau) with small/learned tau.
# # 2. [SHOWN BELOW] (Regularized) Horseshoe for sparse-ish signals (many small, few large):
# tau = pm.HalfStudentT("tau", nu=3, sigma=0.5)      # global
# lam = pm.HalfCauchy("lam", beta=1, dims="cov")     # local
# c   = pm.InverseGamma("c", 2, 2)                   # slab scale
# slab = pt.sqrt((c**2 * lam**2) / (c**2 + tau**2 * lam**2))
# beta = pm.Normal("beta", 0, tau * lam * slab, dims="cov")

In [None]:
x_post_mean = idata.posterior["x_latent"].mean(("chain","draw")).values
np.corrcoef(x_post_mean[idx_obs], x_z[idx_obs])[0,1]   # expect high

In [None]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt, arviz as az

# Nice labels for OUs
ou_labels = pd.Index(pd.Categorical(df["ou_code"]).categories.astype(str), name="ou_code")

# Quick tables
summ_alpha = az.summary(idata, var_names=["alpha_ou"], hdi_prob=0.9)
summ_beta  = az.summary(idata, var_names=["beta_ou"],  hdi_prob=0.9)

# Attach OU codes to summaries (indices like 'beta_ou[0]')
def add_ou_labels(summ, ou_labels):
    lab = []
    for idx in summ.index:
        i = int(idx.split("[")[1].rstrip("]"))
        lab.append(ou_labels[i])
    summ = summ.assign(ou_code=lab).set_index("ou_code")
    return summ

summ_alpha = add_ou_labels(summ_alpha, ou_labels)
summ_beta  = add_ou_labels(summ_beta,  ou_labels)

print("Alpha per OU (posterior mean and 90% HDI):")
display(summ_alpha[["mean","hdi_5%","hdi_95%"]])
print("\nBeta per OU (posterior mean and 90% HDI):")
display(summ_beta[["mean","hdi_5%","hdi_95%"]])

# Forest plots
az.plot_forest(idata, var_names=["beta_ou"], combined=True, hdi_prob=0.9, figsize=(8, 0.35*len(ou_labels)+2))
plt.title("OU slopes β_g (x* → log-productivity, both z-scored)")
plt.show()

az.plot_forest(idata, var_names=["alpha_ou"], combined=True, hdi_prob=0.9, figsize=(8, 0.35*len(ou_labels)+2))
plt.title("OU intercepts α_g (log-productivity z-score)")
plt.show()

In [None]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt, arviz as az
# assumes: idata, y_sd, x_sd, df, ou_labels already defined

# Pull posterior: dims ~ (chain, draw, beta_ou_dim_0)
beta = idata.posterior["beta_ou"]

# Map to % change in raw productivity per +1 survey point
mult = np.exp((y_sd / x_sd) * beta)
pct  = (mult - 1.0) * 100.0  # dims: (chain, draw, ou)

# Posterior mean & 90% interval over chain/draw
pct_mean = pct.mean(dim=("chain","draw")).values  # shape: (G,)
q = pct.quantile([0.05, 0.95], dim=("chain","draw"))          # dims: (quantile, ou)
pct_lo = q.sel(quantile=0.05).values
pct_hi = q.sel(quantile=0.95).values

eff_df = pd.DataFrame({
    "ou_code": pd.Index(pd.Categorical(df["ou_code"]).categories.astype(str)),
    "pct_mean": pct_mean, "pct_lo": pct_lo, "pct_hi": pct_hi
}).set_index("ou_code").sort_values("pct_mean")

display(eff_df)

# Plot interval per OU
fig, ax = plt.subplots(figsize=(8, 0.35*len(eff_df)+2))
ypos = np.arange(len(eff_df))
ax.hlines(y=ypos, xmin=eff_df["pct_lo"], xmax=eff_df["pct_hi"])
ax.plot(eff_df["pct_mean"], ypos, "o")
ax.axvline(0, ls="--", lw=1)
ax.set_yticks(ypos, eff_df.index)
ax.set_xlabel("% change in raw productivity per +1 survey point (90% interval)")
ax.set_title("OU-specific effect interpretation")
plt.tight_layout(); plt.show()

In [None]:
# Quick tables
summ_alpha = az.summary(idata, var_names=["alpha_ou"], hdi_prob=0.9)
summ_beta  = az.summary(idata, var_names=["beta_ou"],  hdi_prob=0.9)

def add_ou_labels(summ, labels):
    labs = []
    for idx in summ.index:
        i = int(idx.split("[")[1].rstrip("]"))
        labs.append(labels[i])
    return summ.assign(ou_code=labs).set_index("ou_code")

ou_labels = pd.Index(pd.Categorical(df["ou_code"]).categories.astype(str), name="ou_code")
display(add_ou_labels(summ_alpha, ou_labels)[["mean","hdi_5%","hdi_95%"]])
display(add_ou_labels(summ_beta,  ou_labels)[["mean","hdi_5%","hdi_95%"]])

# Forest plots
az.plot_forest(idata, var_names=["beta_ou"], combined=True, hdi_prob=0.9,
               figsize=(8, 0.35*len(ou_labels)+2))
plt.title("OU slopes β_g (x* → log-productivity, both z-scored)"); plt.show()

az.plot_forest(idata, var_names=["alpha_ou"], combined=True, hdi_prob=0.9,
               figsize=(8, 0.35*len(ou_labels)+2))
plt.title("OU intercepts α_g (log-productivity z-score)"); plt.show()

In [None]:
import matplotlib.dates as mdates

def plot_latent_for_ou(ou_code_str):
    labels = pd.Categorical(df["ou_code"]).categories.astype(str)
    g = int(np.where(labels == ou_code_str)[0][0])

    # rows for this OU, ordered in time
    row_idx_all = np.where(pd.Categorical(df["ou_code"]).codes == g)[0]
    t = time_idx[row_idx_all]
    order = np.argsort(t)
    row_idx = row_idx_all[order]
    months = pd.to_datetime(unique_months[t[order]])

    # posterior summaries for x_latent at those rows
    x_da  = idata.posterior["x_latent"]                     # (chain, draw, obs)
    x_mu  = x_da.mean(("chain","draw")).values[row_idx]
    qx    = x_da.quantile([0.05, 0.95], dim=("chain","draw"))
    lo    = qx.sel(quantile=0.05).values[row_idx]
    hi    = qx.sel(quantile=0.95).values[row_idx]

    # observed standardized survey at those rows
    obs_mask = has_survey[row_idx]
    obs_months = months[obs_mask]
    obs_vals   = x_z[row_idx][obs_mask]

    fig, ax = plt.subplots(figsize=(10,4))
    ax.fill_between(months, lo, hi, alpha=0.25, label="x* 90% interval")
    ax.plot(months, x_mu, lw=2, label="x* posterior mean")
    ax.scatter(obs_months, obs_vals, marker="x", s=60, label="observed survey (z)")
    ax.set_title(f"Latent survey x* over time — OU {ou_code_str}")
    ax.set_ylabel("Standardized survey (z)")
    ax.set_xlabel("Month")
    ax.legend()
    ax.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(ax.xaxis.get_major_locator()))
    plt.tight_layout(); plt.show()

# Example
plot_latent_for_ou(str(ou_labels[0]))

In [None]:
# Quick plot: OU intercept vs slope post means
a_mu = idata.posterior["alpha_ou"].mean(("chain","draw")).to_numpy()
b_mu = idata.posterior["beta_ou"].mean(("chain","draw")).to_numpy()

plt.figure(figsize=(5,4))
plt.scatter(a_mu, b_mu)
for i, lab in enumerate(ou_labels):
    plt.annotate(str(lab), (a_mu[i], b_mu[i]), fontsize=8, xytext=(3,3), textcoords="offset points")
plt.axhline(0, lw=1, ls="--"); plt.axvline(0, lw=1, ls="--")
plt.xlabel("α_g (mean log-productivity z)"); plt.ylabel("β_g (slope)")
plt.title("OU intercepts vs slopes (posterior means)")
plt.tight_layout(); plt.show()

# Some single-parameter posteriors
az.plot_trace(idata, var_names=["beta_global","corr_alpha_beta","sigma_y","sigma_x","sigma_meas_base"], compact=True, figsize=(10,6))
plt.show()

In [None]:
# import os
# os.environ["PATH"] = "/opt/homebrew/bin:" + os.environ["PATH"]
# # Optional: be explicit
# os.environ["GRAPHVIZ_DOT"] = "/opt/homebrew/bin/dot"
# pm.model_to_graphviz(model)

In [None]:
g = pm.model_to_graphviz(model)
print(g.source[:1000])           # preview DOT text
g.save("model.dot")              # write DOT to file

In [None]:
import shutil, os 
print ("dot:", shutil.which("dot"))

In [None]:
import os, shutil
os.environ["PATH"] = "/opt/homebrew/bin:" + os.environ["PATH"]
print(shutil.which("dot"))

In [None]:
g = pm.model_to_graphviz(model)
g.graph_attr.update(rankdir="LR")  # optional layout
g

In [None]:
# import os
# os.environ["PATH"] = "/opt/homebrew/bin:" + os.environ["PATH"]
# # Optional: be explicit
# os.environ["GRAPHVIZ_DOT"] = "/opt/homebrew/bin/dot"
# pm.model_to_graphviz(model)

In [None]:
g = pm.model_to_graphviz(model)
print(g.source[:1000])           # preview DOT text
g.save("model.dot")              # write DOT to file

In [None]:
import shutil, os 
print ("dot:", shutil.which("dot"))

In [None]:

# ---- 5) Quick checks ----
az.summary(idata, var_names=[
    "mu_alpha", "mu_beta", "sigma_y", "sigma_x", "sigma_mu_x",
    "sigma_meas_base", "corr_alpha_beta"
], kind="stats")