# Regression and Driver Analysis

This notebook demonstrates multiple linear regression and permutation importance analysis to evaluate the impact of selected hydroecological drivers on ∆BGWS. Additionally, it does the same analysis using random forest and  SHAP (SHapley Additive exPlanations).

### Note

To replicate the results presented in the publication, it is essential to preprocess the complete ESM output data as described in the README or Method section of the publication.

**Following steps are included in this script:**

1. **Load period means**
   - Define data
   - Load the preprocessed data
2. **Compute BGWS and Ensemble Mean Change**
   - Compute ∆BGWS and the associated ensemble mean changes
   - Subdivide the dataset by BGWS regimes
3. **Regression Analysis**
   - Perform multiple linear regression using selected predictors
   - Evaluate model performance and calculate permutation importance
4. **Plot Permutation Importance**
   - Visualize feature importance with directional insight (Fig. 3 b & d / Supplementary Fig. S4)
5. **Random Forest Analysis**
   - Define hyperparameter grid for random forest models
   - Get the optimal hyperparameters with highest R2 in both datasets. Max. distance between training and testing data is 10% of R2
   - Define best hyperparameter set for the blue and green water regime and compute the SHAP values
6. **Plot Shap Values for Random Forest Models (Supplementary Fig. 8)**
   - Visualize SHAP values for the training and testing dataset of both regimes

In [None]:
# ========== Import Required Libraries ==========
import sys
import dask
from dask.diagnostics import ProgressBar

In [None]:
# ========== Configure Paths ==========
# Define the full path to the directories containing utility scripts and configurations
config_file = '../../src'
data_handling_dir = '../../src/data_handling'
data_analysis_dir = '../../src/analysis'
data_vis_dir = '../../src/visualization'

# Add the directories to sys.path
sys.path.append(config_file)
sys.path.append(data_handling_dir)
sys.path.append(data_analysis_dir)
sys.path.append(data_vis_dir)

# Import custom utility functions and configurations
import load_data as load_dat
import process_data as pro_dat
import regression_analysis as reg_analysis
import regression_analysis_results as reg_results
import compute_statistics as comp_stats
import regression_analysis_rf as rf_analysis

#import data directory
from config import DATA_DIR

In [None]:
# ========== Define Font Style ==========
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Nimbus Sans'

### Step 1: Load Period Means

In [None]:
# Step 1.1: Define the datasets
experiments = ['historical', 'ssp370']
models = ['BCC-CSM2-MR', 'CESM2', 'CMCC-CM2-SR5', 'CNRM-CM6-1', 'CNRM-ESM2-1', 'CanESM5', 
          'IPSL-CM6A-LR', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 'NorESM2-MM', 'UKESM1-0-LL']
variables=['pr', 'tran', 'mrro', 'vpd', 'mrso', 'lai', 'wue', 'RX5day']

In [None]:
# Step 1.2: Load the datasets
print("Loading period means...")
with ProgressBar():
    ds_dict = dask.compute(
        load_dat.load_period_mean(
            DATA_DIR, 'processed', experiments, models, variables
        )
    )[0]

### Step 2: Compute BGWS and Ensemble Mean Change

In [None]:
# Step 2.1: # Compute BGWS for both periods and ensemble mean for the historical period
ds_dict = pro_dat.compute_bgws(ds_dict)
ds_dict['historical'] = comp_stats.compute_ensemble_statistic(ds_dict['historical'], 'mean')

In [None]:
# Step 2.2: Compute change dictionaries and ensemble mean
ds_dict_change = pro_dat.compute_change_dict(ds_dict)
ds_dict_change['ssp370-historical'] = comp_stats.compute_ensemble_statistic(ds_dict_change['ssp370-historical'], 'mean')

In [None]:
# Step 2.3: Subdivide data by regimes and get final subdivided ensemble mean change dataset
ensemble_mean_change_sub = pro_dat.subdivide_ds_by_regime(ds_dict['historical']['Ensemble mean'], 
                                                      ds_dict_change['ssp370-historical']['Ensemble mean'], 'historical', 'ssp370-historical', 'bgws')

### Step 3: Multiple Linear Regression Analysis

In [None]:
# Step 3.1: Configure regression parameters
predictor_vars = ['pr', 'vpd', 'mrso', 'lai', 'wue', 'RX5day'] # Define the predictor variables used in the regression analysis
predictant = 'bgws' # Define the predictant (dependent variable) for the regression analysis
test_size = 0.3 # Share of the dataset used as test dataset (default is 30%)
cv_folds=5 # define k for cross-validation (default is 5)
n_permutations = 20 # Number of permutations for importance calculation
param_grid = { # Define the parameter grid for hyperparameter tuning during regression
    'alpha': [0.001, 0.01, 0.1, 1, 10, 100], # Regularization strength for ElasticNet
    'l1_ratio': [0.2, 0.5, 0.8] # Mix ratio between L1 (Lasso) and L2 (Ridge) penalties
}

In [None]:
# Step 3.2: Perform regression analysis for both regimes
results_bw_regime = {}
results_bw_regime = reg_analysis.regression_analysis(ensemble_mean_change_sub.isel(subdivision=0), predictor_vars, predictant, test_size,
                        param_grid, cv_folds, n_permutations)
results_gw_regime = {}
results_gw_regime = reg_analysis.regression_analysis(ensemble_mean_change_sub.isel(subdivision=1), predictor_vars, predictant, test_size,
                        param_grid, cv_folds, n_permutations)

print(f"Performance of Blue Water Regression Model (Testing Data): '{results_bw_regime['performance']['R2 Test']:.2f}")
print(f"Performance of Green Water Regression Model (Testing Data): '{results_gw_regime['performance']['R2 Test']:.2f}")

### Step 4: Plot Permutation Importance (Fig. 3 a & b / Supplementary Fig. S7)

In [None]:
# Step 4.1: Plot permutation importance for Blue Water regime
reg_results.plot_permutation_importance(
    results_bw_regime,
    predictor_vars,
    regime='bw_regime',  # Specify the regime being analyzed (Blue Water)
    importance_type='test',  # Choose between 'test' or 'train' dataset for importance analysis
    save_path=None  # Optionally specify a path to save the plot, e.g., "../../results"
)

In [None]:
# Step 4.2: Plot permutation importance for Green Water regime
reg_results.plot_permutation_importance(
    results_gw_regime,
    predictor_vars,
    regime='gw_regime',  # Specify the regime being analyzed (Green Water)
    importance_type='test',  # Choose between 'test' or 'train' dataset for importance analysis
    save_path=None  # Optionally specify a path to save the plot, e.g., "../../results"
)

### Step 5: Random Forest Analysis

In [None]:
# Step 4.1: Define hyperparameter grid for random forest models
rf_param_grid = {
    'max_depth': [10, 15, 20],
    'max_features': ['sqrt'],
    'min_samples_leaf': [5, 10, 12], 
    'min_samples_split': [2, 3], 
    'n_estimators': [200, 300]
}

In [None]:
#Step 4.2: Get the optimal hyperparameters with highest R2 in both datasets. Max. distance between training and testing data is 10% of R2
print("Testing optimal hyperparameters for the BLUE WATER REGIME...")
rf_results_bw_regime_test = {}
rf_results_bw_regime_test=rf_analysis.random_forest_analysis_with_overfitting_check(ensemble_mean_change_sub.isel(subdivision=0), predictor_vars, predictant, test_size,
                        rf_param_grid)
print("Testing optimal hyperparameters for the GREEN WATER REGIME...")
rf_results_gw_regime_test = {}
rf_results_gw_regime_test = rf_analysis.random_forest_analysis_with_overfitting_check(ensemble_mean_change_sub.isel(subdivision=1), predictor_vars, predictant, test_size,
                        rf_param_grid)

In [None]:
#Step 4.3: Define best hyperparameter set for the BLUE WATER REGIME and compute the SHAP values
rf_param_grid_bw = {
    'max_depth': [20],
    'max_features': ['sqrt'],
    'min_samples_leaf': [10],
    'min_samples_split': [2], 
    'n_estimators': [200]
}

rf_results_bw_regime = {}
rf_results_bw_regime = rf_analysis.random_forest_analysis(ensemble_mean_change_sub.isel(subdivision=0), predictor_vars, predictant, test_size,
                        rf_param_grid_bw, n_permutations=n_permutations, shap=True)
print(f"Performance of Blue Water Regression Model (Testing Data): {rf_results_bw_regime['performance']['R2 Test']:.2f}")
print(f"Performance of Blue Water Regression Model (Training Data): {rf_results_bw_regime['performance']['R2 Train']:.2f}")

In [None]:
#Step 4.4: Define best hyperparameter set for the GREEN WATER REGIME and compute the SHAP values
rf_param_grid_gw = {
    'max_depth': [15],
    'max_features': ['sqrt'],
    'min_samples_leaf': [5], 
    'min_samples_split': [2],
    'n_estimators': [300]
}

rf_results_gw_regime = {}
rf_results_gw_regime = rf_analysis.random_forest_analysis(ensemble_mean_change_sub.isel(subdivision=1), predictor_vars, predictant, test_size,
                        rf_param_grid_gw, n_permutations=n_permutations, shap=True)
print(f"Performance of Green Water Regression Model (Testing Data): {rf_results_gw_regime['performance']['R2 Test']:.2f}")
print(f"Performance of Green Water Regression Model (Training Data): {rf_results_gw_regime['performance']['R2 Train']:.2f}")

### Step 6: Plot Shap Values for Random Forest Models (Supplementary Fig. 8)

In [None]:
#Step 6.1: Plot SHAP values for the TRAINING dataset of the BLUE WATER REGIME
reg_results.plot_shap_summary(
    results=rf_results_bw_regime, 
    X=rf_results_bw_regime['X_train'], 
    predictor_vars=predictor_vars, 
    test_train='Train',
    save_path=None,
    title=f"bw_regime_train" 
) 

In [None]:
#Step 6.2: Plot SHAP values for the TESTING dataset of the BLUE WATER REGIME
reg_results.plot_shap_summary(
    results=rf_results_bw_regime, 
    X=rf_results_bw_regime['X_test'], 
    predictor_vars=predictor_vars, 
    test_train='Test',
    save_path=None,
    title=f"bw_regime_test" 
) 

In [None]:
#Step 6.3: Plot SHAP values for the TRAINING dataset of the GREEN WATER REGIME
reg_results.plot_shap_summary(
    results=rf_results_gw_regime,  
    X=rf_results_gw_regime['X_train'], 
    predictor_vars=predictor_vars, 
    test_train='Train',
    save_path=None, #"../../results/ssp370-historical/regression_analysis/rf/shap_importance/global/",
    title=f"gw_regime_train" 
)

In [None]:
#Step 6.4: Plot SHAP values for the TESTING dataset of the GREEN WATER REGIME
reg_results.plot_shap_summary(
    results=rf_results_gw_regime,  
    X=rf_results_gw_regime['X_test'], 
    predictor_vars=predictor_vars, 
    test_train='Test',
    save_path=None,
    title=f"gw_regime_test" 
)