In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from pyprojroot import here

from darts.models import LightGBMModel, TFTModel
from darts.explainability.shap_explainer import ShapExplainer
from darts.explainability.tft_explainer import TFTExplainer

import torch

plt.style.use('default')
mpl.rcParams['figure.dpi'] = 300

In [None]:
TARGET = 'occ'
MODEL = 'lgbm'
FS = 'a'
HPO = 1
H = 24

In [None]:
colordict = {
    'Target' : '#CCEEBC',
    'Traffic' : '#8CC0DE',
    'TA' : '#FFD9C0',
    'Website visits' : '#FAF0D7'
}

In [None]:
inpath = here() / f'data/processed/models/{TARGET}-{MODEL}-{FS}-{HPO}.pkl'

model = LightGBMModel.load(str(inpath))
explainer = ShapExplainer(model, shap_method='tree')
result = explainer.explain(horizons=[1,24])
raw_values = result.get_explanation(horizon=H).pd_dataframe()

In [None]:
values = raw_values.abs().mean().sort_values(ascending=True)
values = (values / values.sum()) * 100
values = values.tail(20)

# Prettify strings
values.index = values.index.str.replace('past_cov_lag', 't')
values.index = [x[0].upper() + x[1:] for x in values.index]
values.index = values.index.str.replace('_', ' ')
values.index = values.index.str.replace('pastcov', '')
values.index = values.index.str.replace('lag', '$t')
values.index = values.index + '$'
values.index = values.index.str.replace('PO', 'AO')

g = (37/255, 194/255, 104/255)
r = (245/255, 50/255, 88/255)
b = (18/255, 137/255, 223/255)

values.sum()

ax = values.plot.barh(width=.8, 
                      figsize=(3,5), 
                      edgecolor='k', 
                      lw=1,
                      color='.8'
                      #color=[colordict.get(x, '.8') for x in [x[0] for x in values.index.str.split(':')]]
                     )

# ax.set_xlabel('Mean( |SHAP value|)')
ax.set_xlabel('Proportional Mean \n |SHAP value| (%)')
ax.set_ylabel('')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.grid(axis='x', lw=.5, ls='--', which='both')
ax.set_axisbelow(True)
#ax.set_xscale('log')
#ax.set_title(f'Horizon: {H}')

plt.savefig(here() / f'output/plots/importance.jpg', 
            bbox_inches='tight',
            dpi=300)

plt.savefig(here() / f'output/plots/importance-{H}.jpg', 
            bbox_inches='tight',
            dpi=300)