In [None]:
from typing import Optional, Tuple

import numpy as np

def _generate_spatial_basis(locations: np.ndarray) -> np.ndarray:
    # Spatial basis φ(s): fixed over time, later scaled by time-varying amplitude w_t
    spatial_basis = np.exp(-(locations ** 2))[:, None]
    spatial_basis /= (np.linalg.norm(spatial_basis) + 1e-12)
    return spatial_basis


def generate_time_NAR_synthetic_data(
    locs: np.ndarray,
    n_time_steps: int,
    noise_std: float,
    eigenvalue: float,
    eta_rho: float = 0.8,
    f_rho: float = 0.6,
    global_mean: float = 50.0,
    feature_noise_std: float = 0.0,
    non_linear_strength: float = 0.0,
    seed: Optional[int] = None,
    # ===================== NEW (NAR term) =====================
    y_rho: float = 0.4,                 # base linear AR coefficient ρ
    y_gamma: float = 1.2,               # nonlinear NAR strength γ (tanh term)
    nar_scale_window: int = 200,        # rolling window size for (mu, scale) used in gate + normalization
    gate_type: str = "A2_global_v2",    # gate variant name (for logging only in current code)
    gate_threshold: float = 0.0,        # threshold τ for global standardized state z_global
    gate_low: float = 0.5,              # low-regime multiplier for ρ
    gate_high: float = 1.0,             # high-regime multiplier for ρ
    debug_gate_ratio: bool = True,      # print fraction of time steps in high-regime
    # ==========================================================
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if seed is not None:
        np.random.seed(seed)

    locs = np.asarray(locs, dtype=np.float32)
    if locs.ndim != 1:
        raise ValueError(f"expects 1D locs (N,), got {locs.shape}")

    N = len(locs)
    p = 3

    # Continuous covariates x(t, i, k)
    cont = np.random.randn(n_time_steps, N, p).astype(np.float32)

    # Covariates with explicit spatio-temporal structure (trend × spatial pattern, etc.)
    spatial_pattern = np.sin(locs * np.pi / 5.0).astype(np.float32)
    temporal_trend = np.linspace(0, 2 * np.pi, n_time_steps).astype(np.float32)

    cont[:, :, 0] += (np.outer(temporal_trend, spatial_pattern) * 0.5).astype(np.float32)
    cont[:, :, 1] += (np.outer(temporal_trend, np.ones(N, dtype=np.float32)) * 0.3).astype(np.float32)
    cont[:, :, 2] += (cont[:, :, 0] * cont[:, :, 1] * 0.2).astype(np.float32)

    # Optional extra noise injected into features
    if feature_noise_std > 0:
        cont += (np.random.randn(*cont.shape) * feature_noise_std).astype(np.float32)

    # Spatial basis φ_i (fixed), later multiplied by time-varying spatial_weights[t]
    spatial_basis = _generate_spatial_basis(locs).reshape(-1).astype(np.float32)

    # Latent temporal processes:
    # - spatial_weights[t]: controls amplitude of spatial basis (time-varying spatial effect)
    # - trend_drift[t]: controls global temporal drift (time-varying mean component)
    spatial_weights = np.zeros(n_time_steps, dtype=np.float32)
    trend_drift = np.zeros(n_time_steps, dtype=np.float32)

    spatial_weights[0] = float(np.random.randn() * 1.0)
    trend_drift[0] = float(np.random.randn() * 1.0)

    # AR(1) evolution for latent temporal processes
    for t in range(1, n_time_steps):
        spatial_weights[t] = eta_rho * spatial_weights[t - 1] + float(np.random.randn() * 0.1)
        trend_drift[t] = f_rho * trend_drift[t - 1] + float(np.random.randn() * 0.1)

    y = np.zeros((n_time_steps, N), dtype=np.float32)
    eps = 1e-6

    def _feat_term(t: int) -> np.ndarray:
        # Covariate-driven effect h(x_t,i): linear + optional nonlinear augmentation
        feat = (0.3 * cont[t, :, 0] + 0.4 * cont[t, :, 1] + 0.2 * cont[t, :, 2]).astype(np.float32)
        if non_linear_strength > 0:
            # Optional nonlinear feature-target relationships (still "instantaneous" in time)
            feat = (
                feat
                + non_linear_strength * (cont[t, :, 0] ** 2)
                + non_linear_strength * 0.5 * cont[t, :, 0] * cont[t, :, 1]
                + non_linear_strength * 0.3 * np.sin(cont[t, :, 2])
            ).astype(np.float32)
        return feat

    # t = 0 initialization (no AR recursion yet)
    trend0 = global_mean + float(trend_drift[0])
    spatial0 = eigenvalue * float(spatial_weights[0]) * spatial_basis
    noise0 = (np.random.randn(N) * noise_std).astype(np.float32)
    y[0] = (trend0 + spatial0 + _feat_term(0) + noise0).astype(np.float32)

    # ===================== NEW (gate stats) =====================
    gate_cnt = 0.0  # counts how often gate=1 over t=1..T-1 (diagnostic)
    # ============================================================

    for t in range(1, n_time_steps):
        # Baseline components (same style as original DGP)
        trend = global_mean + float(trend_drift[t])                         # global mean + temporal drift
        spatial = eigenvalue * float(spatial_weights[t]) * spatial_basis    # spatial basis with time-varying amplitude
        noise = (np.random.randn(N) * noise_std).astype(np.float32)         # observation noise

        prev = y[t - 1]

        # ===================== NEW (rolling scale window) =====================
        # Use a rolling window to compute normalization statistics for:
        #   (1) gate decision (global regime switching)
        #   (2) stabilizing nonlinear term tanh(prev/scale)
        w0 = max(0, t - nar_scale_window)
        mu = float(np.mean(y[w0:t]))                    # rolling mean over past window
        scale = float(np.std(y[w0:t]) + eps)            # rolling std (+eps for stability)

        # Global standardized state (uses mean over sites):
        # z_global > threshold -> high regime (gate=1)
        z_global = float((np.mean(prev) - mu) / scale)

        gate = 1.0 if (z_global > gate_threshold) else 0.0
        gate_cnt += gate

        # Effective AR coefficient ρ_eff is modulated by gate (regime-dependent AR strength)
        # low regime:  ρ_eff = y_rho * gate_low
        # high regime: ρ_eff = y_rho * gate_high
        rho_eff = y_rho * (gate_low + (gate_high - gate_low) * gate)

        # NAR recursion term (adds y-driven temporal dependence):
        #   linear AR part:        rho_eff * prev
        #   nonlinear saturation:  y_gamma * tanh(prev/scale)
        rec = (rho_eff * prev + y_gamma * np.tanh(prev / scale)).astype(np.float32)
        # ============================================================

        # Full observation equation:
        # y_t = trend + spatial + covariate effect + NEW(recursion) + noise
        y[t] = (trend + spatial + _feat_term(t) + rec + noise).astype(np.float32)

    # ===================== NEW (debug print) =====================
    if debug_gate_ratio:
        gate_ratio = gate_cnt / float(n_time_steps - 1)
        print(
            f"[Gate {gate_type}] gate_ratio={gate_ratio:.4f} | "
            f"thr={gate_threshold} | low={gate_low} | high={gate_high} | win={nar_scale_window}"
        )
    # ============================================================

    cat = np.zeros((n_time_steps, N, 0), dtype=np.int64)
    return cat, cont, y
