In [1]:
import os
os.chdir("..")

In [2]:
from src.data.utils import DiffReg
import pymc as pm
import pandas as pd
import polars as pl
import numpy as np
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import causalpy as cp

az.style.use("arviz-darkgrid")

dr = DiffReg()

In [None]:
master = dr.data_set(naics="72", foreign=False)
master

ColumnNotFoundError: unable to find column "naics_code"; valid columns: ["area_fips", "year", "qtr", "industry_code", "month1_emplvl", "month2_emplvl", "month3_emplvl", "avg_wkly_wage", "fips", "state_name", "min_wage", "geometry", "total_employment"]

In [None]:
df = master
df = df.filter((pl.col("fips") == "72") | (pl.col("fips") == "20"))
df = df.filter(pl.col("area_fips") != "20127")
df = df.with_columns(
    treated=pl.when(pl.col("fips") == "72").then(1).otherwise(0),
    date=pl.col("year").cast(pl.String) + "Q" + pl.col("qtr").cast(pl.String),
    post_treatment= pl.when(pl.col("year") >= 2023).then(0).otherwise(1),
    group=pl.col("area_fips").rank("dense").cast(pl.Int32),
    log_total_employment=pl.col("total_employment").log()
).to_pandas()
df['date'] = pd.PeriodIndex(df['date'], freq='Q').start_time
df = df[(df["date"] >= pd.to_datetime("2014-01-01")) & (df["date"] < pd.to_datetime("2024-01-01")) ]
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna()
df["group_idx"] = df["group"].astype("category").cat.codes
df["time_idx"] = df["date"].astype("category").cat.codes
n_groups = df["group_idx"].nunique()
n_times = df["time_idx"].nunique()
df = df.reset_index(drop=True)
df

In [None]:
with pm.Model() as model:
    # Set dimensions
    coords = {
        "obs_idx": df.index.values,
        "time_idx": df["time_idx"].unique(),
        "group_idx": df["group_idx"].unique(),
    }
    for name, values in coords.items():
        model.add_coord(name, values)

    # Data
    time_idx = df["time_idx"].astype("category").cat.codes.values
    group_idx = df["group_idx"].astype("category").cat.codes.values

    pm.Data("time_idx", time_idx, dims="obs_idx")
    pm.Data("group_idx", group_idx, dims="obs_idx")
    pm.Data("post_treatment", df["post_treatment"].values, dims="obs_idx")

    # Priors
    beta_0 = pm.Normal("beta_0", mu=0, sigma=5)

    beta_time = pm.Normal("beta_year", mu=0, sigma=1, dims="time_idx")
    beta_area = pm.Normal("beta_area", mu=0, sigma=1, dims="group_idx")

    beta_p = pm.Normal("beta_p", mu=0, sigma=1)
    beta_delta = pm.Normal("beta_delta", mu=0, sigma=1)

    sigma = pm.HalfNormal("sigma", sigma=1)

    # Linear predictor
    mu = pm.Deterministic(
        "mu",
        beta_0
        + beta_time[time_idx]
        + beta_area[group_idx]
        + beta_p * df["post_treatment"].values
        + beta_delta * df["treated"].values * df["post_treatment"].values,
        dims="obs_idx"
    )

    # Likelihood (on log employment)
    pm.Normal("obs", mu=mu, sigma=sigma, observed=df["log_total_employment"].values, dims="obs_idx")

    # Sampling
    idata = pm.sample(chains=10, cores=10)


In [None]:
with pm.Model() as model:
    coords = {
        "obs_id": df.index.values,
        "time_idx": df["time_idx"].unique(),
        "group_idx": df["group_idx"].unique(),
    }
    model.add_coord("obs_id", coords["obs_id"])
    model.add_coord("time_idx", coords["time_idx"])
    model.add_coord("group_idx", coords["group_idx"])

    # Data containers (bind to mutable "obs_id")
    pm.Data("time_idx", df["time_idx"].values, dims="obs_id")
    pm.Data("group_idx", df["group_idx"].values, dims="obs_id")
    pm.Data("post_treatment", df["post_treatment"].values, dims="obs_id")
    pm.Data("treated", df["treated"].values, dims="obs_id")

    # Priors
    beta_0 = pm.Normal("beta_0", mu=0, sigma=5)
    beta_time = pm.Normal("beta_time", mu=0, sigma=1, dims="time_idx")
    beta_area = pm.Normal("beta_area", mu=0, sigma=1, dims="group_idx")
    beta_p = pm.Normal("beta_p", mu=0, sigma=1)
    beta_delta = pm.Normal("beta_delta", mu=0, sigma=1)
    sigma = pm.HalfNormal("sigma", sigma=1)

    mu = pm.Deterministic(
        "mu",
        beta_0
        + beta_time[df["time_idx"].values]
        + beta_area[df["group_idx"].values]
        + beta_p * df["post_treatment"].values
        + beta_delta * df["treated"].values * df["post_treatment"].values,
        dims="obs_id"
    )

    pm.Normal("obs", mu=mu, sigma=sigma, observed=df["log_total_employment"].values, dims="obs_id")

    idata = pm.sample()


In [None]:
pm.model_to_graphviz(model)

In [None]:
pm.model_to_graphviz(model)

In [None]:
az.plot_trace(idata, var_names="~mu");

In [None]:
def is_post_treatment(time_idx_array, intervention_time_idx):
    return (np.array(time_idx_array) >= intervention_time_idx).astype(int)


In [None]:
def pushforward_prediction(model, idata, time_idx_array, group_idx_val, treated_val):
    n = len(time_idx_array)
    post_treatment = is_post_treatment(time_idx_array, intervention_time)

    with model:
        # Update model data for new inputs
        for name, value in {
            "time_idx": time_idx_array,
            "group_idx": [group_idx_val] * n,
            "post_treatment": post_treatment,
            "treated": [treated_val] * n,
        }.items():
            model.set_data(name, value, coords={"obs_id": np.arange(n)})

        # Override dims and coords for the new prediction
        return pm.sample_posterior_predictive(
            idata,
            var_names=["mu"],
            coords={"obs_id": np.arange(n)},
        )


In [None]:
t_counterfactual = np.arange(36, 40)  # only post-intervention

group = [1] * len(t_counterfactual)
treated = [0] * len(t_counterfactual)  # pretend they weren't treated

with model:
    pm.set_data({
        "post_treated": t_counterfactual,
        "group": group,
        "treated": treated
    })
    ppc_counterfactual = pm.sample_posterior_predictive(idata, var_names=["mu"])


In [None]:
# Rename columns if needed for plotting
df_plot = df.copy()
df_plot["t"] = df_plot["post_treatment"]
df_plot["y"] = df_plot["total_employment"]
df_plot["group"] = df_plot["group_idx"]  # or use a string label if preferred
df_plot["group"] = df_plot["group_idx"].map({0: "control", 1: "treatment"})

# Initialize the plot
fig, ax = plt.subplots(figsize=(10, 6))
sns.scatterplot(data=df_plot, x="t", y="y", hue="group", ax=ax, alpha=0.6)

# Plot HDI for control group
az.plot_hdi(
    x=ti,
    y=ppc_control.posterior_predictive["mu"],
    smooth=False,
    ax=ax,
    color="blue",
    fill_kwargs={"label": "control HDI", "alpha": 0.3},
)

# Plot HDI for treatment group
az.plot_hdi(
    x=ti,
    y=ppc_treatment.posterior_predictive["mu"],
    smooth=False,
    ax=ax,
    color="C1",
    fill_kwargs={"label": "treatment HDI", "alpha": 0.3},
)

# Plot HDI for treatment counterfactual
az.plot_hdi(
    x=t_counterfactual,
    y=ppc_counterfactual.posterior_predictive["mu"],
    smooth=False,
    ax=ax,
    color="C2",
    fill_kwargs={"label": "counterfactual", "alpha": 0.3},
)

# Add vertical line for intervention
ax.axvline(x=intervention_time, ls="--", color="red", label="Intervention", lw=2)

# Final formatting
ax.set(
    xlabel="Post-Treatment Time",
    ylabel="Total Employment",
    title="Difference-in-Differences with Counterfactual",
)
ax.legend()
plt.tight_layout()
plt.show()
