In [None]:
from pathlib import Path

import pandas as pd
from catboost import Pool, CatBoostClassifier
import shap
shap.initjs()
import matplotlib.pyplot as plt

from var import DATA_OUT, IMAGE_OUT, FORECAST_HOURS_IN_ADVANCE

In [None]:
df = pd.read_pickle(Path(DATA_OUT, 'df_dataset.pickle'))

In [None]:
X = df[
    [
        *[col_ for col_ in df.columns if col_.startswith('ie_')],
        *[col_ for col_ in df.columns if col_.startswith('il_')],
        *[col_ for col_ in df.columns if col_.startswith('iu_')],
        *[col_ for col_ in df.columns if col_.startswith('io_')],
        'hf',
        'hf_mav_2h',
        'f_107_adj',
        'hp_30',
        'smr',
        'solar_zenith_angle',
        'bz',
        'vx',
        'rho',
        *[col_ for col_ in df.columns if col_.startswith('local_warning_')],
        *[col_ for col_ in df.columns if col_.startswith('spectral_contribution_')],
        *[col_ for col_ in df.columns if col_.startswith('azimuth_')],
        *[col_ for col_ in df.columns if col_.startswith('velocity_')],
    ]
].rename(
    columns={
        'spectral_contribution_at': 'spct_cont_at',
        'spectral_contribution_ff': 'spct_cont_ff',
        'spectral_contribution_jr': 'spct_cont_jr',
        'spectral_contribution_pq': 'spct_cont_pq',
        'spectral_contribution_ro': 'spct_cont_ro',
        'spectral_contribution_vt': 'spct_cont_vt',
    }
).copy()

y = df[f'tid_within_{FORECAST_HOURS_IN_ADVANCE}h'].copy()

In [None]:
cat_features = [
    *[col_ for col_ in df.columns if col_.endswith('_variation')],
    *[col_ for col_ in df.columns if col_.startswith('local_warning_')],
]

static_params = {
    "eval_metric": 'F1',
    "random_seed": 42,
    "auto_class_weights": "SqrtBalanced",
    "cat_features": cat_features,
    "od_type": "Iter",
    "use_best_model": True,
    "has_time": True,
    "od_wait": 200,
}

In [None]:
X_train, y_train = X.loc['2014':'2021'].copy(), y.loc['2014':'2021'].copy()
X_test, y_test = X.loc['2022'].copy(), y.loc['2022'].copy()

In [None]:
model = CatBoostClassifier(
    loss_function='Logloss',
    iterations=5,
    **static_params,
)

shap_dict, explainer_dict = {}, {}
for i in range(20):
    # Fit model
    model.fit(
        X_train,
        y_train,
        eval_set=(X_test, y_test),
        init_model=model if i!=0 else None,
        silent=True,
    )
    # Evaluate SHAP
    shap_values = model.get_feature_importance(
        Pool(X_train, label=y_train, cat_features=cat_features),
        type="ShapValues",
    )
    shap_dict[i] = shap_values[:,:-1]
    # Create explainer
    explainer_dict[i] = shap.TreeExplainer(model)

In [None]:
for key_ in shap_dict.keys():
    shap.summary_plot(
        shap_dict[key_],
        X_train,
        plot_type='bar',
        show=False,
        max_display=8,
    )
    ax = plt.gca()
    ax.set_xlim(0, 0.65)
    plt.savefig(
        Path(IMAGE_OUT, f'shap_{key_}.png'), dpi=400,
    )

In [None]:
# for key_ in explainer_dict.keys():
#     shap.plots.bar(
#         explainer_dict[key_](X_train),
#         max_display=6,
#     )