In [None]:
import os 
os.chdir("..")

In [None]:
import polars as pl
import seaborn as sns
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
from src.data_process import DataReg
from scipy.spatial import distance

plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [10, 6]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

dr = DataReg(database_file="data.ddb")

In [None]:
df = dr.data_set()
df = df.filter(
    (pl.col("industry_code") == "72")
)
df_dp03 = dr.pull_dp03()
df_dp03 = df_dp03.with_columns(
    area_fips=pl.col("geoid"),
)
df = df.group_by(["area_fips","year"]).agg(
    employment=((pl.col("month1_emplvl") + pl.col("month2_emplvl") + pl.col("month3_emplvl")) / 3).mean()
)
data = df.join(
    df_dp03, on=["area_fips","year"],how="left",validate="m:1"
).sort(by=["area_fips","year"])
selected_cols = ["commute_car", "employment", "total_population"]

data2 = data.filter(pl.col("year") == 2015)

data_np = data2.select(selected_cols).to_numpy()

# Compute the mean and covariance matrix
mean_vec = np.mean(data_np, axis=0)
cov_matrix = np.cov(data_np, rowvar=False)
inv_cov_matrix = np.linalg.inv(cov_matrix)

# Compute Mahalanobis distance of each row from the mean
mahalanobis_distances = [
    distance.mahalanobis(row, mean_vec, inv_cov_matrix) for row in data_np
]

# Add distances to the DataFrame
data2 = data2.with_columns(
    mahalanobis=mahalanobis_distances,
    area_fips="i"+pl.col("area_fips")
    )

In [None]:
controls = data2.sort("mahalanobis").head(200).select("area_fips").to_series().to_list()

In [None]:
df = dr.data_set()
df = df.filter(
    (pl.col("industry_code") == "72") &
    # (~pl.col("area_fips").is_in(remove)) & 
    (pl.col("year") < 2020)

)


df = df.with_columns(
    date=pl.col("year").cast(pl.String) + "Q" + pl.col("qtr").cast(pl.String),
    dummy=pl.lit(1),
    area_fips= "i" + pl.col("area_fips"),
    total_employment=((pl.col("month1_emplvl") + pl.col("month2_emplvl") + pl.col("month3_emplvl")) /
    3).log(),
    # after_treatment=pl.when((pl.col("year") >= 2016) & (pl.col("qtr") > 1)).then(True).otherwise(False)
)
# df.filter(pl.col("area_fips") == "i06081")

In [None]:
data = df.select(pl.col("area_fips", "date", "total_employment", "avg_wkly_wage")).with_columns(controls=pl.when(pl.col("area_fips") == "i06081").then(True).otherwise(False)).to_pandas()
data["date"] = pd.PeriodIndex(df['date'], freq='Q').to_timestamp()
data['after_treatment'] = data['date'] > pd.to_datetime('2016-01-01')
data = data[(data["area_fips"].isin(controls)) | (data["area_fips"] == "i06081")].reset_index(drop=True)
data = data[(data["avg_wkly_wage"] != 0) & (data["avg_wkly_wage"].notnull()) ].reset_index(drop=True)
data.sort_values("date")

In [None]:
fig, ax = plt.subplots()

(
    data.groupby(["date", "controls"], as_index=False)
    .agg({"total_employment": "mean"})
    .pipe(
        (sns.lineplot, "data"),
        x="date",
        y="total_employment",
        hue="controls",
        marker="o",
        ax=ax,
    )
)
ax.axvline(
    x=pd.to_datetime("2016-01-01"),
    linestyle=":",
    lw=2,
    color="C2",
    label="Iplementation of minimum wage",
)

ax.legend(loc="upper right")
ax.set(
    title="Employment",
    ylabel="total employment trend Trend"
)


In [None]:
data