In [None]:
import numpy as np
import pandas as pd
from statsmodels.tsa.ar_model import AutoReg

from unit_averaging import InlineFocusFunction, OptimalUnitAverager


In [None]:
import os

os.getcwd()

In [None]:
german_data = pd.read_csv(
    "docs/examples/data/tutorial_data.csv", parse_dates=True, index_col="period"
)

german_data.index = pd.DatetimeIndex(german_data.index.values, freq="MS")
german_data.iloc[-4:, [0, 2, -1]]

In [None]:
regions = german_data.columns[:-1]
regions


In [None]:
ind_estimates = {}
ind_covar_ests = {}

for region in regions:
    # Extract data and add lags
    ind_data = german_data.loc[:, [region, "Deutschland"]]
    # Run an ARx(1) model
    ar_results = (
        AutoReg(ind_data.loc[:, region], 1, exog=ind_data["Deutschland"])
    ).fit(cov_type="HAC", cov_kwds={"maxlags": 5})
    # Add to dictionary
    ind_estimates[region] = ar_results.params.to_numpy()
    ind_covar_ests[region] = ar_results.cov_params().to_numpy()

In [None]:
target_region = "Hamburg"
target_data = german_data.loc["2019-12", [target_region, "Deutschland"]].to_numpy().squeeze()

In [None]:
target_data

In [None]:
forecast_hamburg = InlineFocusFunction(
    focus_function=lambda coef: coef[0]
    + coef[1] * target_data[0]
    + coef[2] * target_data[1],
    gradient=lambda x: np.array([1, target_data[0], target_data[1]]),
)

In [None]:
forecast_hamburg.focus_function(ind_estimates[target_region])

In [None]:
averager = OptimalUnitAverager(
    focus_function=forecast_hamburg,
    ind_estimates=ind_estimates,
    ind_covar_ests=ind_covar_ests,
)

In [None]:
averager.fit(target_id="Hamburg")

In [None]:
averager.weights_
averager.estimate_

In [None]:
averager.weights_.round(3)

In [None]:
def plot_germany(data_df, aab_shp, **kwargs):
    """Plot a chloropleth map using data_df with shapes in aab_shp"""

    # Clean data_series index
    data_df["aab"] = data_df["aab"].apply(
        lambda x: x.replace(" - ", "-").replace("-", " - ")
    )

    series_name = data_df.columns[-1]

    merged_geo_data_df = aab_shp.merge(
        data_df,
        left_on="region",
        right_on="aab",
        how="outer",
    )

    # plot
    ax = merged_geo_data_df.to_crs("EPSG:25832").plot(series_name, **kwargs)
    ax.axis("off")
    # Return plot
    return ax
