# Logging Reptilia Workflow

In [1]:
import pandas as pd

## MCMC Commands

Key:
"*" = Done
"**" = In progress


File Name Notes:
- All Gamma (-mG models) have _G in the file name
- Gibbs models have Gibbs in the name

First date = when .pkl and sum.txt are outputted
Second date = when ex_rates, sp_rates, per_species_rates, mcmc are ouptutted 

CoVar Model (not BDNN): 8/15, 8/19
- *python PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -trait_file data/reptilia/Reptilia_species_traits.txt -mCov 5 -logT 1 -pC 0 -fixShift data/Time_bins_CrossStage.txt -qShift data/Time_bins_ByStages.txt -mG -A 0 -n 20000000 -s 2000
    - This is: a Covar BD model with fixed times of rate shifts, log transformed traits, TPP and Gamma preservation model, parameter estimation MCMC

BDNN run 1: 8/15, 8/19
- *python PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -j 1 -fixShift data/Time_bins_CrossStage.txt -BDNNmodel 1 -trait_file data/reptilia/Reptilia_species_traits.txt -qShift data/Time_bins_ByStages.txt -mG -A 0 -n 20000000 -s 2000
    - Traits file needed to be: normalized continuous variables, no nulls, consistent data types, tab separated .txt

BDNN run 2: 8/27, 8/28
- *python PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -j 1 -fixShift data/Time_bins_ByStages.txt -BDNNmodel 1 -trait_file data/reptilia/Reptilia_species_traits.txt -qShift data/Time_bins_ByStages.txt -A 0 -n 20000000 -s 2000 -BDNNnodes 8 4 -BDNNupdate_f 0.05 0.05 0.25 -singleton 1
    - Removed -mG flag
    - Removed singletons using -singleton 1
    - Reduced network complexity:
        - -BDNNnodes 8 4
        - -BDNNupdate_f 0.05 0.05 0.25
    - Shifted dates towards the present to remove empty space from LAD to present day: -translate 175.0

BDNN run 3: Torsten Reduced Complexity + no -mG, 8/28, 8/29, 8/30. 9/4
- *python ../PyRate/PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -j 1 -fixShift data/Time_bins_CrossStage.txt -BDNNmodel 1 -trait_file data/reptilia/Reptilia_species_traits.txt -qShift data/Time_bins_CrossStage.txt -n 50000000 -s 50000 -BDNNnodes 8 4 -translate -175
    - **Result**: low ESS prior, BD_lik. Burn-in ~ 15%
- *python ../PyRate/PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -j 1 -fixShift data/Time_bins_ByStages.txt -BDNNmodel 1 -trait_file data/reptilia/Reptilia_species_traits.txt -qShift data/Time_bins_ByStages.txt -n 50000000 -s 50000 -BDNNnodes 8 4 -translate -175
    - Starting to use PyRate from PyRate repo, not Arielli repo
    - Removed -mG flag
    - Reduced network complexity: -BDNNnodes 8 4
    - **Result**: low ESS prior, BD_lik. Burn-in very high for those two. Going forward with 10%
BDNN run 3 RESTORED:
- python PyRate.py ../PyRate/PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -restore_mcmc ..../pyrate_mcmc_logs/*_mcmc.log -BDNNmodel 1 -trait_file  .../Traits.txt -BDNNtimevar …/Paleotemperature.txt -mG -n 200001 -p 20000 -s 5000


BDNN run 4: Gibbs + no -mG
- **python ../PyRate/PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -j 1 -fixShift data/Time_bins_CrossStage.txt -BDNNmodel 1 -trait_file data/reptilia/Reptilia_species_traits.txt -qShift data/Time_bins_CrossStage.txt -n 50000000 -s 50000 -se_gibbs -translate -175
- ** python ../PyRate/PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -j 1 -fixShift data/Time_bins_ByStages.txt -BDNNmodel 1 -trait_file data/reptilia/Reptilia_species_traits.txt -qShift data/Time_bins_ByStages.txt -n 50000000 -s 50000 -se_gibbs -translate -175
    - Removed -mG flag
    - Uses Gibbs sampler: -se_gibbs True

BDNN run 5: BDNN 3 w/ more generations
- python ../PyRate/PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -j 1 -fixShift data/Time_bins_ByStages.txt -BDNNmodel 1 -trait_file data/reptilia/Reptilia_species_traits.txt -qShift data/Time_bins_ByStages.txt -n 200000000 -s 20000 -BDNNnodes 8 4 -translate -175
    - Only doing By Stages now, since both By and Cross Stages had similar results, and By Stages makes more conceptual sense
    - Run the above 4 x to compare whether each independent run reaches the same values (convergence)

COVAR run 2: more gens just to check RJMCMC times of rate shift
- python ../PyRate/PyRate.py reptilia/Reptilia_cleaned_pyrate_input_PyRate.py -trait_file data/reptilia/Reptilia_species_traits.txt -mCov 5 -logT 1 -pC 0 -fixShift data/Time_bins_ByStages.txt -qShift data/Time_bins_ByStages.txt -A 4 -n 200000000 -s 20000
    - This is: a Covar BD model with fixed times of rate shifts, log transformed traits, TPP and Gamma preservation model, parameter estimation MCMC
    - Removed -mG
    - Run the above 4 x to compare whether each independent run reaches the same values (convergence


NOTE:
- Cross Stages from BDNN run 4.b and 4.a (in that orientation) are the only ones currently running as of 9/4!


In [2]:
mcmc = pd.read_csv('pyrate_mcmc_logs/Reptilia_cleaned_pyrate_input_1_G_BDS_BDNN_16_8Tc_mcmc.log', sep='\t')
mcmc.head()

Unnamed: 0,it,posterior,prior,PP_lik,BD_lik,q_0,q_1,q_2,q_3,q_4,...,Yelaphomte_TE,Yimenosaurus_TE,Youngetta_TE,Youngina_TE,Youngosuchus_TE,Yunguisaurus_TE,Yunnanosaurus_TE,Zanclodon_TE,Zhongjiania_TE,Zupaysaurus_TE
0,0,-19488.475162,-519.246613,-11382.427246,-7586.801303,0.283842,0.283842,0.283842,0.283842,0.283842,...,210.773269,199.882408,247.366545,251.166559,247.083426,237.072325,199.3755,191.859012,255.134883,222.219523
1,2000,-15455.251899,-528.505992,-10340.130915,-4586.614992,0.517565,0.428381,0.416735,0.545225,0.394254,...,210.461578,199.579867,247.366545,250.991207,247.083426,237.072325,199.3755,192.214232,255.248001,222.219523
2,4000,-15151.722528,-537.285448,-10387.004305,-4227.432776,0.264763,1.172631,0.658385,0.29891,0.709248,...,210.461578,199.36223,247.281448,250.991207,246.858753,237.102069,198.961248,191.736214,255.628004,222.219523
3,6000,-15143.049096,-543.92294,-10399.280736,-4199.84542,0.599488,1.129137,0.600612,0.443968,0.798106,...,210.461578,199.539161,247.400523,251.372468,246.858753,237.030775,199.055066,191.370345,255.615157,222.016264
4,8000,-15161.341155,-550.906728,-10443.5752,-4166.859227,0.56897,1.591494,0.669723,0.427113,0.798312,...,210.281404,199.919074,247.350455,252.446115,246.858753,235.9386,199.055066,190.719868,255.615157,222.202844


In [4]:
# Checking to see if any columns are a list
list_columns = mcmc.columns[mcmc.applymap(lambda x: isinstance(x, list)).any()].tolist()

## Post-Processing Commands

### First Steps:
- Move MCMC files into descriptive folder
- Check Tracer to decide on burn-in percentage

### Marginal RTT Plot
**Output in pyrate_mcmc_logs/bdnn...**: _RTT.pdf, _RTT.r
BDNN run 3:
- *python ../PyRate/PyRate.py -plotBDNN reptilia/pyrate_mcmc_logs/bdnn3_cross/Reptilia_cleaned_pyrate_input_1_BDS_BDNN_8_4Tc_mcmc.log -b 0.15 -translate -175
- *python ../PyRate/PyRate.py -plotBDNN reptilia/pyrate_mcmc_logs/bdnn3_by/Reptilia_cleaned_pyrate_input_1_BDS_BDNN_8_4Tc_mcmc.log -b 0.15 -translate -175

BDNN run 4 (gibbs):
- python ../PyRate/PyRate.py -plotBDNN reptilia/pyrate_mcmc_logs/bdnn4_cross/  _mcmc.log -b 0.1 -translate -175
- python ../PyRate/PyRate.py -plotBDNN reptilia/pyrate_mcmc_logs/bdnn4_by/  _mcmc.log -b 0.1 -translate -175

### Partial Dependence Plots (PDP)
**Output in pyrate_mcmc_logs/bdnn...**: _PDP.pdf, _PDP.r
BDNN run 3:
- *python ../PyRate/PyRate.py -plotBDNN_effects reptilia/pyrate_mcmc_logs/bdnn3_cross/Reptilia_cleaned_pyrate_input_1_BDS_BDNN_8_4Tc_mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -translate -175 -b 0.15 -resample 100
- *python ../PyRate/PyRate.py -plotBDNN_effects reptilia/pyrate_mcmc_logs/bdnn3_by/Reptilia_cleaned_pyrate_input_1_BDS_BDNN_8_4Tc_mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -translate -175 -b 0.15 -resample 100

BDNN run 4:
- python ../PyRate/PyRate.py -plotBDNN_effects reptilia/pyrate_mcmc_logs/bdnn4_cross      mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt
python ../PyRate/PyRate.py -plotBDNN_effects reptilia/pyrate_mcmc_logs/bdnn4_by      mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt

### Partial Dependence Rates: DEFUNCT
*Accourding to Hauffe: only needed if you want n-way interactions where n>3*
BDNN run 3:
- *python ../PyRate/PyRate.py -BDNN_interaction reptilia/pyrate_mcmc_logs/bdnn3_cross/Reptilia_cleaned_pyrate_input_1_BDS_BDNN_8_4Tc_mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -b 0.15 -resample 100

BDNN run4:
- python ../PyRate/PyRate.py -BDNN_interaction reptilia/pyrate_mcmc_logs/bdnn4      mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -b 0.1 -resample 100

### Predictor Importance
**Output in pyrate_mcmc_logs/bdnn...**: 
- _contribution_per_species_rates.r
- _contribution_per_species_rates.pdf
- ex_predictor_influence.csv
- ex_shap_per_species.csv
- sp_predictor_influence.csv
- sp_shap_per_species.csv

BDNN run 3:
- *python ../PyRate/PyRate.py -BDNN_pred_importance reptilia/pyrate_mcmc_logs/bdnn3_cross/Reptilia_cleaned_pyrate_input_1_BDS_BDNN_8_4Tc_mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -b 0.15 -resample 100 -BDNN_nsim_expected_cv 0 -BDNN_pred_importance_interaction
    - BDNN_pred_importance: rank 2-way interactions in addition to per-predictor
    - Notes from the run: Different bin sizes detected due to using -fixShift. Time windows resampled to a resolution of 5.0. 
        - (Because CrossStage's smallest bin size is 5)
- *python ../PyRate/PyRate.py -BDNN_pred_importance reptilia/pyrate_mcmc_logs/bdnn3_by/Reptilia_cleaned_pyrate_input_1_BDS_BDNN_8_4Tc_mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -b 0.15 -resample 100 -BDNN_nsim_expected_cv 0 -BDNN_pred_importance_interaction

BDNN run 4:
- python ../PyRate/PyRate.py -BDNN_pred_importance reptilia/pyrate_mcmc_logs/bdnn4_cross       _mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -b 0.1 -resample 100 -BDNN_nsim_expected_cv 0 -BDNN_pred_importance_interaction
- python ../PyRate/PyRate.py -BDNN_pred_importance reptilia/pyrate_mcmc_logs/bdnn4_by       _mcmc.log -plotBDNN_transf_features data/reptilia/reptilia_backscale.txt -b 0.1 -resample 100 -BDNN_nsim_expected_cv 0 -BDNN_pred_importance_interaction

## SHAP Code Exploration
ex_shap_per_species

In [None]:
# PYRATE CODE
 elif args.BDNN_pred_importance != "":
        import pyrate_lib.bdnn_lib as bdnn_lib
        path_dir_log_files = args.BDNN_pred_importance.replace("_mcmc.log", "")
        pkl_file = path_dir_log_files + ".pkl"
        mcmc_file = path_dir_log_files + "_mcmc.log"
        do_inter_imp = args.BDNN_pred_importance_interaction is False
        BDNNmodel = bdnn_lib.get_bdnn_model(pkl_file)
        sp_taxa_shap, ex_taxa_shap, q_taxa_shap = None, None, None
        sp_main_consrank, ex_main_consrank, q_main_consrank = None, None, None
        if BDNNmodel in [1, 3] and args.BDNN_nsim_expected_cv > 0:
            print("Getting expected coefficient of rate variation")
            bdnn_lib.get_coefficient_rate_variation(path_dir_log_files, burnin,
                                                    combine_discr_features=args.BDNN_groups,
                                                    num_sim=args.BDNN_nsim_expected_cv,
                                                    num_processes=args.thread[0],
                                                    show_progressbar=True)
        if BDNNmodel in [2, 3] and args.BDNN_nsim_expected_cv > 0:
            print("Getting expected coefficient of sampling variation")
            bdnn_lib.get_coefficient_sampling_variation(path_dir_log_files, burnin,
                                                        combine_discr_features=args.BDNN_groups,
                                                        num_sim=args.BDNN_nsim_expected_cv,
                                                        num_processes=args.thread[0],
                                                        show_progressbar=True)
        if BDNNmodel in [1, 3]:
            print("Getting permutation importance birth-death")
            sp_featperm, ex_featperm = bdnn_lib.feature_permutation(mcmc_file, pkl_file,
                                                                    burnin,
                                                                    thin=args.resample,
                                                                    min_bs=args.BDNN_pred_importance_window_size[0],
                                                                    n_perm=args.BDNN_pred_importance_nperm,
                                                                    num_processes=args.thread[0],
                                                                    combine_discr_features=args.BDNN_groups,
                                                                    show_progressbar=True,
                                                                    do_inter_imp=do_inter_imp)
        if BDNNmodel in [2, 3]:
            print("Getting permutation importance sampling")
            q_featperm = bdnn_lib.feature_permutation_sampling(mcmc_file, pkl_file,
                                                               burnin,
                                                               thin=args.resample,
                                                               min_bs=args.BDNN_pred_importance_window_size[-1],
                                                               n_perm=args.BDNN_pred_importance_nperm,
                                                               num_processes=args.thread[0],
                                                               combine_discr_features= args.BDNN_groups,
                                                               show_progressbar=True,
                                                               do_inter_imp=do_inter_imp)
        if BDNNmodel in [1, 3]:
            print("Getting SHAP values birth-death")
            sp_shap, ex_shap, sp_taxa_shap, ex_taxa_shap = bdnn_lib.k_add_kernel_shap(mcmc_file, pkl_file,
                                                                                      burnin,
                                                                                      thin=args.resample,
                                                                                      num_processes=args.thread[0],
                                                                                      combine_discr_features=args.BDNN_groups,
                                                                                      show_progressbar=True,
                                                                                      do_inter_imp=do_inter_imp,
                                                                                      use_mean=args.BDNN_mean_shap_per_group)
        if BDNNmodel in [2, 3]:
            print("Getting SHAP values sampling")
            q_shap, q_taxa_shap = bdnn_lib.k_add_kernel_shap_sampling(mcmc_file, pkl_file,
                                                                      burnin,
                                                                      thin=args.resample,
                                                                      num_processes=args.thread[0],
                                                                      combine_discr_features=args.BDNN_groups,
                                                                      show_progressbar=True,
                                                                      do_inter_imp=do_inter_imp)
        obj_effect = bdnn_lib.get_effect_objects(mcmc_file, pkl_file,
                                                 burnin,
                                                 thin=args.resample,
                                                 combine_discr_features=args.BDNN_groups,
                                                 file_transf_features=args.plotBDNN_transf_features,
                                                 num_processes=args.thread[0],
                                                 show_progressbar=True,
                                                 do_inter_imp=do_inter_imp)
        bdnn_obj, cond_trait_tbl_sp, cond_trait_tbl_ex, cond_trait_tbl_q, names_features_sp, names_features_ex, names_features_q, sp_rate_part, ex_rate_part, q_rate_part, sp_fad_lad, backscale_par = obj_effect
        if BDNNmodel in [1, 3]:
            print("Getting marginal probabilities birth-death")
            sp_pv = bdnn_lib.get_prob_effects(cond_trait_tbl_sp, sp_rate_part, bdnn_obj, names_features_sp, rate_type='speciation')
            ex_pv = bdnn_lib.get_prob_effects(cond_trait_tbl_ex, ex_rate_part, bdnn_obj, names_features_ex, rate_type='extinction')
        if BDNNmodel in [2, 3]:
            print("Getting marginal probabilities sampling")
            q_pv = bdnn_lib.get_prob_effects(cond_trait_tbl_q, q_rate_part, bdnn_obj, names_features_q, rate_type='sampling')
        if BDNNmodel in [1, 3]:
            # consensus among 3 feature importance methods
            print("Getting consensus ranking birth-death")
            sp_feat_importance, sp_main_consrank = bdnn_lib.get_consensus_ranking(sp_pv, sp_shap, sp_featperm)
            ex_feat_importance, ex_main_consrank = bdnn_lib.get_consensus_ranking(ex_pv, ex_shap, ex_featperm)
            output_wd = os.path.dirname(os.path.realpath(path_dir_log_files))
            name_file = os.path.basename(path_dir_log_files)
            ex_feat_merged_file = os.path.join(output_wd, name_file + '_ex_predictor_influence.csv')
            ex_feat_importance.to_csv(ex_feat_merged_file, na_rep='NA', index=False)
            sp_feat_merged_file = os.path.join(output_wd, name_file + '_sp_predictor_influence.csv')
            sp_feat_importance.to_csv(sp_feat_merged_file, na_rep='NA', index=False)
            sp_taxa_shap_file = os.path.join(output_wd, name_file + '_sp_shap_per_species.csv')
            sp_taxa_shap.to_csv(sp_taxa_shap_file, na_rep='NA', index=False)
            ex_taxa_shap_file = os.path.join(output_wd, name_file + '_ex_shap_per_species.csv')
            ex_taxa_shap.to_csv(ex_taxa_shap_file, na_rep='NA', index=False)
        if BDNNmodel in [2, 3]:
            print("Getting consensus ranking sampling")
            q_feat_importance, q_main_consrank = bdnn_lib.get_consensus_ranking(q_pv, q_shap, q_featperm)
            output_wd = os.path.dirname(os.path.realpath(path_dir_log_files))
            name_file = os.path.basename(path_dir_log_files)
            q_feat_merged_file = os.path.join(output_wd, name_file + '_q_predictor_influence.csv')
            q_feat_importance.to_csv(q_feat_merged_file, na_rep='NA', index=False)
            q_taxa_shap_file = os.path.join(output_wd, name_file + '_q_shap_per_species.csv')
            q_taxa_shap.to_csv(q_taxa_shap_file, na_rep='NA', index=False)
        # Plot contribution to species-specific rates
        bdnn_lib.dotplot_species_shap(mcmc_file, pkl_file, burnin, args.resample, output_wd, name_file,
                                      sp_taxa_shap, ex_taxa_shap, q_taxa_shap,
                                      sp_main_consrank, ex_main_consrank, q_main_consrank,
                                      combine_discr_features=args.BDNN_groups,
                                      file_transf_features=args.plotBDNN_transf_features,
                                      translate=args.translate)
        quit()
# Saving Files
if BDNNmodel in [1, 3]:
            # consensus among 3 feature importance methods
            print("Getting consensus ranking birth-death")
            sp_feat_importance, sp_main_consrank = bdnn_lib.get_consensus_ranking(sp_pv, sp_shap, sp_featperm)
            ex_feat_importance, ex_main_consrank = bdnn_lib.get_consensus_ranking(ex_pv, ex_shap, ex_featperm)
            output_wd = os.path.dirname(os.path.realpath(path_dir_log_files))
            name_file = os.path.basename(path_dir_log_files)
            ex_feat_merged_file = os.path.join(output_wd, name_file + '_ex_predictor_influence.csv')
            ex_feat_importance.to_csv(ex_feat_merged_file, na_rep='NA', index=False)
            sp_feat_merged_file = os.path.join(output_wd, name_file + '_sp_predictor_influence.csv')
            sp_feat_importance.to_csv(sp_feat_merged_file, na_rep='NA', index=False)
            sp_taxa_shap_file = os.path.join(output_wd, name_file + '_sp_shap_per_species.csv')
            sp_taxa_shap.to_csv(sp_taxa_shap_file, na_rep='NA', index=False)
            ex_taxa_shap_file = os.path.join(output_wd, name_file + '_ex_shap_per_species.csv')
            ex_taxa_shap.to_csv(ex_taxa_shap_file, na_rep='NA', index=False)


# BDNN_LIB CODE
def k_add_kernel_shap(mcmc_file, pkl_file, burnin, thin, num_processes=1, combine_discr_features={}, show_progressbar=False, do_inter_imp=True, use_mean=False):
#    if do_inter_imp == False:
#        from fastshap import KernelExplainer
    bdnn_obj, post_w_sp, post_w_ex, _, sp_fad_lad, post_ts, post_te, post_t_reg_lam, post_t_reg_mu, _, post_reg_denom_lam, post_reg_denom_mu, _, _, _ = bdnn_parse_results(mcmc_file, pkl_file, burnin, thin)
    mcmc_samples = post_ts.shape[0]
    trt_tbls = bdnn_obj.trait_tbls[:2]
    n_species = trt_tbls[0].shape[-2]
    n_features = trt_tbls[0].shape[-1]
    names_features_sp = get_names_features(bdnn_obj, rate_type='speciation')
    names_features_ex = copy_lib.deepcopy(names_features_sp)
    n_states = 1
    if len(combine_discr_features) > 0:
        n_states = len(combine_discr_features[list(combine_discr_features.keys())[0]])
#    if n_features == 1 or (n_states == n_features):
#        return make_shap_result_for_single_feature(names_features_sp, names_features_ex, combine_discr_features)
    if n_features == 1:
        if n_states > n_features:
            do_inter_imp = False
        else:
            return make_shap_result_for_single_feature(names_features_sp, names_features_ex, combine_discr_features)
    bdnn_dd = 'diversity' in names_features_sp
    div_idx_trt_tbl = -1
    if is_time_trait(bdnn_obj) and bdnn_dd:
            div_idx_trt_tbl = -2
    hidden_act_f = bdnn_obj.bdnn_settings['hidden_act_f']
    out_act_f = bdnn_obj.bdnn_settings['out_act_f']
    idx_comb_feat_sp = get_idx_comb_feat(names_features_sp, combine_discr_features)
    idx_comb_feat_ex = get_idx_comb_feat(names_features_ex, combine_discr_features)
    shap_names_sp = make_shap_names(names_features_sp, idx_comb_feat_sp, combine_discr_features, do_inter_imp = do_inter_imp)
    shap_names_ex = make_shap_names(names_features_ex, idx_comb_feat_ex, combine_discr_features, do_inter_imp = do_inter_imp)
    n_main_eff_sp = np.sum(shap_names_sp[:, 1] == 'none')
    n_main_eff_ex = np.sum(shap_names_ex[:, 1] == 'none')
    n_inter_eff_sp = int(n_main_eff_sp * (n_main_eff_sp - 1) / 2)
    n_inter_eff_ex = int(n_main_eff_sp * (n_main_eff_sp - 1) / 2)
    if do_inter_imp is False:
        n_inter_eff_sp = 0
        n_inter_eff_ex = 0
    n_effects_sp = n_main_eff_sp + n_inter_eff_sp + 1 + n_species * n_main_eff_sp # np.concatenate((shap_main, shap_interaction, baseline, shap_main_instances.flatten()))
    n_effects_ex = n_main_eff_ex + n_inter_eff_ex + 1 + n_species * n_main_eff_ex
    args = []
    for i in range(mcmc_samples):
        a = [bdnn_obj, post_ts[i, :], post_te[i, :],
             post_w_sp[i], post_w_ex[i], post_t_reg_lam[i], post_t_reg_mu[i], post_reg_denom_lam[i], post_reg_denom_mu[i],
             hidden_act_f, out_act_f, trt_tbls, bdnn_dd, div_idx_trt_tbl, idx_comb_feat_sp, idx_comb_feat_ex, do_inter_imp, use_mean]
        args.append(a)
    unixos = is_unix()
    if unixos and num_processes > 1:
        pool_perm = multiprocessing.Pool(num_processes)
        shap_values = list(tqdm(pool_perm.imap_unordered(k_add_kernel_shap_i, args),
                                total = mcmc_samples, disable = show_progressbar == False))
        pool_perm.close()
    else:
        shap_values = []
        for i in tqdm(range(mcmc_samples), disable = show_progressbar == False):
            shap_values.append(k_add_kernel_shap_i(args[i]))
    shap_values = np.vstack(shap_values)
    shap_summary = get_rates_summary(shap_values.T)
    mean_shap_sp = shap_summary[:(n_main_eff_sp + n_inter_eff_sp), :]
    mean_shap_ex = shap_summary[n_effects_sp:(n_effects_sp + n_main_eff_ex + n_inter_eff_ex), :]
    taxa_shap_sp = shap_summary[(n_main_eff_sp + n_inter_eff_sp):n_effects_sp, :] # First row is baseline
    taxa_shap_ex = shap_summary[(n_effects_sp + n_main_eff_ex + n_inter_eff_ex):, :]
    if bdnn_dd:
        trt_tbls[0][0, :, div_idx_trt_tbl] = 1.0
        trt_tbls[1][0, :, div_idx_trt_tbl] = 1.0
    feature_without_variance_sp = get_idx_feature_without_variance(trt_tbls[0])
    feature_without_variance_ex = get_idx_feature_without_variance(trt_tbls[1])
    remove_sp = []
    for i in feature_without_variance_sp:
        remove_sp.append(np.where(shap_names_sp[:, 0] == names_features_sp[i])[0])
        remove_sp.append(np.where(shap_names_sp[:, 1] == names_features_sp[i])[0])
    remove_ex = []
    for i in feature_without_variance_ex:
        remove_ex.append(np.where(shap_names_ex[:, 0] == names_features_ex[i])[0])
        remove_ex.append(np.where(shap_names_ex[:, 1] == names_features_ex[i])[0])
    remove_sp = np.array(list(pd.core.common.flatten(remove_sp))).astype(int)
    remove_ex = np.array(list(pd.core.common.flatten(remove_ex))).astype(int)
    mean_shap_sp = np.delete(mean_shap_sp, remove_sp[remove_sp < len(mean_shap_sp)], axis = 0)
    mean_shap_ex = np.delete(mean_shap_ex, remove_ex[remove_ex < len(mean_shap_ex)], axis = 0)
    shap_names_sp_del = np.delete(shap_names_sp, remove_sp, axis = 0)
    shap_names_ex_del = np.delete(shap_names_ex, remove_ex, axis = 0)
    shap_values_sp = pd.DataFrame(mean_shap_sp, columns = ['shap', 'lwr_shap', 'upr_shap'])
    shap_values_ex = pd.DataFrame(mean_shap_ex, columns = ['shap', 'lwr_shap', 'upr_shap'])
    shap_names_sp_del = pd.DataFrame(shap_names_sp_del, columns = ['feature1', 'feature2'])
    shap_names_ex_del = pd.DataFrame(shap_names_ex_del, columns = ['feature1', 'feature2'])
    shap_lam = pd.concat([shap_names_sp_del, shap_values_sp], axis = 1)
    shap_ex = pd.concat([shap_names_ex_del, shap_values_ex], axis = 1)
    taxa_names = sp_fad_lad["Taxon"]
    taxa_names_shap_sp = make_taxa_names_shap(taxa_names, n_species, shap_names_sp_del)
    taxa_names_shap_ex = make_taxa_names_shap(taxa_names, n_species, shap_names_ex_del)
    taxa_shap_sp = delete_invariantfeat_from_taxa_shap(feature_without_variance_sp, names_features_sp,
                                                       shap_names_sp, taxa_shap_sp)
    taxa_shap_ex = delete_invariantfeat_from_taxa_shap(feature_without_variance_ex, names_features_ex,
                                                       shap_names_ex, taxa_shap_ex)
    sp_from_shap = get_species_rates_from_shap(shap_values[:, (n_main_eff_sp + n_inter_eff_sp):n_effects_sp],
                                               n_species, n_main_eff_sp, mcmc_samples)
    ex_from_shap = get_species_rates_from_shap(shap_values[:, (n_effects_sp + n_main_eff_ex + n_inter_eff_ex):],
                                               n_species, n_main_eff_ex, mcmc_samples)
    taxa_shap_sp = merge_taxa_shap_and_species_rates(taxa_shap_sp, taxa_names_shap_sp, sp_from_shap, n_species)
    taxa_shap_ex = merge_taxa_shap_and_species_rates(taxa_shap_ex, taxa_names_shap_ex, ex_from_shap, n_species)
    return shap_lam, shap_ex, taxa_shap_sp, taxa_shap_ex

