In [None]:
import os 
os.chdir("..")

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from src.data_process import DataReg

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

dr = DataReg(database_file="data.ddb")

In [None]:
target = "i24031"
date = "2017-01-01"
naics = "72"

In [None]:
gdf = dr.pull_county_shapes()
gdf["area_fips"] = "i" + gdf["area_fips"]

In [None]:
controls2 = gdf[gdf["fips"] == "17"]["area_fips"].to_list()

In [None]:
controls = dr.controls_list(target=target, amount=200, naics=naics)
data = dr.synth_data(controls=controls, target=target, date=date, naics=naics)
dr.synth_freq(controls=controls, target=target, date=date, naics=naics)

In [None]:
target = "i24031"
date = "2017-01-01"
naics = "11"
controls = dr.controls_list(target=target, amount=200, naics=naics)
data = dr.synth_data(controls=controls, target=target, date=date, naics=naics)
dr.synth_freq(controls=controls, target=target, date=date, naics=naics)

In [None]:
target = "i24031"
date = "2017-01-01"
naics = "11"
controls = dr.controls_list(target=target, amount=200, naics=naics)
data = dr.synth_data(controls=controls, target=target, date=date, naics=naics)
dr.synth_freq(controls=controls, target=target, date=date, naics=naics)

In [None]:
dr.synth_bayes(controls=controls, target=target, date=date, naics=naics)


In [None]:
features = ["total_employment"]
pre_df = (
    data
    .query("~after_treatment")
    .pivot(index='area_fips', columns="date", values=features)
    .T
).dropna(axis=1)

post_df = (
    data
    .query("after_treatment")
    .pivot(index='area_fips', columns="date", values=features)
    .T
)
pre_df = pre_df.dropna(axis=1)
post_df = post_df.dropna(axis=1)

controls = list(set(pre_df.columns) & set(post_df.columns))

pre_df = pre_df[controls]
post_df = post_df[controls]


In [None]:
data = data[data["area_fips"].isin(controls)].reset_index(drop=True)
data

In [None]:
fig, ax = plt.subplots()

(
    data.groupby(["date", "controls"], as_index=False)
    .agg({"total_employment": "mean"})
    .pipe(
        (sns.lineplot, "data"),
        x="date",
        y="total_employment",
        hue="controls",
        marker="o",
        ax=ax,
    )
)
ax.axvline(
    x=pd.to_datetime(date),
    linestyle=":",
    lw=2,
    color="C2",
    label="Iplementation of minimum wage",
)

ax.legend(loc="upper right")
ax.set(
    title="Employment",
    ylabel="total employment trend Trend"
)


In [None]:
features = ["total_employment"]

inverted = (data.query("~after_treatment") # filter pre-intervention period
            .pivot(index='area_fips', columns="date")[features] # make one column per year and one row per state
            .T) # flip the table to have one column per state

inverted.head()

In [None]:
y = inverted[target].values 
X = inverted.drop(columns=target).values 

In [None]:
from typing import List
from operator import add
from toolz import reduce, partial

def loss_w(W, X, y) -> float:
    return np.sqrt(np.mean((y - X.dot(W))**2))

In [None]:
from scipy.optimize import fmin_slsqp

def get_w(X, y):
    
    w_start = [1/X.shape[1]]*X.shape[1]

    weights = fmin_slsqp(partial(loss_w, X=X, y=y),
                         np.array(w_start),
                         f_eqcons=lambda x: np.sum(x) - 1,
                         bounds=[(0.0, 1.0)]*len(w_start),
                         disp=False)
    return weights

In [None]:
calif_weights = get_w(X, y)
print("Sum:", calif_weights.sum())
calif_weights

In [None]:
calif_synth = data.query("~controls").pivot(index='date', columns="area_fips")["total_employment"].values.dot(calif_weights)
calif_synth

In [None]:
plt.figure(figsize=(10,6))
plt.plot(data.query("controls")["date"], data.query("controls")["total_employment"], label="California")
plt.plot(data.query("controls")["date"], calif_synth, label="Synthetic Control")
plt.ylabel("Per-capita cigarette sales (in packs)")
plt.legend()

In [None]:
plt.figure(figsize=(10,6))
plt.plot(data.query("controls")["date"], data.query("controls")["total_employment"] - calif_synth,
         label="California Effect")
# plt.vlines(x=pd.to_datetime("2016-01-01"), ymin=-30, ymax=7, linestyle=":", lw=2, label="Proposition 99")
plt.hlines(y=0, xmin=pd.to_datetime("2014-01-01"), xmax=pd.to_datetime("2020-01-01"), lw=2)
plt.title("State - Synthetic Across Time")
plt.ylabel("")
plt.legend()

In [None]:
def synthetic_control(area_fips: int, data: pd.DataFrame) -> np.array:
    
    features = ["total_employment"]
    
    inverted = (data.query("~after_treatment")
                .pivot(index='area_fips', columns="date")[features]
                .T)
    
    y = inverted[area_fips].values # treated
    X = inverted.drop(columns=area_fips).values # donor pool

    weights = get_w(X, y)
    synthetic = (data.query(f'~(area_fips=="{area_fips}")')
                 .pivot(index='date', columns="area_fips")["total_employment"]
                 .values.dot(weights))

    return (data
            .query(f'area_fips=="{area_fips}"')[["area_fips", "date", "total_employment", "after_treatment"]]
            .assign(synthetic=synthetic)).reset_index(drop=True)

In [None]:
synthetic_control(target, data).head()

In [None]:
from joblib import Parallel, delayed

control_pool = data["area_fips"].unique()

parallel_fn = delayed(partial(synthetic_control, data=data))

synthetic_states = Parallel(n_jobs=8)(parallel_fn(area_fips) for area_fips in control_pool)

In [None]:
synthetic_states[0].head()

In [None]:
plt.figure(figsize=(12,7))
for area_fips in synthetic_states:
    plt.plot(area_fips["date"], area_fips["total_employment"] - area_fips["synthetic"], color="C5",alpha=0.4)

plt.plot(data.query("controls")["date"], data.query("controls")["total_employment"] - calif_synth,
        label="California")

plt.vlines(x=pd.to_datetime("2016-01-01"), ymin=-0.25, ymax=0.25, linestyle=":", lw=2, label="Proposition 99")
plt.hlines(y=0, xmin=pd.to_datetime("2014-01-01"), xmax=pd.to_datetime("2020-01-01"), lw=3)
plt.ylabel("Gap in per-capita cigarette sales (in packs)")
plt.title("State - Synthetic Across Time")
plt.legend()

In [None]:
def pre_treatment_error(area_fips):
    pre_treat_error = (area_fips.query("~after_treatment")["total_employment"] 
                       - area_fips.query("~after_treatment")["synthetic"]) ** 2
    return pre_treat_error.mean()

In [None]:
plt.figure(figsize=(12,7))
for area_fips in synthetic_states:
    
    if pre_treatment_error(area_fips) < 80:
        plt.plot(area_fips["date"], area_fips["total_employment"] - area_fips["synthetic"], color="C5",alpha=0.4)

plt.plot(data.query("controls")["date"], data.query("controls")["total_employment"] - calif_synth,
        label="California")

plt.vlines(x=pd.to_datetime("2016-01-01"), ymin=-0.25, ymax=0.25, linestyle=":", lw=2, label="Proposition 99")
plt.hlines(y=0, xmin=pd.to_datetime("2014-01-01"), xmax=pd.to_datetime("2020-01-01"), lw=3)
plt.ylabel("Gap in per-capita cigarette sales (in packs)")
plt.title("Distribution of Effects")
plt.title("State - Synthetic Across Time (Large Pre-Treatment Errors Removed)")
plt.legend()