In [None]:
import os 
os.chdir("..")

In [None]:
import polars as pl
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from src.data_process import DataReg
from scipy.spatial import distance

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

dr = DataReg(database_file="data.ddb")

In [None]:
df = dr.data_set()
df = df.filter(
    (pl.col("industry_code") == "72")
)
df_dp03 = dr.pull_dp03()
df_dp03 = df_dp03.with_columns(
    area_fips=pl.col("geoid"),
)
df = df.group_by(["area_fips","year"]).agg(
    employment=((pl.col("month1_emplvl") + pl.col("month2_emplvl") + pl.col("month3_emplvl")) / 3).mean()
)
data = df.join(
    df_dp03, on=["area_fips","year"],how="left",validate="m:1"
).sort(by=["area_fips","year"])
selected_cols = ["commute_car", "employment", "total_population"]

data2 = data.filter(pl.col("year") == 2015)

data_np = data2.select(selected_cols).to_numpy()

# Compute the mean and covariance matrix
mean_vec = np.mean(data_np, axis=0)
cov_matrix = np.cov(data_np, rowvar=False)
inv_cov_matrix = np.linalg.inv(cov_matrix)

# Compute Mahalanobis distance of each row from the mean
mahalanobis_distances = [
    distance.mahalanobis(row, mean_vec, inv_cov_matrix) for row in data_np
]

# Add distances to the DataFrame
data2 = data2.with_columns(
    mahalanobis=mahalanobis_distances,
    area_fips="i"+pl.col("area_fips")
    )

In [None]:
mean_vec = np.mean(data_np, axis=0)
mean_vec

In [None]:
controls = data2.sort("mahalanobis").head(200).select("area_fips").to_series().to_list()

In [None]:
df = dr.data_set()
df = df.filter(
    (pl.col("industry_code") == "72") &
    # (~pl.col("area_fips").is_in(remove)) & 
    (pl.col("year") < 2020)

)


df = df.with_columns(
    date=pl.col("year").cast(pl.String) + "Q" + pl.col("qtr").cast(pl.String),
    dummy=pl.lit(1),
    area_fips= "i" + pl.col("area_fips"),
    total_employment=((pl.col("month1_emplvl") + pl.col("month2_emplvl") + pl.col("month3_emplvl")) /
    3).log(),
    # after_treatment=pl.when((pl.col("year") >= 2016) & (pl.col("qtr") > 1)).then(True).otherwise(False)
)
# df.filter(pl.col("area_fips") == "i06081")

In [None]:
data = df.select(pl.col("area_fips", "date", "total_employment", "avg_wkly_wage")).with_columns(controls=pl.when(pl.col("area_fips") == "i06081").then(True).otherwise(False)).to_pandas()
data["date"] = pd.PeriodIndex(df['date'], freq='Q').to_timestamp()
data['after_treatment'] = data['date'] > pd.to_datetime('2016-01-01')
data = data[(data["area_fips"].isin(controls)) | (data["area_fips"] == "i06081")].reset_index(drop=True)
# data = data[(data["area_fips"].str.startswith("i06")) | (data["area_fips"] == "i06081")].reset_index(drop=True)
data = data[(data["avg_wkly_wage"] != 0) & (data["avg_wkly_wage"].notnull()) ].reset_index(drop=True)
data = data.sort_values("date")
data

In [None]:
features = ["total_employment"]
pre_df = (
    data
    .query("~after_treatment")
    .pivot(index='area_fips', columns="date", values=features)
    .T
).dropna(axis=1)

post_df = (
    data
    .query("after_treatment")
    .pivot(index='area_fips', columns="date", values=features)
    .T
)
pre_df = pre_df.dropna(axis=1)
post_df = post_df.dropna(axis=1)

controls = list(set(pre_df.columns) & set(post_df.columns))

pre_df = pre_df[controls]
post_df = post_df[controls]


In [None]:
data = data[data["area_fips"].isin(controls)].reset_index(drop=True)
data

In [None]:
fig, ax = plt.subplots()

(
    data.groupby(["date", "controls"], as_index=False)
    .agg({"total_employment": "mean"})
    .pipe(
        (sns.lineplot, "data"),
        x="date",
        y="total_employment",
        hue="controls",
        marker="o",
        ax=ax,
    )
)
ax.axvline(
    x=pd.to_datetime("2016-01-01"),
    linestyle=":",
    lw=2,
    color="C2",
    label="Iplementation of minimum wage",
)

ax.legend(loc="upper right")
ax.set(
    title="Employment",
    ylabel="total employment trend Trend"
)


In [None]:
features = ["total_employment"]

inverted = (data.query("~after_treatment")
    .pivot(index='area_fips', columns="date", values=features)
    .T)
inverted

In [None]:
y = inverted["i06081"].values 
X = inverted.drop(columns="i06081").values 

In [None]:
from typing import List
from operator import add
from toolz import reduce, partial

def loss_w(W, X, y) -> float:
    return np.sqrt(np.mean((y - X.dot(W))**2))

In [None]:
from scipy.optimize import fmin_slsqp

def get_w(X, y):
    
    w_start = [1/X.shape[1]]*X.shape[1]

    weights = fmin_slsqp(partial(loss_w, X=X, y=y),
                         np.array(w_start),
                         f_eqcons=lambda x: np.sum(x) - 1,
                         bounds=[(0.0, 1.0)]*len(w_start),
                         disp=False)
    return weights

In [None]:
calif_weights = get_w(X, y)
print("Sum:", calif_weights.sum())
np.round(calif_weights, 4)

In [None]:
calif_synth = data.query("~controls").pivot(index='date', columns="area_fips")["total_employment"].values.dot(calif_weights) 
calif_synth

In [None]:
plt.figure(figsize=(10,6))
plt.plot(data.query("controls")["date"], data.query("controls")["total_employment"], label="California")
plt.plot(data.query("controls")["date"], calif_synth, label="Synthetic Control")
# plt.vlines(x=pd.to_datetime("2016-01-01"), ymin=0, ymax=10.5, linestyle=":", lw=2, label="Proposition 99")
plt.ylabel("Per-capita cigarette sales (in packs)")
plt.legend()