In [None]:
import numpy as np
import pandas as pd
from docs_plot_utils import plot_germany
from statsmodels.tsa.ar_model import AutoReg

from unit_averaging import InlineFocusFunction, OptimalUnitAverager


In [None]:
german_data = pd.read_csv(
    "data/tutorial_data.csv", parse_dates=True, index_col="period"
)

german_data.index = pd.DatetimeIndex(german_data.index.values, freq="MS")
german_data.iloc[-4:, [0, 2, -1]]

In [None]:
regions = german_data.columns[:-1]
regions


In [None]:
# Extract data on last month of target region
target_data = german_data.loc["2019-12", ["Frankfurt", "Deutschland"]].to_numpy().squeeze().copy()

# Construct focus function
forecast_cologne_jan_2020 = InlineFocusFunction(
    focus_function=lambda coef: coef[0]
    + coef[1] * target_data[0]
    + coef[2] * target_data[1],
    gradient=lambda coef: np.array([1, target_data[0], target_data[1]]),
)


In [None]:
german_data = german_data.diff()
german_data["Germany_lag"] = german_data["Deutschland"].shift(1)
german_data = german_data.iloc[2:,]


ind_estimates = {}
ind_covar_ests = {}

for region in regions:
    # Extract data and add lags
    ind_data = german_data.loc[:, [region, "Germany_lag"]]
    # Run an ARx(1) model
    ar_results = (
        AutoReg(ind_data.loc[:, region], 1, exog=ind_data["Germany_lag"])
    ).fit(cov_type="HAC", cov_kwds={"maxlags": 5})
    # Add to dictionary
    ind_estimates[region] = ar_results.params.to_numpy()
    ind_covar_ests[region] = ar_results.cov_params().to_numpy()

In [None]:
averager = OptimalUnitAverager(
    focus_function=forecast_cologne_jan_2020,
    ind_estimates=ind_estimates,
    ind_covar_ests=ind_covar_ests,
)

In [None]:
averager.fit(target_id="Frankfurt")

In [None]:
weight_dict = {}
for key, val in zip(averager.keys, averager.weights, strict=False):
    weight_dict[key] = val

weight_df = pd.Series(weight_dict).reset_index()
weight_df.columns = ["aab", "weights"]

fig, ax = plot_germany(
    weight_df,
    "Weight in Averaging Combination",
    cmap="Purples",
    vmin=-0.005,
)