In [None]:
import numpy as np
import pandas as pd

import xgboost as xgb

from shapflow.flow import CausalLinks
from shapflow.flow import build_feature_graph
from shapflow.flow import GraphExplainer
from shapflow.flow import edge_credits2edge_credit
from shapflow.flow import translator, create_xgboost_f

from shap_flow_util import read_csv_incl_timeindex

import time
import dill
import tqdm
import multiprocess as mp
import os

In [None]:
periods = [('2018-01-01', '2021-09-30'),
            ('2021-10-01', '2023-12-31'),
            ('2018-01-01', '2023-12-31')]
version = 'v2'

targets = ['price', 'export']
for target in targets:
    for start_date, end_date in periods:
        model_name = 'xgb_{}_start_{}_end_{}'.format(target, start_date, end_date, version)
        X_test = read_csv_incl_timeindex('./data/{}/X_test_{}.csv'.format(version, model_name))
        X_test['isworkingday'] = X_test['isworkingday']*1.0

        model = xgb.Booster()
        model.load_model("./models/{}/{}_best.json".format(version, model_name))
        seed = 7
        
        n_bg = 100 # number of sampled background samples
        nsamples = 1000 # number of forefround samples to explain
        nruns = 500
        bg = X_test.sample(n=n_bg, random_state=seed) # background samples
        fg = X_test.sample(n=nsamples, random_state=seed) # foreground samples (samples to explain)

        bg.to_csv('./data/{}/bg_{}.csv'.format(version, model_name), sep=',', index=True)
        fg.to_csv('./data/{}/fg_{}.csv'.format(version, model_name), sep=',', index=True)

        causal_links = CausalLinks()
        categorical_feature_names = []
        display_translator = translator(X_test.columns, X_test, X_test)
        if target == 'price':
            target_name = 'price_da'
        elif target == 'export':
            target_name = 'agg_net_export'
        else:
            Exception('target unknown')
        
        feature_names = list(X_test.columns)

        year_features = ['day_of_year_sin', 'day_of_year_cos']
        causal_links.add_causes_effects(year_features, ['gas_price'])
        
        year_hour_features = year_features + ['hour_sin', 'hour_cos']
        wind_solar_da = ['wind_da', 'solar_da']
        load_rl = ['load_da', 'rl_BE', 'rl_ES', 'rl_DE_LU', 'rl_IT_NORD']
        nuc_ror = ['nuclear_avail', 'run_off_gen']
        causal_links.add_causes_effects(year_hour_features, wind_solar_da + load_rl + nuc_ror)
        
        river_temp_flow = ['river_temp', 'river_flow_mean']
        causal_links.add_causes_effects(['temp_mean'], river_temp_flow)
        
        causal_links.add_causes_effects(river_temp_flow, nuc_ror)
        
        causal_links.add_causes_effects(['temp_mean'], wind_solar_da + load_rl)
        
        causal_links.add_causes_effects(['isworkingday'], ['rl_FR_ramp'] + load_rl + nuc_ror)
        
        causal_links.add_causes_effects(feature_names, 
                                        target_name, 
                                        create_xgboost_f(feature_names, model))

        causal_graph = build_feature_graph(X_test, 
                                        causal_links=causal_links, 
                                        categorical_feature_names=categorical_feature_names,
                                        display_translator=display_translator,
                                        target_name=target_name,
                                        method='xgboost')
        g = causal_graph.to_graphviz('LR')

        #calculate multiple background result (same as in income.ipynb)
        # change this to a suitable value, depending on machine (e.g. 6, 12; on cluster 20)
        num_processes = 20
        from shap_flow_util import calculate_edge_credit

        start = time.time()

        model.set_param('n_jobs', -1)
        model.set_param('device', 'cpu')

        pool = mp.Pool(num_processes)
        _args = [(causal_graph, bg[i:i+1], fg, nruns) for i in range(len(bg))]
        edge_credits = pool.starmap(calculate_edge_credit, tqdm.tqdm(_args, total=len(_args)))
        pool.close()
        pool.join()

        end = time.time()
        print(end - start)
        
        # need this for being able to draw shapley flow (need to call shap_values for one bg sample redundandly)
        model.set_param('n_jobs', 40)
        explainer = GraphExplainer(causal_graph, bg[0:1], nruns, silent=False)
        cf = explainer.shap_values(fg)
        # save credit flow to file
        cf.edge_credit = edge_credits2edge_credit(edge_credits, cf.graph)
        
        directory = './credit_flow/{}'.format(version)
        if not os.path.exists(directory):
            os.makedirs(directory)
        with open('{}/flow_{}.pkl'.format(directory, model_name), 'wb') as file:
            dill.dump(cf, file)