In [None]:
import os
os.chdir("..")

In [None]:
from src.data.data_process import DataReg
import polars as pl
import requests
import geopandas as gpd
from pysal.lib import weights
from shapely import wkt
import pandas as pd
import arviz as az
import matplotlib.dates as mdates
import numpy as np
import causalpy as cp
import matplotlib.pyplot as plt
dr = DataReg()

In [None]:
df_qcew = dr.base_data()
# df_dp03 = dr.pull_dp03()
pr_zips = gpd.GeoDataFrame(dr.make_spatial_table().df())
pr_zips["geometry"] = pr_zips["geometry"].apply(wkt.loads)
pr_zips = pr_zips.set_geometry("geometry")
pr_zips["zipcode"] = pr_zips["zipcode"].astype(str)

# df = df_qcew.join(df_dp03, on=["zipcode","year"], how="inner")
df = pr_zips.join(
    df_qcew.to_pandas().set_index("zipcode"), on="zipcode", how="inner", validate="1:m"
        ).reset_index(drop=True)
df

In [None]:
sector_df = df[df["sector"] == "61"].reset_index(drop=True)
y = sector_df["total_employment"].values.reshape(-1,1)
w = weights.distance.DistanceBand.from_dataframe(
    sector_df, 80467
)
reg = sector_df
reg["w_emplyment"] = weights.lag_spatial(w,y) 
reg

In [None]:
treatment_time = 2023

result = cp.SyntheticControl(
    reg,
    treatment_time,
    formula="total_employment ~ 0 + k_index + w_emplyment + inc_less_10k",
    model=cp.skl_models.WeightedProportion(),
)
fig, ax = result.plot(plot_predictors=True)

In [None]:
pr_zips = gpd.GeoDataFrame(dr.make_spatial_table().df())
pr_zips["geometry"] = pr_zips["geometry"].apply(wkt.loads)
pr_zips = pr_zips.set_geometry("geometry")
pr_zips["zipcode"] = pr_zips["zipcode"].astype(str)

In [None]:
empty_df = [pl.Series("date", [], dtype=pl.String)]
for zips in list(pr_zips["zipcode"].values):
    empty_df.append(pl.Series(zips, [], dtype=pl.Int32))
df_master = pl.DataFrame(empty_df)
df_master

In [None]:
df = dr.conn.sql("SELECT first_month_employment, second_month_employment, third_month_employment, ui_addr_5_zip, qtr, year FROM qcewtable").pl()

In [None]:
tmp = df.drop_nulls()
tmp = tmp.filter(pl.col("ui_addr_5_zip").is_in(list(pr_zips["zipcode"].values)))
tmp  = tmp.group_by(["year", "qtr","ui_addr_5_zip"]).agg(
    first_month_employment = pl.col("first_month_employment").sum(),
    second_month_employment = pl.col("second_month_employment").sum(),
    third_month_employment = pl.col("third_month_employment").sum()
)

tmp = tmp.with_columns(
    ui_addr_5_zip="zip_" + pl.col("ui_addr_5_zip")
)

In [None]:
def foo(df:pl.DataFrame, year, qtr):
    df = df.filter((pl.col("year") == year) & (pl.col("qtr") == qtr))
    if df.is_empty():
        return df
    names = df.select(pl.col("ui_addr_5_zip")).transpose()
    names = names.to_dicts().pop()
    df = df.drop("year", "qtr", "ui_addr_5_zip").transpose(include_header=True)
    df = df.rename(names)
    df = df.with_columns(
        date=pl.when((qtr == 1) & (pl.col("column") == "first_month_employment")).then(pl.lit(f"{year}-01-01"))
               .when((qtr == 1) & (pl.col("column") == "second_month_employment")).then(pl.lit(f"{year}-02-01"))
               .when((qtr == 1) & (pl.col("column") == "third_month_employment")).then(pl.lit(f"{year}-03-01"))
               .when((qtr == 2) & (pl.col("column") == "first_month_employment")).then(pl.lit(f"{year}-04-01"))
               .when((qtr == 2) & (pl.col("column") == "second_month_employment")).then(pl.lit(f"{year}-05-01"))
               .when((qtr == 2) & (pl.col("column") == "third_month_employment")).then(pl.lit(f"{year}-06-01"))
               .when((qtr == 3) & (pl.col("column") == "first_month_employment")).then(pl.lit(f"{year}-07-01"))
               .when((qtr == 3) & (pl.col("column") == "second_month_employment")).then(pl.lit(f"{year}-08-01"))
               .when((qtr == 3) & (pl.col("column") == "third_month_employment")).then(pl.lit(f"{year}-09-01"))
               .when((qtr == 4) & (pl.col("column") == "first_month_employment")).then(pl.lit(f"{year}-10-01"))
               .when((qtr == 4) & (pl.col("column") == "second_month_employment")).then(pl.lit(f"{year}-11-01"))
               .when((qtr == 4) & (pl.col("column") == "third_month_employment")).then(pl.lit(f"{year}-12-01"))
               .otherwise(pl.lit("ERROR"))
    )
    return df.drop("column")

foo(tmp, 2021, 1)

In [None]:
for year in [2023, 2024]:
    for qtr in range(1,5):
        something = foo(tmp,year, qtr)
        if something.is_empty():
            continue
        df_master = pl.concat([df_master, something], how="diagonal")
df_master

In [None]:
data = df_master
columns_with_nulls = [col for col in data.columns if data[col].is_null().any()]

data = data.drop(columns_with_nulls)
data = data.to_pandas()
data["date"] = pd.to_datetime(data["date"])
treatment_time = pd.to_datetime("2024-01-01")
data.index = pd.to_datetime(data['date'])
data = data.drop('date', axis=1)
data

In [None]:
formula = "zip_00791 ~ 0"
for col in data.columns:
    if col == "zip_00791":
        continue
    formula += f" + {col}"
formula

In [None]:
# Import and process data
# formula = ""
result = cp.SyntheticControl(
    data,
    treatment_time,
    formula=formula,
    model=cp.pymc_models.WeightedSumFitter(
        sample_kwargs={"target_accept": 0.95, "random_seed": 787}
    ),
)

In [None]:
az.summary(result.idata, round_to=2)
az.plot_trace(result.idata, var_names=["~mu"], compact=False)

In [None]:
fig, ax = result.plot(plot_predictors=False)

# formatting
ax[2].tick_params(axis="x", labelrotation=-90)
ax[2].xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax[2].xaxis.set_major_locator(mdates.YearLocator())
for i in [0, 1, 2]:
    ax[i].set(ylabel="Employment")