In [None]:
import os
os.chdir("..")

In [None]:
import numpy as np
import spreg
import pymc as pm
import arviz as az
import bambi as bmb
from pysal.lib import weights
import polars as pl

from src.data.data_reg import SpatialReg
az.style.use("arviz-darkgrid")

sr = SpatialReg(n=10)


In [None]:
gdf = sr.spatial_panel(n=10,time=5,rho=.8)
gdf

In [None]:
gdf[gdf["time"]==0].plot("X1")

In [None]:
gdf[gdf["time"]==0].plot("y_d")

In [None]:
wr = weights.contiguity.Rook.from_dataframe(gdf[gdf["time"] == 0])
wr.transform = "r"
y_d = gdf["y_d"].values.reshape(-1,1)
xb = gdf["X1"].values.reshape(-1,1)
fe_lag = spreg.Panel_FE_Lag(y=y_d, x=xb, w=wr)
print(fe_lag.summary)

In [None]:
df = gdf.drop("geometry", axis=1)
model = bmb.Model(
    "y_d ~ 1 + X1 + (0 + w_d|id)",
    df, dropna=True
)
results = model.fit()

In [None]:
gdf["centroid"] = gdf.centroid
gdf["lat"] = gdf["centroid"].x
gdf["lon"] = gdf["centroid"].y
df = gdf.drop("geometry", axis=1)
df

In [None]:
X = df[["X1","lat","lon"]].values.reshape(-1,3)
y = df["y_d"].values.reshape(-1,1)
X


In [None]:
# Sort and extract variables
gdf = gdf.sort_values(["time", "id"]).reset_index(drop=True)

# Encode spatial unit ids as integers 0..N-1
gdf["unit_id"] = gdf["id"].astype("category").cat.codes
N = gdf["unit_id"].nunique()
T = gdf["time"].nunique()

y = gdf["y_d"].values
X1 = gdf["X1"].values
Wy = gdf["w_d"].values
unit_idx = gdf["unit_id"].values


In [None]:
with pm.Model() as model:
    # Hyperpriors
    sigma = pm.HalfNormal("sigma", 1.0)
    tau_rho = pm.HalfNormal("tau_rho", 1.0)
    tau_mu = pm.HalfNormal("tau_mu", 1.0)

    # Priors
    beta = pm.Normal("beta", mu=0, sigma=5)
    rho_i = pm.Normal("rho", mu=0, sigma=tau_rho, shape=N)     # one rho per unit
    mu_i = pm.Normal("mu", mu=0, sigma=tau_mu, shape=N)         # one intercept per unit

    # Create shared inputs
    X_data = pm.Data("X1", X1)
    Wy_data = pm.Data("Wy", Wy)
    unit_idx_data = pm.Data("unit_idx", unit_idx)

    # Compute mu_y
    mu_y = rho_i[unit_idx_data] * Wy_data + beta * X_data + mu_i[unit_idx_data]

    # Likelihood
    y_obs = pm.Normal("y_obs", mu=mu_y, sigma=sigma, observed=y)

    trace = pm.sample(1000, tune=1000, target_accept=0.9, return_inferencedata=True)


In [None]:
az.plot_trace(trace, var_names=["rho", "beta", "sigma"])
az.summary(trace, var_names=["rho", "beta", "sigma"])


In [None]:
rho_true = .8
summary = az.summary(trace, var_names=["rho"], hdi_prob=0.94)
within_hdi = (rho_true >= summary["hdi_3%"]) & (rho_true <= summary["hdi_97%"])

# Report results
all_contain = within_hdi.all()
num_pass = within_hdi.sum()
num_total = len(within_hdi)

print(f"True rho = {rho_true}")
print(f"{num_pass}/{num_total} HDIs contain true rho.")
print("All HDIs contain true rho." if all_contain else "Not all HDIs contain true rho.")

# Optionally, list which units failed
if not all_contain:
    failed_units = np.where(~within_hdi)[0]
    print(f"Units failing HDI test: {failed_units}")