In [None]:
import os 
os.chdir("..")

In [None]:
import polars as pl
import seaborn as sns
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
from src.data_process import DataReg

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

dr = DataReg(database_file="data.ddb")

In [None]:
df = dr.data_set()
df = df.filter(
    (pl.col("industry_code") == "72")
)
df = df.with_columns(
    total_employment=(pl.col("month1_emplvl") + pl.col("month2_emplvl") + pl.col("month3_emplvl")) / 3
)
remove = df.filter(pl.col("total_employment") == 0).select(pl.col("area_fips")).unique().to_series().to_list()

In [None]:
df = dr.data_set()
df = df.filter(
    (pl.col("industry_code") == "72") &
    # (~pl.col("area_fips").is_in(remove)) & 
    (pl.col("year") < 2020)

)


df = df.with_columns(
    date=pl.col("year").cast(pl.String) + "Q" + pl.col("qtr").cast(pl.String),
    dummy=pl.lit(1),
    area_fips= "i" + pl.col("area_fips"),
    total_employment=((pl.col("month1_emplvl") + pl.col("month2_emplvl") + pl.col("month3_emplvl")) /
    3).log(),
    # after_treatment=pl.when((pl.col("year") >= 2016) & (pl.col("qtr") > 1)).then(True).otherwise(False)
)
# df.filter(pl.col("area_fips") == "i06081")

In [None]:
# df = dr.data_set()
# df = df.filter(
#     (pl.col("industry_code") == "72") &
#     (~pl.col("area_fips").is_in(l)) & 
#     (~pl.col("area_fips").is_in(remove)) & 
#     (pl.col("year") < 2020)

# )


# df = df.with_columns(
#     date=pl.col("year").cast(pl.String) + "Q" + pl.col("qtr").cast(pl.String),
#     dummy=pl.lit(1),
#     area_fips= "i" + pl.col("area_fips"),
#     total_employment=((pl.col("month1_emplvl") + pl.col("month2_emplvl") + pl.col("month3_emplvl")) /
#     3).log(),
#     # after_treatment=pl.when((pl.col("year") >= 2016) & (pl.col("qtr") > 1)).then(True).otherwise(False)
# )
# df.filter(pl.col("area_fips") == "i06081")

In [None]:
df

In [None]:
data = df.select(pl.col("area_fips", "date", "total_employment", "avg_wkly_wage")).with_columns(controls=pl.when(pl.col("area_fips") == "i06081").then(True).otherwise(False)).to_pandas()
data["date"] = pd.PeriodIndex(df['date'], freq='Q').to_timestamp()
data['after_treatment'] = data['date'] > pd.to_datetime('2016-01-01')
data = data[(data["area_fips"].str.startswith("i06")) | (data["area_fips"] == "i06081")].reset_index(drop=True)
data

In [None]:
# data  = df.pivot(on="area_fips", index="date", values="total_employment").to_pandas()
# data["date"] = pd.PeriodIndex(data['date'], freq='Q').to_timestamp()
# data = data.set_index("date")

In [None]:
# target_county = "i06081"
# all_counties = data.columns
# other_counties = all_counties.difference({target_county})
# all_counties = list(all_counties)
# other_counties = list(other_counties)

In [None]:
# from scipy import stats
# y = data["i06081"].values
# x = data["i01007"].values
# l = []
# res = stats.pearsonr(x, y)
# for i in other_counties:
#     x = data[i].values
#     res = stats.pearsonr(x, y)
#     if res.pvalue < 0.001:
#         l.append(i)
# len(l)

In [None]:
fig, ax = plt.subplots()

(
    data.groupby(["date", "controls"], as_index=False)
    .agg({"total_employment": "mean"})
    .pipe(
        (sns.lineplot, "data"),
        x="date",
        y="total_employment",
        hue="controls",
        marker="o",
        ax=ax,
    )
)
ax.axvline(
    x=pd.to_datetime("2016-01-01"),
    linestyle=":",
    lw=2,
    color="C2",
    label="Iplementation of minimum wage",
)

ax.legend(loc="upper right")
ax.set(
    title="Employment",
    ylabel="total employment trend Trend"
)


In [None]:
features = ["total_employment"]
pre_df = (
    data
    .query("~ after_treatment")
    .pivot(index='area_fips', columns="date", values=features)
    .T
)

post_df = (
    data
    .query("after_treatment")
    .pivot(index='area_fips', columns="date", values=features)
    .T
)

In [None]:
idx = "i06081"

y_pre = pre_df[idx].to_numpy()
x_pre = pre_df.drop(columns=idx).to_numpy()
pre_years = pre_df.reset_index(inplace=False).date.unique()
n_pre = pre_years.size

y_post = post_df[idx].to_numpy()
x_post = post_df.drop(columns=idx).to_numpy()
post_years = post_df.reset_index(inplace=False).date.unique()
n_post = post_years.size

k = pre_df.shape[1] - 1

In [None]:
with pm.Model() as model:
    x = pm.Data(name="x", value=x_pre)
    y = pm.Data(name="y", value=y_pre)
    beta = pm.Dirichlet(name="beta", a=(1 / k) * np.ones(k))
    sigma = pm.HalfNormal(name="sigma", sigma=5)
    mu = pm.Deterministic(name="mu", var=pm.math.dot(x, beta))
    likelihood = pm.Normal(name="likelihood", mu=mu, sigma=sigma, observed=y)

pm.model_to_graphviz(model)

In [None]:
with model:
    idata = pm.sample(nuts_sampler="blackjax")
    posterior_predictive_pre = pm.sample_posterior_predictive(trace=idata)

In [None]:
with model:
    pm.set_data(new_data={"x": x_post, "y": y_post})
    posterior_predictive_post = pm.sample_posterior_predictive(
        trace=idata, var_names=["likelihood"]
    )

In [None]:
pre_posterior_mean = (
    posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre]
    .stack(samples=("chain", "draw"))
    .mean(axis=1)
)

post_posterior_mean = (
    posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post]
    .stack(samples=("chain", "draw"))
    .mean(axis=1)
)


fig, ax = plt.subplots()

(
    data.groupby(["date", "controls"], as_index=False)
    .agg({"total_employment": "mean"})
    .assign(
        california=lambda x: x.controls.map(
            {True: "is_california", False: "is_not_california"}
        )
    )
    .pipe(
        (sns.lineplot, "data"),
        x="date",
        y="total_employment",
        hue="controls",
        alpha=0.5,
        ax=ax,
    )
)
ax.axvline(
    x=pd.to_datetime("2016-01-01"),
    linestyle=":",
    lw=2,
    color="C2",
    label="Proposition 99",
)
sns.lineplot(
    x=pre_years,
    y=pre_posterior_mean,
    color="C1",
    marker="o",
    label="pre-treatment posterior predictive mean",
    ax=ax,
)
sns.lineplot(
    x=post_years,
    y=post_posterior_mean,
    color="C2",
    marker="o",
    label="post-treatment posterior predictive mean",
    ax=ax,
)
az.plot_hdi(
    x=pre_years,
    y=posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre],
    smooth=False, 
    color="C1",
    fill_kwargs={"label": "pre-treatment posterior predictive (94% HDI)"},
    ax=ax,
)

az.plot_hdi(
    x=post_years,
    y=posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post],
    smooth=False,  
    color="C2",
    fill_kwargs={"label": "post-treatment posterior predictive (94% HDI)"},
    ax=ax,
)
ax.legend(loc="lower left")
ax.set(
    title="Gap in per-capita cigarette sales (in packs)", ylabel="Cigarette Sales Trend"
)

In [None]:
# Data Aggregation and Grouping
data_grouped = (
    data.groupby(["date", "controls"])
    .agg({"total_employment": "mean"})
    .reset_index()
)
data_grouped["is_county"] = data_grouped.controls.map({True: "San Mateo", False: ""})

# Plotting
fig, ax = plt.subplots()


sns.lineplot(
    data=data_grouped[data_grouped["is_county"] != ""],  
    x="date",
    y="total_employment",
    hue="is_county",
    alpha=0.5,
    ax=ax,
)


ax.axvline(
    x=pd.to_datetime("2016-01-01"),
    linestyle=":",
    lw=2,
    color="C2",
    label="Incremental MW",
)


sns.lineplot(
    x=pre_years,
    y=pre_posterior_mean,
    color="C1",
    marker="o",
    label="Pre-treatment posterior predictive mean",
    ax=ax,
)

sns.lineplot(
    x=post_years,
    y=post_posterior_mean,
    color="C2",
    marker="o",
    label="Post-treatment posterior predictive mean",
    ax=ax,
)


az.plot_hdi(
    x=pre_years,
    y=posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre],
    smooth=False,  
    color="C1",
    fill_kwargs={"label": "Pre-treatment posterior predictive (94% HDI)"},
    ax=ax,
)

az.plot_hdi(
    x=post_years,
    y=posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post],
    smooth=False, 
    color="C2",
    fill_kwargs={"label": "Post-treatment posterior predictive (94% HDI)"},
    ax=ax,
)

ax.legend(loc="upper left")
ax.set(
    title="Sythetic control on San Mateo County",
    ylabel="Employment"
)

plt.show()


In [None]:

effect_pre = y_pre[:n_pre] - pre_posterior_mean
effect_post = y_post[:n_post] - post_posterior_mean


fig, ax = plt.subplots()


ax.axvline(
    x=pd.to_datetime("2016-01-01"),
    linestyle=":",
    lw=2,
    color="C2",
    label="Incremental MW",
)


sns.lineplot(
    x=pre_years,
    y=effect_pre,
    color="C1",
    marker="o",
    label="Pre-treatment posterior predictive effect mean",
    ax=ax,
)
sns.lineplot(
    x=post_years,
    y=effect_post,
    color="C2",
    marker="o",
    label="Post-treatment posterior predictive effect mean",
    ax=ax,
)


az.plot_hdi(
    x=pre_years,
    y=y_pre[:n_pre] - posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre],
    smooth=False,
    color="C1",
    fill_kwargs={"label": "Pre-treatment posterior predictive effect (94% HDI)"},
    ax=ax,
)
az.plot_hdi(
    x=post_years,
    y=y_post[:n_post] - posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post],
    smooth=False,
    color="C2",
    fill_kwargs={"label": "Post-treatment posterior predictive effect (94% HDI)"},
    ax=ax,
)


ax.axhline(y=0.0, color="black", linestyle="--", label="Zero effect")


ax.legend(loc="lower left")
ax.set(
    title="San Mateo County - Synthetic Control Effect Over Time",
    ylabel="Gap in total employment",
)

plt.show()


In [None]:
effect_distribution = (
    y_post[:n_post] 
    - posterior_predictive_post.posterior_predictive["likelihood"][:, :, :n_post]
)[:, :, -1]

# Reshape for seaborn plotting
g = (
    effect_distribution
    .stack(samples=("chain", "draw"))
    .pipe((sns.displot, "data"), kde=True, height=4.5, aspect=1.5)
)

# Set title with appropriate label
g.set(title="Reduction in employment at treatment (Jan 2016)")

# Optionally, label the x-axis
g.set_axis_labels("Estimated treatment effect", "Density")


In [None]:

def run_synthetic_control(
    pre_df: pd.DataFrame, post_df: pd.DataFrame, idx: int
) -> tuple:
    # prepare data
    y_pre = pre_df[idx].to_numpy()
    x_pre = pre_df.drop(columns=idx).to_numpy()
    pre_years = pre_df.reset_index(inplace=False).date.unique()
    n_pre = pre_years.size

    y_post = post_df[idx].to_numpy()
    x_post = post_df.drop(columns=idx).to_numpy()
    post_years = post_df.reset_index(inplace=False).date.unique()
    n_post = post_years.size

    k = pre_df.shape[1] - 1

    # specify the model
    with pm.Model() as model:
        x = pm.MutableData(name="x", value=x_pre)
        y = pm.MutableData(name="y", value=y_pre)

        beta = pm.Dirichlet(name="beta", a=(1 / k) * np.ones(k))
        sigma = pm.HalfNormal(name="sigma", sigma=5)
        mu = pm.Deterministic(name="mu", var=pm.math.dot(x, beta))
        likelihood = pm.Normal(name="likelihood", mu=mu, sigma=sigma, observed=y)

        # fit the model
        idata = pm.sample(nuts_sampler="blackjax")
        
        posterior_predictive_pre = pm.sample_posterior_predictive(trace=idata)
        # post-treatment predictive distribution
        pm.set_data(new_data={"x": x_post, "y": y_post})
        posterior_predictive_post = pm.sample_posterior_predictive(
            trace=idata, var_names=["likelihood"]
        )

        # compute errors
        error_pre = (
            y_pre[:n_pre]
            - posterior_predictive_pre.posterior_predictive["likelihood"][:, :, :n_pre]
        )
        error_post = (
            y_post[:n_post]
            - posterior_predictive_post.posterior_predictive["likelihood"][
                :, :, :n_post
            ]
        )

    return error_pre, error_post

In [None]:
from tqdm.notebook import tqdm
results = {
    idx: run_synthetic_control(pre_df=pre_df, post_df=post_df, idx=idx)
    for idx in tqdm(data["area_fips"].unique())
}

In [None]:
fig, ax = plt.subplots()

for idx in data["area_fips"].unique():
    error_pre, error_post = results[idx]
    sigma_pre = error_pre.stack(samples=("chain", "draw")).std(axis=1).min().item()
    if sigma_pre < 10:
        color = "C6" if idx == "i06081" else "gray"
        alpha = 1 if idx == "i06081" else 0.3
        label = "San Mateo County" if idx == "i06081" else None
        sns.lineplot(
            x=pre_years,
            y=error_pre.stack(samples=("chain", "draw")).mean(axis=1),
            color=color,
            alpha=alpha,
            ax=ax,
        )
        sns.lineplot(
            x=post_years,
            y=error_post.stack(samples=("chain", "draw")).mean(axis=1),
            color=color,
            alpha=alpha,
            label=label,
            ax=ax,
        )

ax.axhline(y=0.0, color="black", linestyle="--", label="zero")
ax.legend(loc="lower left")
ax.set(
    title="State - Synthetic Across Time",
    ylabel="Employment over time",
)

In [None]:
fig, ax = plt.subplots()

for idx in data["area_fips"].unique():
    error_pre, error_post = results[idx]
    sigma_pre = error_pre.stack(samples=("chain", "draw")).std(axis=1).min().item()
    if sigma_pre < 10:
        color = "C6" if idx == "i06081" else "gray"
        alpha = 1 if idx == "i06081" else 0.05
        label = "california" if idx == "i06081" else None
        az.plot_hdi(
            x=pre_years,
            y=error_pre,
            smooth=False,
            color=color,
            fill_kwargs={"alpha": alpha},
            ax=ax,
        )
        az.plot_hdi(
            x=post_years,
            y=error_post,
            smooth=False,
            color=color,
            fill_kwargs={"alpha": alpha, "label": label},
            ax=ax,
        )

ax.axhline(y=0.0, color="black", linestyle="--", label="zero")
ax.legend(loc="lower left")
ax.set(
    title="State - Synthetic Across Time",
    ylabel="Gap in per-capita cigarette sales (in packs)",
)