In [None]:
from emulator_utils.pre_process import log_standard, minmax, unscale
from emulator_utils.read_data import readpowerspec
from emulator_utils.split import random_holdout
from emulator_utils.surrogates import simple_mlp, train_mlp, save_mlp, load_mlp

In [None]:
import matplotlib.pylab as plt
import numpy as np
from glob import glob

### Reading spectra

In [None]:
k, _, _, _ = readpowerspec('../test/data/powerspec_LJ/m000p.pk.499')
fnames = glob('../test/data/powerspec_LJ/*')
pk_all = np.array([readpowerspec(f)[1] for f in fnames])

### Using a pre-processing pipeline

In [None]:
scaled, scaler = log_standard(pk_all)

### Using mock data here, will be rescaled as well

In [None]:
params = np.arange(scaled.shape[0]).reshape(-1, 1)
scaled_y, scaler_y = minmax(params)

In [None]:
pk_train, pk_test, param_train, param_test = random_holdout(scaled, params, split_fraction = 0.2)

In [None]:
plt.figure(1)
plt.plot(k, pk_train.T, 'k', alpha=0.3)
plt.plot(k, pk_test.T, 'r', alpha=0.7)
# plt.xscale('log')
# plt.yscale('log')
plt.xlabel('k')
plt.ylabel('Pk')

### Sample neural network 

### Fixing the network architecture (hidden_dims is flexible)

In [None]:
mlp = simple_mlp(input_shape=param_train.shape[1], 
                 output_shape=pk_train.shape[1], 
                 hidden_dims=[16, 128, 1024, 4096])

### Set-up hyper-parameters

In [None]:
learning_rate = 0.01
decay_rate = 0.01
batch_size = 10
num_epochs = 10

In [None]:
mlp, train_history = train_mlp(mlp, 
                    param_train, pk_train, param_test, pk_test, 
                    learning_rate, decay_rate, num_epochs, batch_size)

In [None]:
plt.figure(2111)
plt.plot(train_history.history['loss'], label = 'loss')
plt.plot(train_history.history['val_loss'], label = 'val loss')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

In [None]:
save_mlp(mlp, '../test/model/trained_mlp')
mlp = load_mlp('../test/model/trained_mlp')

In [None]:
mlp_pred = mlp.predict(param_test)
pk_pred = unscale(mlp_pred, scaler)

In [None]:
plt.figure(1)
plt.plot(k, pk_pred.T, 'r', alpha=0.5)
plt.plot(k, unscale(pk_test, scaler).T, 'k', alpha=0.4)

plt.xscale('log')
plt.yscale('log')
plt.xlabel('k')
plt.ylabel('Pk')

## explainers 

In [None]:
## local SHAP model
## global ALE model

In [None]:
from emulator_utils.explainer import shap_estimate
from emulator_utils.explainer import plot_shap_summary_single, plot_shap_summary_multiple
from emulator_utils.explainer import plot_shap_force_single, plot_shap_force_multiple

In [None]:
# import alibi
# from alibi.explainers import KernelShap
# from alibi.explainers import IntegratedGradients
# from alibi.explainers import ALE
# from alibi.explainers import plot_ale

# import shap

# def shap_estimate(model, training_data, test_data, input_names, output_names):  
#     predictor = model.predict
#     explainer = shap.KernelExplainer(predictor, training_data, features = input_names, out_names = output_names)
#     shap_values = explainer.shap_values(test_data)
#     expected_values = explainer.expected_value

#     return explainer, shap_values, expected_values

# def plot_shap_summary_single(shap_values_single, test_data, input_names, plot_type):   
#     p1 = shap.summary_plot(shap_values_single, test_data, feature_names = input_names, plot_type=plot_type)

    
# def plot_shap_summary_multiple(shap_values, test_data, input_names, plot_type):
#     p2 = shap.summary_plot(shap_values, test_data, feature_names = input_names, plot_type=plot_type)


# def plot_shap_force_single(expected_values, shap_values, input_names, output_names, out_id, test_id):
#     # predictor = model.predict
#     # explainer = shap.KernelExplainer(predictor, training_data, features = input_names, out_names = output_names)
#     p3 = shap.force_plot(expected_values[out_id], shap_values[out_id][test_id], feature_names = input_names, out_names = output_names[out_id])
#     return p3

# def plot_shap_force_multiple(expected_values, shap_values, input_names, output_names, out_id):
#     # predictor = model.predict
#     # explainer = shap.KernelExplainer(predictor, training_data, features = input_names, out_names = output_names)
#     # out_id = 0                                                                                                                                   
#     p4 = shap.force_plot(explainer.expected_value[out_id], shap_values[out_id], feature_names = input_names, out_names = output_names[out_id])  
#     return p4

In [None]:
explainer, shap_values, expected_values = shap_estimate(mlp, param_train, param_test, ['redshift'], ['pk'])

In [None]:
output_indx = 100  ## k-bin value
plot_shap_summary_single(shap_values[output_indx], param_test, ['redshift'], plot_type='violin')
plot_shap_summary_multiple(shap_values[:4000:200], param_test, ['redshift'], plot_type='bar')

In [None]:
plot_shap_force_multiple(expected_values, shap_values, ['redshift'], ['pk'], 0)

In [None]:
plot_shap_force_single(expected_values, shap_values, ['redshift'], ['pk'], 0, 11)