In [None]:
import os
os.chdir("..")

In [None]:
from src.data.data_process import DataReg
import polars as pl
import requests
import bambi as bmb
import geopandas as gpd
from pysal.lib import weights
from shapely import wkt
import pandas as pd
import arviz as az
import matplotlib.dates as mdates
import numpy as np
from pysal.lib import cg as geometry
import causalpy as cp
import matplotlib.pyplot as plt
dr = DataReg()

In [None]:
df_qcew = dr.base_data().with_columns(
    treatment=pl.when(pl.col("year") >= 2023).then(True).otherwise(False)
)
df_dp03 = dr.pull_dp03()
df_dp03 = df_dp03.with_columns(qtr=1)
pr_zips = gpd.GeoDataFrame(dr.make_spatial_table().df())
pr_zips["geometry"] = pr_zips["geometry"].apply(wkt.loads)
pr_zips = pr_zips.set_geometry("geometry")
pr_zips["zipcode"] = pr_zips["zipcode"].astype(str)

df = df_qcew.join(df_dp03, on=["zipcode","year", "qtr"], how="left")
df = pr_zips.join(
    df.to_pandas().set_index("zipcode"), on="zipcode", how="inner", validate="1:m"
        ).reset_index(drop=True)
df = df[df["year"] >= 2012].sort_values(by=["zipcode", "year","qtr"]).reset_index(drop=True)
pr_zips.plot()

In [None]:
params = ["inc_less_10k", "inc_10k_15k", "inc_15k_25k", "inc_25k_35k", "inc_35k_50k", "inc_50k_75k", "inc_75k_100k", "inc_100k_150k", "inc_150k_200k", "inc_more_200k"]

for col in params:
    df[col] = df[col].interpolate(method="cubic")
df = df.sort_values(by=["year","qtr","zipcode"]).reset_index(drop=True)
df

In [None]:
df = df[df["year"] >= 2012].sort_values(by=["year","qtr","zipcode"]).reset_index(drop=True)

In [None]:
w = weights.distance.DistanceBand.from_dataframe(df[(df["year"]== 2012) & (df["qtr"]== 1)], 1609.344 * 20, binary=True)

In [None]:
def calculate_spatial_lag(df, w, column):
    # Reshape y to match the number of rows in the dataframe
    y = df[column].values.reshape(-1, 1)
    
    # Apply spatial lag
    spatial_lag = weights.lag_spatial(w, y)
    
    return spatial_lag

# Initialize an empty list to store results
spatial_lag_results = []

# Assuming `df` has 'year' and 'quarter' columns for grouping
for year in range(2012,2019):
    for qtr in range(1,5):
        group_df = df[(df["year"]== year) & (df["qtr"]== qtr)].reset_index(drop=True)
        spatial_lag = calculate_spatial_lag(group_df, w, 'total_employment')
    
    # Add the spatial lag results back to the group dataframe
    group_df['w_employment'] = spatial_lag.flatten()  # Flatten to make it 1D for the column
    
    # Append the group to the results list
    spatial_lag_results.append(group_df)

# Concatenate all the results back together
reg = pd.concat(spatial_lag_results)
reg

In [None]:
test = df[(df["year"]== 2012) & (df["qtr"]== 1)].reset_index(drop=True)
y = test["total_employment"].values.reshape(-1,1)
w = weights.contiguity.Queen.from_dataframe(df)
# w.transform = 'r'
reg = test
reg["w_emplyment"] = weights.lag_spatial(w,y) 
# reg
reg

In [None]:
# Assuming `reg` is your DataFrame and you want to set the 'year' column as a datetime index
data = reg.copy()
data = data.drop("geometry", axis=1)
# data = data.drop("sector", axis=1)
# data = data.drop("treatment", axis=1)
data['date'] = data['year'] #* 10 + data['qtr']
income_columns = [
    'inc_25k_35k', 'inc_35k_50k', 'inc_50k_75k',
    'inc_75k_100k', 'inc_100k_150k', 'inc_150k_200k',
    'inc_more_200k'
]

# Step 1: Sort the DataFrame
data = data.sort_values(by=['zipcode', 'zipcode', 'qtr'])

# Step 2: Interpolate each column by zip group
data[income_columns] = data.groupby('zipcode')[income_columns].transform(
    lambda group: group.interpolate(method='linear', limit_direction='both')
)
data

In [None]:
model = bmb.Model(
    "total_employment ~ k_index + date + w_employment",
    data, dropna=True
)
results = model.fit(target_accept=0.95)

In [None]:
# model = bmb.Model("total_employment ~ k_index + date + (1 + k_index|zipcode) + w_emplyment + inc_less_10k + inc_10k_15k + inc_15k_25k + inc_25k_35k + inc_35k_50k +  inc_50k_75k + inc_75k_100k + inc_100k_150k + inc_150k_200k", data, dropna=True)
# 

In [None]:
model.plot_priors()

In [None]:
# Plot posteriors
az.plot_trace(
    results,
    compact=True,
)

In [None]:
res = az.summary(results)
res

In [None]:
res.to_csv("test.csv")

In [None]:
az.plot_forest(results, combined=True, hdi_prob=0.94)

In [None]:
rows_to_extract = [
        'inc_25k_35k', 'inc_35k_50k', 'inc_50k_75k',
    'inc_75k_100k', 'inc_100k_150k', 'inc_150k_200k',
    "Intercept",
    "date",
    "k_index",
    "sigma",
    "w_emplyment"
]

# Extract rows by index
res.loc[rows_to_extract]