In [None]:
import os
os.chdir("..")

In [None]:
import world_bank_data as wb
import pandas as pd
import polars as pl
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from src.data.data_pull import DataClean

dc = DataClean()

# TODO
- colapse by 5 average
- keep the balance data
- 

In [None]:
dc.pull_wb().drop_nulls().select(pl.col("country")).unique()

In [None]:
df = pl.DataFrame(wb.get_countries())
df

In [None]:
df = dc.wb_data(params=["GFDD.OI.01"],year=1980)
df

In [None]:
columns = [
    'gdp_growth', 'spending', 'gini',
    'expenses'
]
data = df.to_pandas()

# Step 1: Sort the DataFrame
data = data[data["country"] != "World"]
data = data[data["country"] == "United States"]
data = data.sort_values(by=['year', 'country']).reset_index(drop=True)

# Step 2: Interpolate each column by zip group
# data[columns] = data.groupby('country')[columns].transform(
#     lambda group: group.interpolate(method='linear', limit_direction='both')
# )
data = data.dropna(subset=["gdp_growth"]).reset_index(drop=True)
data = data.dropna(subset=["expenses"]).reset_index(drop=True)
data

In [None]:
def plot_knots(knots, ax):
    for knot in knots:
        ax.axvline(knot, color="0.1", alpha=0.4)
    return ax

In [None]:
num_knots = 5
knots = np.quantile(data["expenses"], np.linspace(0, 1, num_knots))
iknots = knots[1:-1]
priors = {
    "Intercept": bmb.Prior("Normal", mu=100, sigma=10),
    "common": bmb.Prior("Normal", mu=0, sigma=10), 
    "sigma": bmb.Prior("Exponential", lam=1)
}
model = bmb.Model("gdp_growth ~ year + bs(expenses, knots=iknots, intercept=True)", data, dropna=True, priors=priors)

In [None]:
def plot_spline_basis(basis, expenses, figsize=(10, 6)):
    df = (
        pd.DataFrame(basis)
        .assign(expenses=expenses)
        .melt("expenses", var_name="basis_idx", value_name="gdp_growth")
    )

    _, ax = plt.subplots(figsize=figsize)

    for idx in df.basis_idx.unique():
        d = df[df.basis_idx == idx]
        ax.plot(d["expenses"], d["gdp_growth"])
    
    return ax

In [None]:
B = model.components["mu"].design.common["bs(expenses, knots=iknots, intercept=True)"]
ax = plot_spline_basis(B, data["expenses"].values)
plot_knots(knots, ax);

In [None]:
# The seed is to make results reproducible
idata = model.fit(idata_kwargs={"log_likelihood": True})

In [None]:
az.plot_trace(idata);

In [None]:
posterior_stacked = az.extract(idata)
wp = posterior_stacked["bs(expenses, knots=iknots, intercept=True)"].mean("sample").values

ax = plot_spline_basis(B * wp.T, data["expenses"].values)
ax.plot(data.expenses.values, np.dot(B, wp.T), color="black", lw=3)
plot_knots(knots, ax);