In [None]:
# spatio-temporal_cai_workflow.py
# ---------------------------------------------------------------------
# 3.7.2 Spatial Aggregation Analysis for CAI (Italy, NUTS-3)
# - Build region-year CAI panel (survey-weighted)
# - Cross-sectional maps (pooled mean; dominant ecoscheme)
# - Local stats: Getis–Ord Gi* hotspots/coldspots (2021), Moran's I (LISA)
# - Bayesian STAR (CmdStanPy) spatio-temporal fit + 2022–2026 forecasts
# - National time-track plot with 95% credible interval
#
# Requirements (install as needed):
#   pip install geopandas libpysal esda matplotlib numpy pandas cmdstanpy arviz
#   (and have CmdStan installed & set up; see: https://mc-stan.org/cmdstanpy/)
# ---------------------------------------------------------------------

from __future__ import annotations
import os
from pathlib import Path
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt

from libpysal.weights import Queen, Rook, KNN
from esda.getisord import G_Local
from esda.moran import Moran_Local

# Bayesian tooling
try:
    from cmdstanpy import CmdStanModel
    import arviz as az
    HAVE_STAN = True
except Exception:
    HAVE_STAN = False

# ----------------------------- CONFIG ---------------------------------
# INPUTS you should point to your files
PANEL_FARMS = "data/processed/normalised_with_cai.csv"  # farm-level (or holding-level) with CAI_farm + weights + nuts3 + year
GEOM_NUTS3  = "italy_nuts3.geojson"                         # NUTS-3 polygons (must have column 'nuts3')
REGION_NAME_COL = "name"                                    # readable name field in polygons

CAI_COL_FARM     = "CAI_farm"       # per-farm CAI (bounded [0,1]) from your CAI pipeline
SURVEY_WEIGHT    = "rica_weight"    # survey weight column (RICA)
NUTS3_ID_COL     = "nuts3"          # NUTS-3 code in both datasets
YEAR_COL         = "year"

# Per-scheme CAIs to compute "dominant ecoscheme"
ES_COLS_RAW = ["CAI_ES1", "CAI_ES2", "CAI_ES3", "CAI_ES4", "CAI_ES5"]  

# STAR model horizon
FORECAST_YEARS = [2022, 2023, 2024, 2025, 2026]

# Output folder
OUT = Path("spatial_outputs")
OUT.mkdir(parents=True, exist_ok=True)

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# ----------------------- UTILITIES ------------------------------------
def wavg(x, w):
    x = np.asarray(x, float)
    w = np.asarray(w, float)
    sw = w.sum()
    return float((x * w).sum() / sw) if sw > 0 else float(np.mean(x))

def standardize_w_row(w):
    """Row-standardize libpysal W in place and return dense matrix."""
    w.transform = "R"
    # Dense matrix
    return w.full()[0]

def common_color_limits(values_list):
    """Find common min/max for a list of arrays/series for a uniform map legend."""
    mins = [np.nanmin(np.asarray(v)) for v in values_list]
    maxs = [np.nanmax(np.asarray(v)) for v in values_list]
    return float(np.min(mins)), float(np.max(maxs))

# -------------------- 1) MAKE REGION-YEAR PANEL -----------------------
print("Building region-year panel (survey-weighted CAI)...")
df = pd.read_csv(PANEL_FARMS)

need_cols = {NUTS3_ID_COL, YEAR_COL, CAI_COL_FARM}
missing = need_cols - set(df.columns)
if missing:
    raise KeyError(f"Missing required columns in {PANEL_FARMS}: {missing}")

if SURVEY_WEIGHT not in df.columns:
    df[SURVEY_WEIGHT] = 1.0  # fallback to unweighted if not present

# Weighted mean CAI per NUTS-3 × year
panel = (df.groupby([NUTS3_ID_COL, YEAR_COL], dropna=False)
           .apply(lambda g: wavg(g[CAI_COL_FARM], g[SURVEY_WEIGHT]))
           .reset_index(name="CAI_region_year"))

# Per-scheme weighted means to compute "dominant ecoscheme"
have_es = all(c in df.columns for c in ES_COLS_RAW)
if have_es:
    es_panel = (df.groupby([NUTS3_ID_COL, YEAR_COL], dropna=False)
                  .apply(lambda g: pd.Series({c: wavg(g[c], g[SURVEY_WEIGHT]) for c in ES_COLS_RAW}))
                  .reset_index())
    panel = panel.merge(es_panel, on=[NUTS3_ID_COL, YEAR_COL], how="left")

panel.to_csv(OUT / "panel_region_year.csv", index=False)

# Pooled 2014–2021 cross-section (survey-weighted across years)
obs_years = sorted(panel[YEAR_COL].unique())
obs_years = [y for y in obs_years if y <= 2021]
pooled = (panel[panel[YEAR_COL].isin(obs_years)]
          .groupby(NUTS3_ID_COL, dropna=False)["CAI_region_year"].mean()
          .reset_index(name="CAI_pooled_2014_2021"))

# Dominant ecoscheme per region (normalize across schemes & pick winner)
if have_es:
    es_cols = ES_COLS_RAW
    es_region = (panel[panel[YEAR_COL].isin(obs_years)]
                 .groupby(NUTS3_ID_COL, dropna=False)[es_cols].mean().reset_index())
    # Normalize across columns per region
    Z = es_region[es_cols].values
    row_min = Z.min(axis=1, keepdims=True)
    row_rng = (Z.max(axis=1, keepdims=True) - row_min)
    row_rng[row_rng == 0] = 1
    Z_n = (Z - row_min) / row_rng
    winner_idx = Z_n.argmax(axis=1)
    es_region["dominant_es"] = [f"ES{i+1}" for i in winner_idx]
    dom_tbl = es_region[[NUTS3_ID_COL, "dominant_es"]]
else:
    dom_tbl = None

# -------------------- 2) LOAD GEOMETRY & MERGE -----------------------
print("Merging with NUTS-3 geometry...")
gdf = gpd.read_file(GEOM_NUTS3)
if NUTS3_ID_COL not in gdf.columns:
    raise KeyError(f"'{NUTS3_ID_COL}' not found in geometry file {GEOM_NUTS3}")

gdf = gdf[[NUTS3_ID_COL, REGION_NAME_COL, "geometry"]].copy()
gdf_panel = gdf.merge(pooled, on=NUTS3_ID_COL, how="left")

if dom_tbl is not None:
    gdf_panel = gdf_panel.merge(dom_tbl, on=NUTS3_ID_COL, how="left")

# -------------------- 3) CROSS-SECTIONAL MAPS -------------------------
print("Rendering cross-sectional maps...")
# (A) Pooled CAI 2014–2021
fig, ax = plt.subplots(1, 1, figsize=(6, 7))
gdf_panel.plot(column="CAI_pooled_2014_2021", cmap="viridis", legend=True, ax=ax)
ax.set_title("Composite CAI (cross-sectional, 2014–2021 pooled)\nWeighted by survey weights; Regions")
ax.axis("off")
plt.tight_layout()
plt.savefig(OUT / "map_pooled_2014_2021.png", dpi=200)
plt.close()

# (B) Dominant ecoscheme
if dom_tbl is not None:
    cmap = {
        "ES1": "#1f77b4",
        "ES2": "#2ca02c",
        "ES3": "#d62728",
        "ES4": "#ff7f0e",
        "ES5": "#9467bd",
        np.nan: "#aaaaaa"
    }
    fig, ax = plt.subplots(1, 1, figsize=(6, 7))
    gdf_panel.assign(color=gdf_panel["dominant_es"].map(cmap)).plot(color=gdf_panel["dominant_es"].map(cmap), ax=ax)
    # Legend
    import matplotlib.patches as mpatches
    patches = [mpatches.Patch(color=cmap[k], label=k if k==k else "No data") for k in ["ES1","ES2","ES3","ES4","ES5"]]
    ax.legend(handles=patches, title="Ecoscheme", loc="lower left")
    ax.set_title("Dominant Ecoscheme (normalized across regions) — Cross-sectional")
    ax.axis("off")
    plt.tight_layout()
    plt.savefig(OUT / "map_dominant_ecoscheme.png", dpi=200)
    plt.close()

# -------------------- 4) LOCAL STATS (Gi* and LISA) -------------------
print("Computing local spatial statistics for 2021...")
year_2021 = panel[panel[YEAR_COL] == 2021][[NUTS3_ID_COL, "CAI_region_year"]]
g_2021 = gdf.merge(year_2021, on=NUTS3_ID_COL, how="left").dropna(subset=["CAI_region_year"]).reset_index(drop=True)

# Build spatial weights (KNN k=6 for Gi* and LISA)
w_knn = KNN.from_dataframe(g_2021, k=6)
w_knn.transform = "R"

y = g_2021["CAI_region_year"].values

# ---- Getis-Ord Gi* (two-tailed, 999 perms) ----
gi = G_Local(y, w_knn, permutations=999, star=True)  # Gi*
sig = gi.p_sim < 0.05
hot = (gi.Zs > 0) & sig
cold = (gi.Zs < 0) & sig

g_2021["GiZ"] = gi.Zs
g_2021["GiP"] = gi.p_sim
g_2021["Gi_category"] = np.where(hot, "Hotspot", np.where(cold, "Coldspot", "Not sig."))

# Plot Gi*
fig, ax = plt.subplots(1, 1, figsize=(5.5, 7))
colors = {"Hotspot": "#b2182b", "Coldspot": "#2166ac", "Not sig.": "#f0f0f0"}
g_2021.plot(color=g_2021["Gi_category"].map(colors), ax=ax, edgecolor="white", linewidth=0.2)
# Legend
import matplotlib.patches as mpatches
ax.legend(handles=[mpatches.Patch(color=colors[k], label=k) for k in ["Hotspot","Coldspot","Not sig."]],
          loc="lower left", title="Gi* (2021)")
ax.set_title("Getis-Ord Gi* hotspots/coldspots — CAI, 2021\nKNN(k=6), 999 permutations")
ax.axis("off")
plt.tight_layout()
plt.savefig(OUT / "map_gistar_2021.png", dpi=220)
plt.close()

# ---- Local Moran's I (LISA) ----
lisa = Moran_Local(y, w_knn, permutations=999)
sig_lisa = lisa.p_sim < 0.05

# Quadrant encoding: 1 HH, 2 LH, 3 LL, 4 HL
quad = lisa.q.copy()
lab = np.array(["Not significant"] * len(quad), dtype=object)
lab[(quad == 1) & sig_lisa] = "High-High"
lab[(quad == 2) & sig_lisa] = "Low-High"
lab[(quad == 3) & sig_lisa] = "Low-Low"
lab[(quad == 4) & sig_lisa] = "High-Low"
g_2021["LISA_cat"] = lab

# Plot LISA
fig, ax = plt.subplots(1, 1, figsize=(5.5, 7))
lisa_colors = {
    "High-High": "#d73027",
    "Low-Low": "#2166ac",
    "High-Low": "#1a9850",
    "Low-High": "#762a83",
    "Not significant": "#d9d9d9"
}
g_2021.plot(color=g_2021["LISA_cat"].map(lisa_colors), ax=ax, edgecolor="white", linewidth=0.2)
ax.legend(handles=[mpatches.Patch(color=lisa_colors[k], label=k) for k in
                   ["High-High","Low-Low","High-Low","Low-High","Not significant"]],
          loc="lower left", title="LISA (2021)")
ax.set_title("Local Moran’s I (LISA) — Composite CAI, 2021")
ax.axis("off")
plt.tight_layout()
plt.savefig(OUT / "map_lisa_2021.png", dpi=220)
plt.close()

# Save tables
g_2021.drop(columns="geometry").to_csv(OUT / "local_stats_2021.csv", index=False)

# -------------------- 5) BAYESIAN STAR (CmdStan) ----------------------
def write_star_stan(path: Path):
    path.write_text(r"""
data {
  int<lower=1> N;               // regions
  int<lower=2> T;               // time points
  matrix[N,N] W;                // row-standardized spatial weights
  matrix[N,T] Y;                // observed CAI
}
parameters {
  real alpha;                   // drift
  real<lower=-1,upper=1> phi;   // temporal AR(1)
  real<lower=-1,upper=1> rho;   // spatial lag
  vector[N] u;                  // region random effects
  real<lower=0> sigma;          // obs noise
  real<lower=0> tau_u;          // RE scale
}
model {
  // Priors (weakly-informative / as described)
  rho ~ normal(0, 0.3);
  phi ~ normal(0, 0.3);
  alpha ~ normal(0, 2);
  tau_u ~ cauchy(0, 2);
  sigma ~ cauchy(0, 2);
  u ~ normal(0, tau_u);

  // Likelihood
  // t=1 (no lag)
  Y[,1] ~ normal(alpha + u, sigma);

  // t>=2 (STAR)
  for (t in 2:T) {
    vector[N] mu = alpha + u + phi * Y[,t-1] + rho * (W * Y[,t]);
    Y[,t] ~ normal(mu, sigma);
  }
}
generated quantities {
  // Posterior checks: not needed here; forecasts handled in Python
}
""".strip())

# Prepare matrix Y (N × T) and W (Queen contiguity) ---------------------
years_all = sorted(panel[YEAR_COL].unique())
T_obs = sum(y <= 2021 for y in years_all)
years_obs = [y for y in years_all if y <= 2021]
years_future = FORECAST_YEARS

# Build Y by consistent region order
regions = sorted(panel[NUTS3_ID_COL].unique())
N = len(regions)
Y = np.full((N, T_obs), np.nan)
for t_idx, yy in enumerate(years_obs):
    s = panel[panel[YEAR_COL] == yy].set_index(NUTS3_ID_COL)["CAI_region_year"]
    Y[:, t_idx] = [s.get(r, np.nan) for r in regions]

# Drop any region with missing throughout observed period
keep = ~np.isnan(Y).any(axis=1)
regions_kept = [r for r,k in zip(regions, keep) if k]
Y = Y[keep, :]
N = len(regions_kept)

gdf_kept = gdf.set_index(NUTS3_ID_COL).loc[regions_kept].reset_index()

# Spatial weights (Queen)
Wq = Queen.from_dataframe(gdf_kept, ids=gdf_kept[NUTS3_ID_COL].tolist())
W_mat = standardize_w_row(Wq)

# Fit Bayesian STAR 
fit_summ = None
posterior = None
draws = None

if HAVE_STAN:
    print("Fitting Bayesian STAR with CmdStanPy (4 chains × 2000 post-warmup)...")
    stan_file = OUT / "star_model.stan"
    write_star_stan(stan_file)
    model = CmdStanModel(stan_file=str(stan_file))
    data = {"N": N, "T": T_obs, "W": W_mat, "Y": Y}
    fit = model.sample(data=data, chains=4, parallel_chains=4,
                       seed=RANDOM_SEED, iter_sampling=2000, iter_warmup=1000,
                       show_progress=False)

    # Summaries (rho, phi, alpha, sigma, tau_u)
    try:
        fit_summ = fit.summary()
        fit_summ.to_csv(OUT / "bstar_summary_full.csv")
    except Exception:
        pass

    # Convert to ArviZ for neat table
    try:
        idata = az.from_cmdstanpy(fit)
        az.to_netcdf(idata, OUT / "bstar_idata.nc")
        table = az.summary(idata, var_names=["rho","phi","alpha","sigma","tau_u"], hdi_prob=0.94)
        table.to_csv(OUT / "bstar_posterior_summaries.csv")
        print("Saved posterior summaries to bstar_posterior_summaries.csv")
    except Exception:
        pass

    # Extract draws (means)
    draws = {
        "rho": fit.stan_variable("rho"),
        "phi": fit.stan_variable("phi"),
        "alpha": fit.stan_variable("alpha"),
        "sigma": fit.stan_variable("sigma"),
        "u": fit.stan_variable("u"),         # shape: (S, N)
    }
else:
    print("CmdStanPy not available; skipping Bayesian fit. (Forecasts below will be heuristic.)")

# -------------------- 6) FORECASTS 2022–2026 --------------------------
print("Forecasting 2022–2026...")
T_future = len(years_future)
Y_fore_mean = np.full((N, T_future), np.nan)
Y_fore_lo   = np.full((N, T_future), np.nan)
Y_fore_hi   = np.full((N, T_future), np.nan)

if HAVE_STAN and draws is not None:
    # Posterior simulation: for each draw, iterate forward using
    #   (I - rho W) y_t = alpha + u + phi y_{t-1} + eps
    S = len(draws["rho"])
    I = np.eye(N)
    # Start from last observed year vector
    y_prev = Y[:, -1].copy()

    # Collect across draws
    samples_store = [ [] for _ in range(T_future) ]   # list of arrays (S × N) per horizon
    for s in range(S):
        rho = draws["rho"][s]
        phi = draws["phi"][s]
        alpha = draws["alpha"][s]
        sig = max(1e-6, draws["sigma"][s])
        u_s = draws["u"][s, :]  # (N,)
        A = I - rho * W_mat

        y_curr = y_prev.copy()
        for h in range(T_future):
            mu = alpha + u_s + phi * y_curr
            eps = np.random.normal(0.0, sig, size=N)
            y_next = np.linalg.solve(A, mu + eps)
            samples_store[h].append(y_next.copy())
            y_curr = y_next

    # Aggregate to mean & 95% CI
    for h in range(T_future):
        smat = np.vstack(samples_store[h])  # (S × N)
        Y_fore_mean[:, h] = smat.mean(axis=0)
        Y_fore_lo[:, h]   = np.quantile(smat, 0.025, axis=0)
        Y_fore_hi[:, h]   = np.quantile(smat, 0.975, axis=0)

else:
    # Heuristic fallback: AR(1) on region mean with slight drift (use pooled estimates)
    print("Heuristic forecast: simple AR(1) per region with small drift (no uncertainty).")
    phi = 0.2
    drift = float(np.nanmean(Y[:, -1] - Y[:, -2])) if Y.shape[1] >= 2 else 0.001
    y_prev = Y[:, -1].copy()
    for h in range(T_future):
        y_next = (1 - phi) * y_prev + phi * np.nanmean(Y, axis=1) + drift
        Y_fore_mean[:, h] = y_next
        Y_fore_lo[:, h]   = y_next - 0.002
        Y_fore_hi[:, h]   = y_next + 0.002
        y_prev = y_next

# Save region-year forecast table
rows = []
for i, reg in enumerate(regions_kept):
    for k, yr in enumerate(years_future):
        rows.append({
            NUTS3_ID_COL: reg,
            "year": yr,
            "cai_mean": Y_fore_mean[i, k],
            "cai_lo95": Y_fore_lo[i, k],
            "cai_hi95": Y_fore_hi[i, k],
        })
forecast_tbl = pd.DataFrame(rows)
forecast_tbl.to_csv(OUT / "bstar_region_forecasts.csv", index=False)

# -------------------- 7) MULTI-PANEL MAPS (2014–2026) -----------------
print("Drawing uniform-legend maps for 2014–2026...")
# Collect matrices for uniform palette
vals_list = []
for y in years_obs:
    s = panel[panel[YEAR_COL] == y].set_index(NUTS3_ID_COL)["CAI_region_year"]
    gtmp = gdf.set_index(NUTS3_ID_COL).join(s, how="left").reset_index()
    vals_list.append(gtmp["CAI_region_year"].values)

for k, yr in enumerate(years_future):
    s = pd.Series(Y_fore_mean[:, k], index=regions_kept, name="cai")
    gtmp = gdf.set_index(NUTS3_ID_COL).join(s, how="left").reset_index()
    vals_list.append(gtmp["cai"].values)

vmin, vmax = common_color_limits(vals_list)

# Grid (3 rows × 4 cols = 12 panels to cover 2014..2026)
years_all_for_plot = years_obs + years_future
n_panels = len(years_all_for_plot)
ncols = 4
nrows = int(np.ceil(n_panels / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 11), subplot_kw=dict(aspect="equal"))
axes = axes.ravel()

for idx, yr in enumerate(years_all_for_plot):
    ax = axes[idx]
    if yr in years_obs:
        s = panel[panel[YEAR_COL] == yr].set_index(NUTS3_ID_COL)["CAI_region_year"]
    else:
        # forecast mean
        h = years_future.index(yr)
        s = pd.Series(Y_fore_mean[:, h], index=regions_kept)
    gtmp = gdf.set_index(NUTS3_ID_COL).join(s, how="left")
    gtmp.plot(column=s.name, cmap="YlGnBu", vmin=vmin, vmax=vmax, linewidth=0.2, edgecolor="white", ax=ax)
    ax.set_title(str(yr)); ax.axis("off")

# colorbar
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
sm = ScalarMappable(cmap="YlGnBu", norm=Normalize(vmin=vmin, vmax=vmax))
sm._A = []
cax = fig.add_axes([0.25, 0.08, 0.5, 0.02])
cb = fig.colorbar(sm, cax=cax, orientation="horizontal")
cb.set_label("Composite CAI (observed & provisional forecast)")

fig.suptitle("Composite CAI — Observed (2014–2021) and provisional forecast (2022–2026)\nUniform legend", y=0.98)
plt.tight_layout(rect=[0,0.05,1,0.96])
plt.savefig(OUT / "map_series_2014_2026.png", dpi=200)
plt.close()

# -------------------- 8) NATIONAL TIME-TRACK (MEAN + CI) --------------
print("Plotting national CAI trajectory...")
# Observed national mean
nat_obs = (panel[panel[YEAR_COL].isin(years_obs)]
           .groupby(YEAR_COL)["CAI_region_year"].mean()
           .reindex(years_obs))

# Forecast national mean + CI (region-averaged)
nat_fore_mean = np.nanmean(Y_fore_mean, axis=0)
nat_fore_lo   = np.nanmean(Y_fore_lo, axis=0)
nat_fore_hi   = np.nanmean(Y_fore_hi, axis=0)

fig, ax = plt.subplots(1, 1, figsize=(9.5, 5.5))
ax.plot(years_obs, nat_obs.values, "-o", label="Observed mean (2014–2021)")
ax.axvline(2021.5, color="k", linestyle=":", alpha=0.6)
ax.plot(years_future, nat_fore_mean, "--", label="Forecast mean")
ax.fill_between(years_future, nat_fore_lo, nat_fore_hi, alpha=0.2, label="95% CI (mean forecast)")
ax.set_title("Italy overall — CAI trajectory (2014–2026)")
ax.set_xlabel("Year"); ax.set_ylabel("Composite CAI")
ax.legend(loc="upper left")
plt.tight_layout()
plt.savefig(OUT / "national_cai_trend.png", dpi=220)
plt.close()

print("\nDone. Key outputs in:", OUT.resolve())
print(" - panel_region_year.csv")
print(" - map_pooled_2014_2021.png")
if dom_tbl is not None: print(" - map_dominant_ecoscheme.png")
print(" - map_gistar_2021.png, map_lisa_2021.png, local_stats_2021.csv")
print(" - bstar_posterior_summaries.csv (if Stan ran), bstar_region_forecasts.csv")
print(" - map_series_2014_2026.png, national_cai_trend.png")
