# Prior sensitivity


In [None]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm


# Define different prior specifications to test
prior_specs = {
    "default_priors": {
        "mu_alpha_mu": 0, "mu_alpha_sigma": 1,
        "sigma_alpha_sigma": 1,
        "mu_rho_mu": 0, "mu_rho_sigma": 1,
        "sigma_rho_sigma": 1,
        "sigma_likelihood_sigma": 1
    },
    "wider_priors": {
        "mu_alpha_mu": 0, "mu_alpha_sigma": 5, # Wider
        "sigma_alpha_sigma": 2,               # Wider
        "mu_rho_mu": 0, "mu_rho_sigma": 5,    # Wider
        "sigma_rho_sigma": 2,                 # Wider
        "sigma_likelihood_sigma": 2           # Wider
    },
    "narrower_priors": {
        "mu_alpha_mu": 0, "mu_alpha_sigma": 0.5, # Narrower
        "sigma_alpha_sigma": 0.5,                # Narrower
        "mu_rho_mu": 0, "mu_rho_sigma": 0.5,     # Narrower
        "sigma_rho_sigma": 0.5,                  # Narrower
        "sigma_likelihood_sigma": 0.5            # Narrower
    },
    # Add more scenarios as needed, e.g., priors centered differently
    "shifted_rho_prior": {
        "mu_alpha_mu": 0, "mu_alpha_sigma": 1,
        "sigma_alpha_sigma": 1,
        "mu_rho_mu": 0.5, "mu_rho_sigma": 1,  # Shifted mean for rho
        "sigma_rho_sigma": 1,
        "sigma_likelihood_sigma": 1
    }
}

# Dictionary to store inference data for each prior specification
idata_results = {}

print("--- Starting Prior Sensitivity Analysis ---")

for prior_name, priors in prior_specs.items():
    print(f"\nRunning model with {prior_name}...")

    with pm.Model(coords=coords) as current_hierarchical_ar_model:
        # Hyperparameters for hierarchical prior distributions
        # Non-centered parametrization for mu_alpha (overall mean of intercepts)
        # and sigma_alpha (standard deviation of intercepts)
        mu_alpha_raw = pm.Normal('mu_alpha_raw', mu=priors["mu_alpha_mu"], sigma=priors["mu_alpha_sigma"])
        sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=priors["sigma_alpha_sigma"])
        mu_alpha = pm.Deterministic('mu_alpha', mu_alpha_raw * sigma_alpha)

        # Non-centered parametrization for mu_rho (overall mean of AR coefficients)
        # and sigma_rho (standard deviation of AR coefficients)
        mu_rho_raw = pm.Normal('mu_rho_raw', mu=priors["mu_rho_mu"], sigma=priors["mu_rho_sigma"])
        sigma_rho = pm.HalfNormal('sigma_rho', sigma=priors["sigma_rho_sigma"])
        mu_rho = pm.Deterministic('mu_rho', mu_rho_raw * sigma_rho)

        # Hierarchical parameters for each state (non-centered parametrization)
        # Intercepts for each state
        alpha_raw = pm.Normal('alpha_raw', mu=0, sigma=1, dims='state') # These are usually kept standard
        alpha = pm.Deterministic('alpha', mu_alpha + alpha_raw * sigma_alpha, dims='state')

        # AR(1) coefficients for each state
        rho_raw = pm.Normal('rho_raw', mu=0, sigma=1, dims='state') # These are usually kept standard
        rho_unconstrained = mu_rho + rho_raw * sigma_rho
        rho = pm.Deterministic('rho', pm.Deterministic(
            'rho_transformed', pm.math.tanh(rho_unconstrained)), dims='state')

        # Standard deviation for the innovations (error term) for each state
        sigma = pm.HalfNormal('sigma', sigma=priors["sigma_likelihood_sigma"], dims='state')

        # Precompute lagged values for all states
        inflation_lagged = inflation_padded[:, :-1]
        inflation_current = inflation_padded[:, 1:]
            
        # Calculate expected mean for each state's time series
        mu_inflation = (alpha[:, None] + rho[:, None] * inflation_lagged)

        # Filter out NaN values from observed data and corresponding predictions and masks
        observed_values_flat = inflation_current[mask[:, 1:]]
        mu_inflation_flat = mu_inflation[mask[:, 1:]]
            
        # Get the corresponding sigma for each valid observed value.
        sigma_expanded = np.tile(sigma.eval()[:, None], (1, max_time_len - 1))
        sigma_flat = sigma_expanded[mask[:, 1:]]

        # Likelihood for the observed inflation values
        pm.Normal('inflation_likelihood',
                    mu=mu_inflation_flat,
                    sigma=sigma_flat,
                    observed=observed_values_flat)

    with current_hierarchical_ar_model:
        idata_results[prior_name] = pm.sample(
            draws=1000,   # Reduced draws for faster testing, increase for final analysis
            tune=1000,    # Reduced tune for faster testing
            chains=2,
            cores=2,
            random_seed=42,
            target_accept=0.9, # Can be slightly less stringent for testing
            max_treedepth=10   # Can be slightly less stringent for testing
        )

# --- Compare Results ---
print("\n--- Comparing Posterior Summaries for Different Priors ---")
for prior_name, idata in idata_results.items():
    print(f"\n--- Results for {prior_name} ---")
    print(pm.summary(idata, var_names=['mu_alpha', 'sigma_alpha', 'mu_rho', 'sigma_rho']))

# --- Optional: Visualize Prior vs. Posterior ---
# You can also plot prior and posterior distributions to see the influence
# of the data vs. the prior. This requires defining the priors as PyMC distributions
# in a way that allows sampling them or plotting their PDFs.
# For simplicity here, we'll just compare the posterior summaries.

# --- Optional: Compare Forecasts ---
print("\n--- Comparing Forecasts for Different Priors ---")
all_forecast_dfs = []
for prior_name, idata in idata_results.items():
    posterior_alpha_sens = idata.posterior["alpha"].mean(("chain", "draw")).values
    posterior_rho_sens = idata.posterior["rho"].mean(("chain", "draw")).values
    
    last_observed_values_sens = df_long.groupby('state')['value'].last().values
    num_states = len(states)
    forecasted_inflation_sens = np.zeros(num_states)
    for i in range(num_states):
        forecasted_inflation_sens[i] = posterior_alpha_sens[i] + posterior_rho_sens[i] * last_observed_values_sens[i]
    
    forecast_df_sens = pd.DataFrame({
        'State': states,
        'Forecasted_Inflation_Next_Period': forecasted_inflation_sens
    })
    forecast_df_sens.rename(columns={'Forecasted_Inflation_Next_Period': f'Forecast_{prior_name}'}, inplace=True)
    all_forecast_dfs.append(forecast_df_sens)

# Merge all forecast DataFrames
if all_forecast_dfs:
    final_forecast_comparison = all_forecast_dfs[0]
    for i in range(1, len(all_forecast_dfs)):
        final_forecast_comparison = pd.merge(final_forecast_comparison, all_forecast_dfs[i], on='State', how='left')
    print(final_forecast_comparison)