# CMIP6 Statistics and Plots

**Following steps are included in this script:**

1. Load netCDF files
2. Compute statistics
3. Plot statistics

### Import Packages

In [None]:
# ========== Packages ==========
import xarray as xr
import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import dask
import matplotlib.cm
from matplotlib import rcParams
import math
import multiprocessing as mp
from cftime import DatetimeNoLeap
import glob
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, pearsonr, kendalltau
from matplotlib.lines import Line2D
from sklearn.metrics import r2_score
import matplotlib.colors as colors
from matplotlib.patches import Patch
import matplotlib.patches as mpatches



%matplotlib inline

rcParams["mathtext.default"] = 'regular'

### Functions

#### Open files

In [None]:
# ========= Helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

In [None]:
# Define a helper function to open and merge datasets
def open_and_merge_datasets(folder, model, experiment_id, variables):
    filepaths = []
    for var in variables:
        path = f'../../data/CMIP6/{experiment_id}/{folder}/{var}'
        fp = glob.glob(os.path.join(path, f'CMIP.{model}.{experiment_id}.{var}_regridded.nc'))
        if fp:
            filepaths.append(fp[0])
        else:
            #print(f"No file found for variable '{var}' in model '{model}'.")
            print(fp)

    datasets = [xr.open_dataset(fp) for fp in filepaths]
    ds = xr.merge(datasets)
    return ds

#### Helper functions

In [None]:
def drop_var(ds_dict, var):
    for name, ds in ds_dict.items():
        ds_dict[name] = ds.drop(var)
        
    return ds_dict

In [None]:
def select_period(ds_dict, start_year=None, end_year=None, period=None):
    '''
    Helper function to select periods.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    period (int, list, str): Single month (int), list of months (list), multiple seasons (str) to select.
    '''
    # Define season to month mapping for northern hemisphere
    seasons_to_months = {
        'nh_winter': [12, 1, 2],
        'nh_spring': [3, 4, 5],
        'nh_summer': [6, 7, 8],
        'nh_fall': [9, 10, 11]
    }

    months = []

    # If a single month is provided as an integer, convert it to a list
    if isinstance(period, int):
        months = [period]
    elif isinstance(period, str):
        # Check if the input is a single season or multiple seasons
        if 'and' in period:
            # Split the string into individual seasons
            seasons = period.lower().split('and')
            # Extend the months list with months of each season
            for season in seasons:
                season = season.strip()  # Remove leading/trailing whitespace
                months.extend(seasons_to_months.get(season, []))
        else:
            # If a season is provided as a string, map it to the corresponding list of months
            months = seasons_to_months.get(period.lower(), [])
    elif isinstance(period, list):
        # If a list is provided, use it directly
        months = period
    else:
        raise ValueError("Period must be an integer, a string representing a single season, "
                         "a string with multiple seasons separated by 'and', or a list of integers.")

    for k, ds in ds_dict.items():
        # Select the time slice based on start and end year
        if start_year and end_year:
            start_date = f'{start_year}-01-01'
            end_date = f'{end_year}-12-31'
            ds = ds.sel(time=slice(start_date, end_date))

        # If months are specified, select those months
        if months:
            # This creates a boolean mask that is True for the selected months
            month_mask = DataArray(ds['time.month'].isin(months), coords=ds['time'].coords)
            ds = ds.where(month_mask, drop=True)

        # Update the dictionary with the selected data
        ds_dict[k] = ds
    
    return ds_dict

In [None]:
def select_period(ds_dict, start_year=None, end_year=None, period=None, yearly_sum=False):
    '''
    Helper function to select periods and optionally compute yearly sums.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    period (int, list, str): Single month (int), list of months (list), multiple seasons (str) to select.
    yearly_sum (bool): If True, compute the yearly sum over the selected period.
    '''
    # Define season to month mapping for northern hemisphere
    seasons_to_months = {
        'nh_winter': [12, 1, 2],
        'nh_spring': [3, 4, 5],
        'nh_summer': [6, 7, 8],
        'nh_fall': [9, 10, 11]
    }

    months = []

    # If a single month is provided as an integer, convert it to a list
    if isinstance(period, int):
        months = [period]
    elif isinstance(period, str):
        # Check if the input is a single season or multiple seasons
        if 'and' in period:
            # Split the string into individual seasons
            seasons = period.lower().split('and')
            # Extend the months list with months of each season
            for season in seasons:
                season = season.strip()  # Remove leading/trailing whitespace
                months.extend(seasons_to_months.get(season, []))
        else:
            # If a season is provided as a string, map it to the corresponding list of months
            months = seasons_to_months.get(period.lower(), [])
    elif isinstance(period, list):
        # If a list is provided, use it directly
        months = period
    else:
        raise ValueError("Period must be an integer, a string representing a single season, "
                         "a string with multiple seasons separated by 'and', or a list of integers.")

    for k, ds in ds_dict.items():
        # Select the time slice based on start and end year
        if start_year and end_year:
            start_date = f'{start_year}-01-01'
            end_date = f'{end_year}-12-31'
            ds = ds.sel(time=slice(start_date, end_date))

        # If months are specified, select those months
        if months:
            # This creates a boolean mask that is True for the selected months
            month_mask = DataArray(ds['time.month'].isin(months), coords=ds['time'].coords)
            ds = ds.where(month_mask, drop=True)

        # If yearly_sum is True, sum over 'time' dimension to get yearly sum
        if yearly_sum:
            ds = ds.resample(time='AS').sum(dim='time')

        # Update the dictionary with the selected data
        ds_dict[k] = ds
    
    return ds_dict

In [None]:
def select_period(ds_dict, start_year=None, end_year=None, period=None, yearly_sum=False):
    '''
    Helper function to select periods and optionally compute yearly sums.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    period (int, list, str, None): Single month (int), list of months (list), multiple seasons (str) to select,
                                   or None to not select any specific period.
    yearly_sum (bool): If True, compute the yearly sum over the selected period.
    '''
    # Define season to month mapping for northern hemisphere
    seasons_to_months = {
        'nh_winter': [12, 1, 2],
        'nh_spring': [3, 4, 5],
        'nh_summer': [6, 7, 8],
        'nh_fall': [9, 10, 11]
    }

    months = []

    # If no specific period is selected, all data will be used.
    if period is None:
        months = None
    elif isinstance(period, int):
        months = [period]
    elif isinstance(period, str):
        # Check if the input is a single season or multiple seasons
        if 'and' in period:
            seasons = period.lower().split('and')
            for season in seasons:
                season = season.strip()
                months.extend(seasons_to_months.get(season, []))
        else:
            months = seasons_to_months.get(period.lower(), [])
    elif isinstance(period, list):
        months = period
    else:
        raise ValueError("Period must be None, an integer, a string representing a single season, "
                         "a string with multiple seasons separated by 'and', or a list of integers.")

    for k, ds in ds_dict.items():
        if start_year and end_year:
            start_date = f'{start_year}-01-01'
            end_date = f'{end_year}-12-31'
            ds = ds.sel(time=slice(start_date, end_date))

        # If months are specified, select those months
        if months is not None:
            month_mask = DataArray(ds['time.month'].isin(months), coords=ds['time'].coords)
            ds = ds.where(month_mask, drop=True)

        # If yearly_sum is True, sum over 'time' dimension to get yearly sum
        if yearly_sum:
            ds = ds.resample(time='AS').sum(dim='time')

        ds_dict[k] = ds

    return ds_dict

In [None]:
# ======== Standardize ========
def standardize(ds_dict):
    '''
    Helper function to standardize datasets of a dictionary
    '''
    ds_dict_stand = {}
    
    for name, ds in ds_dict.items():
        attrs = ds.attrs
        ds_stand = (ds - ds.mean()) / ds.std()

        # Preserve variable attributes from the original dataset
        for var in ds.variables:
            if var in ds_stand.variables:
                ds_stand[var].attrs = ds[var].attrs

        ds_stand.attrs = attrs
        ds_dict_stand[name] = ds_stand
        
    return ds_dict_stand

In [None]:
def check_args_and_get_info(ds_dict, variable):
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if not isinstance(variable, str):
        raise TypeError('variable must be a string.')
        
    # Dictionary to store plot titles for each statistic
    titles = {"mean": "Mean", "std": "Standard deviation of yearly means", "min": "Minimum", "max": "Maximum", "median": "Median", "time": "Time", "space": "Space"}
    freq = {"mon": "Monthly"}
    
    long_name = {
        'Precipitation': 'Precipitation',
        'Total Runoff': 'Total Runoff',
        'Vapor Pressure Deficit': 'Vapor Pressure Deficit',
        'Evaporation Including Sublimation and Transpiration': 'Evapotranspiration',
        'Transpiration': 'Transpiration',
        'Leaf Area Index': 'Leaf Area Index',
        'Carbon Mass Flux out of Atmosphere Due to Gross Primary Production on Land [kgC m-2 s-1]': 'Gross Primary Production',
        'Total Liquid Soil Moisture Content of 1 m Column': '1 m Soil Moisture',
        'Total Liquid Soil Moisture Content of 2 m Column': '2 m Soil Moisture',
        'Runoff - Precipitation': 'Runoff - Precipitation',
        'Transpiration - Precipitation': 'Transpiration - Precipitation',
        '(Runoff + Transpiration) - Precipitation':  '(Runoff + Transpiration) - Precipitation',
        'ET - Precipitation':  'ET - Precipitation', 
        'Negative Runoff': 'Negative Runoff',
    }
   
    # Data information
    var_long_name = ds_dict[list(ds_dict.keys())[0]][variable].long_name
    period = f"{ds_dict[list(ds_dict.keys())[0]].attrs['period'][0]}-{ds_dict[list(ds_dict.keys())[0]].attrs['period'][1]}"
    experiment_id =  ds_dict[list(ds_dict.keys())[0]].experiment_id
    unit = ds_dict[list(ds_dict.keys())[0]][variable].units
    statistic_dim = ds_dict[list(ds_dict.keys())[0]].statistic_dimension
    statistic = ds_dict[list(ds_dict.keys())[0]].attrs['statistic']
    frequency = freq[ds_dict[list(ds_dict.keys())[0]].frequency]

    return var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency

In [None]:
import regionmask

def apply_region_mask(ds_dict):
    """
    Applies the AR6 land region mask to datasets in the provided dictionary and adds a region dimension.

    Args:
        ds_dict (dict): A dictionary of xarray datasets.

    Returns:
        dict: A new dictionary where keys are the same as in the input dictionary,
              and each value is an xarray Dataset with a region dimension added to each variable.
    """
    
    land_regions = regionmask.defined_regions.ar6.land
    ds_masked_dict = {}
    
    for ds_name, ds in ds_dict.items():
        ds_out = xr.Dataset()  # Initiate an empty Dataset for the masked data
        
        # Get attributes
        attrs = ds.attrs
        
        for var in ds:
            # Get the binary mask
            mask = land_regions.mask_3D(ds[var])
            
            var_attrs = ds[var].attrs

            # Multiply the original data with the mask to get the masked data
            masked_var = ds[var] * mask

            # Replace 0s with NaNs, if desired
            masked_var = masked_var.where(masked_var != 0)

            # Add the masked variable to the output Dataset
            ds_out[var] = masked_var
            
            ds_out[var].attrs = var_attrs
            
        # Add the attributes
        ds_out.attrs = attrs

        # Add the Dataset to the output dictionary
        ds_masked_dict[ds_name] = ds_out

    return ds_masked_dict

In [None]:
def compute_change(ds_dict_hist_mean, ds_dict_ssp370_mean, relative_change=False):
    ds_dict_change = {}

    for name, ds in ds_dict_hist_mean.items():
        
        # Pre-Filter to keep only variables present in both datasets
        common_vars = [var for var in ds.variables if var in ds_dict_ssp370_mean[name].variables]
        ds = ds.drop_vars([var for var in ds.variables if var not in common_vars])
        ds_dict_ssp370_mean[name] = ds_dict_ssp370_mean[name].drop_vars([var for var in ds_dict_ssp370_mean[name].variables if var not in common_vars])
        
        # Compute either absolute or relative change
        if relative_change:
            # Add a constant to the denominator to avoid division by zero
            epsilon = 1
            ds_nonzero = ds.where(ds != 0, epsilon)
            
            ds_dict_change[name] = ((ds_dict_ssp370_mean[name] - ds) / ds_nonzero) * 100
            
        else:
            ds_dict_change[name] = ds_dict_ssp370_mean[name] - ds
            
        ds_dict_change[name].attrs = {'period': 'SSP370 - Historical',
                                      'statistic': ds_dict_ssp370_mean[name].statistic,
                                      'statistic_dimension':  ds_dict_ssp370_mean[name].statistic_dimension,
                                      'experiment_id': 'ssp370-historical',
                                      'source_id': ds_dict_ssp370_mean[name].source_id
                                    }
        
        for variables in ds:
            ds_dict_change[name][variables].attrs = ds_dict_ssp370_mean[name][variables].attrs
            if relative_change:
                ds_dict_change[name][variables].attrs['units_rel'] = '%'
                ds_dict_change[name].attrs['change'] = 'Relative Change'
            else:
                ds_dict_change[name].attrs['change'] = 'Absolute Change'
             
    return ds_dict_change


In [None]:
def compute_ensemble(ds_dict_change, metric='mean'):
    for key in ['Ensemble mean', 'Ensemble median']:
        if key in ds_dict_change:
            ds_dict_change.pop(key)

    # Drop 'member_id' coordinate if it exists in any of the datasets
    for ds_key in ds_dict_change:
        if 'member_id' in ds_dict_change[ds_key].coords:
            ds_dict_change[ds_key] = ds_dict_change[ds_key].drop('member_id')

    combined = xr.concat(ds_dict_change.values(), dim='ensemble')
    ds_dict_change[f'Ensemble {metric}'] = getattr(combined, metric)(dim='ensemble')
    
    return ds_dict_change

#### Compute statistics

In [None]:
def calculate_spatial_mean(ds_dict):
    ds_dict_mean = {}
    for key, ds in ds_dict.items():
        attrs = ds.attrs
        for var in list(ds.data_vars.keys()):
            var_attrs = ds[var].attrs
            
            ds_dict_mean[key][var] = ds.mean(['lon', 'lat'])
            ds_dict_mean[key][var].attrs = var_attrs
        
        ds_dict_mean[key].attrs = attrs
        
    return ds_dict_mean

In [None]:
def compute_statistic_single(ds, statistic, dimension, yearly_mean=True):
    if dimension == "time":
        stat_ds = getattr(ds, statistic)("time", keep_attrs=True, skipna=True)
        stat_ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
    if dimension == "space":
        # Assign the period attribute before grouping by year
        ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
        if yearly_mean:
            ds = ds.groupby('time.year').mean('time', keep_attrs=True, skipna=True)
            ds.attrs['mean'] = 'yearly mean'
            
        
        #get the weights, apply on data, and compute statistic
        weights = np.cos(np.deg2rad(ds.lat))
        weights.name = "weights"
        ds_weighted = ds.weighted(weights)
        stat_ds = getattr(ds_weighted, statistic)(("lon", "lat"), keep_attrs=True, skipna=True)
    
    stat_ds.attrs['statistic'] = statistic
    stat_ds.attrs['statistic_dimension'] = dimension

    return stat_ds

In [None]:
def compute_statistic(ds_dict, statistic, dimension, start_year=None, end_year=None, yearly_mean=True):
    """
    Computes the specified statistic for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        statistic (str): The statistic to compute, which can be one of 'mean', 'std', 'min', 'var', or 'median'.
        dimension (str): The dimension to compute over, which can be 'time' or 'space'.
        start_year (str, optional): The start year of the period to compute the statistic over.
        end_year (str, optional): The end year of the period to compute the statistic over.

    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if statistic not in ["mean", "std", "min", "max", "var", "median"]:
        raise ValueError(f"Invalid statistic '{statistic}' specified.")
    if dimension not in ["time", "space"]:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")

    # Select period
    if start_year is not None and end_year is not None:
        select_period(ds_dict, start_year=start_year, end_year=end_year)
        
        
    # Use multiprocessing to compute the statistic for each dataset in parallel
    with mp.Pool() as pool:
        results = pool.starmap(compute_statistic_single, [(ds, statistic, dimension, yearly_mean) for ds in ds_dict.values()])

    return dict(zip(ds_dict.keys(), results))

In [None]:
from itertools import permutations

def precompute_metrics(ds_dict, variables, metrics=['pearson']):
    # Initialize the results dictionary
    results_dict = {metric: {} for metric in metrics}
    
    for name, ds in ds_dict.items():
        # Create a DataFrame with all the variables
        df = pd.DataFrame({var: ds[var].values.flatten() for var in variables})
        
        # Define all pairs of variables
        pairs = list(permutations(variables, 2))  # <-- Change here
        args = [(df, var1, var2, metrics) for var1, var2 in pairs]

        # Use a multiprocessing pool to compute the metrics for all pairs
        with Pool() as p:
            results = p.map(compute_metrics_for_pair, args)
        
        # Store the results in the results_dict
        for var1, var2, metric_dict in results:
            for metric, value in metric_dict.items():
                # Ensure the keys exist in the dictionary
                results_dict[metric].setdefault(name, {}).setdefault(f'{var1}_{var2}', value)
    return results_dict

In [None]:
def compute_stats(ds_dict):
    """
    Compute yearly mean of each variable in the dataset.

    Parameters:
    ds_dict (dict): The input dictionary of xarray.Dataset.

    Returns:
    dict: A dictionary where the keys are the dataset names and the values are another dictionary.
          This inner dictionary has keys as variable names and values as DataArray of yearly means.
    """
    stats = {}
    for model, ds in ds_dict.items():
        # Compute the yearly mean
        yearly_ds = ds.resample(time='1Y').mean()

        stats[model] = {}
        for var in yearly_ds.data_vars:
            # Compute the spatial mean
            spatial_mean = yearly_ds[var].mean(dim=['lat', 'lon'])
            
            # Store the yearly mean values
            stats[model][var] = spatial_mean
    return stats

In [None]:
def compute_yearly_means(ds_dict):
    yearly_means_dict = {}

    # For each dataset, compute the yearly mean over the 'time', 'lat', and 'lon' dimensions
    for name, ds in ds_dict.items():  
        ds_yearly = ds.groupby('time.year').mean('time')    
        
        yearly_means_dict[name] = ds_yearly

    return yearly_means_dict

In [None]:
def compute_yearly_regional_means(ds_dict_region):
    yearly_means_dict = {}

    # For each dataset, compute the yearly mean over the 'time', 'lat', and 'lon' dimensions
    for region, ds_dict in ds_dict_region.items():
        yearly_means_dict[region] = {}
        for ds_name, ds in ds_dict.items():
            # Compute the yearly mean
            ds_yearly = ds.groupby('time.year').mean('time')
            
            # Create weights
            weights = np.cos(np.deg2rad(ds.lat))
            # Apply the weights and calculate the spatial mean
            ds_weighted = ds_yearly.weighted(weights)
            yearly_means_dict[region][ds_name] = ds_weighted.mean(('lat', 'lon'))

    return yearly_means_dict

In [None]:
def compute_change(ds_dict_hist, ds_dict_fut, var_rel_change=None):
    ds_dict_change = {}
    
    if var_rel_change == "ALL":
        var_rel_change = list(ds_dict_hist[next(iter(ds_dict_hist))].data_vars.keys())
    elif var_rel_change is None:
        var_rel_change = []

    for name, ds in ds_dict_hist.items():
        if name in ds_dict_fut:
            ds_f = ds_dict_fut[name]
            fut_scenario = ds_f.experiment_id
            
            change_ds = xr.Dataset()
            
            for variable in ds.data_vars:
                if variable in var_rel_change:
                    # Compute relative change only where ds is not 0
                    rel_change = ds[variable].where(ds[variable] != 0)
                    rel_change = ((ds_f[variable] - rel_change) / abs(rel_change)) * 100
                    
                    # Where ds was 0, set the corresponding relative change to np.nan
                    rel_change = rel_change.where(ds[variable] != 0)
                    
                    change_ds[variable] = rel_change
                    
                else:
                    change_ds[variable] = ds_f[variable] - ds[variable]
                    
                change_ds[variable].attrs = ds_f[variable].attrs

            ds_dict_change[name] = change_ds

            ds_dict_change[name].attrs = {
                'period': f'Change {fut_scenario} - Historical',
                'statistic': ds_f.statistic,
                'statistic_dimension': ds_f.statistic_dimension,
                'experiment_id': fut_scenario.lower() + '-historical',
                'source_id': ds_f.source_id,
                'frequency': ds_f.frequency,
                'change': 'Mixed Change',
            }

    return ds_dict_change

In [None]:
def calculate_spatial_mean(ds_dict):
    ds_dict_mean = {}
    
    for key, ds in ds_dict.items():
        attrs = ds.attrs
        
        # Initialize a new Dataset for this key
        ds_dict_mean[key] = xr.Dataset()
        
        for var in list(ds.data_vars.keys()):
            var_attrs = ds[var].attrs
            
            ds_dict_mean[key][var] = ds[var].mean(['lon', 'lat'])
            ds_dict_mean[key][var].attrs = var_attrs
        
        ds_dict_mean[key].attrs = attrs
        
    return ds_dict_mean

In [None]:
def compute_bgws(ds_dict):

    for model, ds in ds_dict.items():
        bgws = (ds['mrro']-ds['tran'])/ds['pr']

        # Replace infinite values with NaN
        bgws = xr.where(np.isinf(bgws), float('nan'), bgws)

        # Set all values above 2 and below -2 to NaN
        bgws = xr.where(bgws > 2, float('nan'), bgws)
        bgws = xr.where(bgws < -2, float('nan'), bgws)

        ds_dict[model]['bgws'] = bgws
        ds_dict[model]['bgws'].attrs = {'long_name': 'Blue Green Water Share',
                             'units': ''}
        
    return ds_dict

#### Plot metrics

In [None]:
def plot_time_series_correlations(yearly_correlations, variable_pairs, target_variable):
    """
    Plots the time series of correlations for each variable pair that includes the target variable.

    Parameters:
    yearly_correlations (dict): The output from calculate_yearly_correlations.
    variable_pairs (list of tuples): The pairs of variables that the correlations were calculated for.
    target_variable (str): The variable that must be included in a pair for it to be plotted.
    """

    # Calculate the number of plots
    n_plots = sum([var1 == target_variable or var2 == target_variable for var1, var2 in variable_pairs])

    # Calculate the dimensions of the grid of subplots
    grid_size = math.ceil(math.sqrt(n_plots))
    
    # Create the figure
    fig, axs = plt.subplots(grid_size, grid_size, figsize=(15, 15), sharex=True, sharey=True)

    # Flatten the axes
    axs = axs.flatten()

    # Create an index for the current plot
    i_plot = 0

    # Prepare a list to store handles and labels for the legend
    handles, labels = [], []

    for var1, var2 in variable_pairs:
        if var1 == target_variable or var2 == target_variable:
            # Select the current axes
            ax = axs[i_plot]
            
            # Construct the correlation variable name
            corr_var = f'{var1}-{var2}'

            # Prepare a list to store correlations of all models
            all_corrs = []

            for name, ds in yearly_correlations.items():
                # Check if this variable exists in the Dataset
                if corr_var in ds:
                    # Plot the time series of the correlation
                    line, = ax.plot(ds['time'], ds[corr_var], label=name)

                    # Append to all_corrs
                    all_corrs.append(ds[corr_var])

                    # Append to handles and labels if not already present
                    if name not in labels:
                        handles.append(line)
                        labels.append(name)

            # Compute the mean correlation across all models
            mean_corr = xr.concat(all_corrs, dim='model').mean(dim='model')
            mean_line, = ax.plot(mean_corr['time'], mean_corr, color='black', linestyle='--')

            ax.set_xlabel('Year')
            ax.set_ylabel('Correlation')
            ax.set_title(f'{var1} vs {var2}')
            ax.grid(True)

            # Increment the plot index
            i_plot += 1

    # Add 'Mean' to the legend
    handles.append(mean_line)
    labels.append('Mean')

    # Show the figure with a legend
    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.1, 0.5))
    plt.tight_layout()
    plt.show()

In [None]:
def plot_mean_correlations(correlations_hist, correlations_ssp370, variable_pairs, target_variable, scale_axis=False, variable_captions=None):
    """
    Plots the mean correlations for each variable pair that includes the target variable.

    Parameters:
    correlations_hist (dict): The output from calculate_correlations for the historical period.
    correlations_ssp370 (dict): The output from calculate_correlations for the SSP370 scenario.
    variable_pairs (list of tuples): The pairs of variables that the correlations were calculated for.
    target_variable (str): The variable that must be included in a pair for it to be plotted.
    scale_axis (bool): Whether to scale the y-axis according to metric value ranges. Default is False.
    """
    
    # Get info
    metric = correlations_hist[list(correlations_hist.keys())[0]].Metric
    metric_sign = correlations_hist[list(correlations_hist.keys())[0]].Metric_sign
    
    # Filter variable pairs
    variable_pairs = [(var1, var2) for var1, var2 in variable_pairs if var1 == target_variable or var2 == target_variable]

    # Calculate the number of plots and dimensions of the grid of subplots
    n_plots = len(variable_pairs)
    n_cols = min(n_plots, 3)  # Maximum 3 plots in a row
    n_rows = int(np.ceil(n_plots / n_cols))

    # Create the figure
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows), squeeze=False)
    axs = axs.flatten()  # Flatten the axes array
    
    for i, (var1, var2) in enumerate(variable_pairs):
        # Prepare lists to store yearly correlation values
        yearly_corr_hist, yearly_corr_ssp370 = [], []

        for name in correlations_hist.keys():
            # Extract the yearly mean values for the historical period
            yearly_corr_hist.append(correlations_hist[name][f'{var1}-{var2}'].values)

            # Extract the yearly mean values for the SSP370 scenario
            yearly_corr_ssp370.append(correlations_ssp370[name][f'{var1}-{var2}'].values)

        # Compute the box plot positions
        positions = np.arange(len(correlations_hist.keys()))

        # Select the current axes
        ax = axs[i]

        # Define an offset for x-values to place box plots for different periods side by side
        offset = 0.15

        # Plot the box plots for the historical period
        ax.boxplot(yearly_corr_hist, positions=positions-offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='cornflowerblue'), medianprops=dict(color='black'), 
                   showfliers=True)

        # Plot the box plots for the SSP370 scenario
        ax.boxplot(yearly_corr_ssp370, positions=positions+offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='sandybrown'), medianprops=dict(color='black'), 
                   showfliers=True)

        # Set the x-ticks labels and the title
        ax.set_ylabel(f'{metric_sign}')
        ax.set_title(f'{variable_captions.get(var1, var1)} x {variable_captions.get(var2, var2)}', fontsize=9)
        ax.set_xticks(positions)
        ax.set_xticklabels(list(correlations_hist.keys()), rotation=90)
        
    fig.suptitle(f'Yearly {metric} for Historical (1985-2014) and SSP370 (2071-2100) Period', fontsize=12, y=1.0)
    
    # Set a legend
    axs[0].legend([Patch(facecolor='cornflowerblue'), Patch(facecolor='sandybrown')], ['Historical', 'SSP370'])

    # Handle empty subplots in case n_plots is less than n_rows * n_cols
    for i in range(n_plots, n_rows*n_cols):
        fig.delaxes(axs[i])
    
    # Handle y-axis scaling
    if scale_axis:
        for ax in axs:
            ax.set_ylim([-1, 1] if metric != 'r2_lr' else [0, 1])

    # Adjust the layout
    plt.tight_layout()
    plt.show()
    
    # Save figure
    suffix = "_scaled_axis" if scale_axis else ""
    filename = f"{metric}_changes_{target_variable}{suffix}.png"
    
    savepath = f'../../results/CMIP6/yearly_metrics_comparison'
    os.makedirs(savepath, exist_ok=True)

    filepath = os.path.join(savepath, filename)

    fig.savefig(filepath, dpi=300, bbox_inches='tight')

In [None]:
def ensemble_corr_change_plot(ds_dict, target_variable, full_var_names_and_unit, cmap='viridis', save_fig=False, file_format='png'):
    """
    Plots a map of the specified statistic of the given variable for the Ensemble_mean dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        target_variable (str): The target variable to plot correlations with.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.
    """
    # Info
    experiment_id = ds_dict[list(ds_dict.keys())[0]].experiment_id
    metric = ds_dict[list(ds_dict.keys())[0]].Metric
    metric_sign = ds_dict[list(ds_dict.keys())[0]].Metric_sign
    means = ds_dict[list(ds_dict.keys())[0]].attrs['means']
    target_var_long_name = full_var_names_and_unit[target_variable][0]
    
    # Create a figure
    n_cols = 3  # Set number of columns to 3
    n_rows = 3  # Set number of rows to 3
    
    fig = plt.figure(figsize=[12 * n_cols, 6.5 * n_rows])  # Start with a blank figure, without subplots

    plt.style.use('default')

    # Loop over datasets and plot the requested statistic
    subplot_counter = 0

    # Get the Ensemble_mean dataset
    ensemble_ds = ds_dict.get("Ensemble_mean", None)

    if ensemble_ds is None:
        print("Ensemble_mean dataset not found.")
        return None

    for variable in ensemble_ds.variables:
        if not (f'{target_variable} x ' in variable or f' x {target_variable}' in variable):
            continue

        # Add a new subplot with a cartopy projection
        ax = fig.add_subplot(n_rows, n_cols, subplot_counter+1, projection=ccrs.Robinson())

        data_to_plot = ensemble_ds[variable]
        im = data_to_plot.plot(ax=ax, cmap=cmap, vmin = -1, vmax = 1, transform=ccrs.PlateCarree(), add_colorbar=False)  # Added a cartopy transform to the plot and cmap parameter
        
        if f'{target_variable} x ' in variable:
            corr_var = variable.replace(f'{target_variable} x ', '')
        elif f' x {target_variable}' in variable:
            corr_var = variable.replace(f' x {target_variable}', '')
        else:
            continue
        corr_var_long_name = full_var_names_and_unit[corr_var][0]
        ax.set_title(f'{target_var_long_name} x {corr_var_long_name}', fontsize=18)  # Use the long names in the title
        ax.coastlines()  # Adds lines around the continents

        subplot_counter += 1

    # Adjust spacing between subplots
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=fig.axes, orientation='horizontal', fraction=0.02, pad=0.04, aspect=75, shrink=0.4)
    

    # Set colorbar ticks
    cbar.set_ticks([-0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75])
    
    
    # Set tick size
    cbar.ax.tick_params(labelsize=20)  # Adjust size as needed
    
    # Set colorbar label
    cbar.set_label(f'{metric_sign} change', size=26)  # Adjust size as needed
    
    # Set figure title with first and last year of dataset 
    if experiment_id == 'historical' or experiment_id == 'ssp370':
        fig.suptitle(f"{metric} ({experiment_id}) of Ensemble Mean for Variable Combinations with {target_var_long_name} ({means})", fontsize=20, y=0.9)
    elif experiment_id == 'ssp370-historical':
        fig.suptitle(f"{metric} Change ({experiment_id}) of Ensemble Mean for Variable Combinations with {target_var_long_name} ({means})", fontsize=20, y=0.9)
    
    
    # Show plot
    plt.show()
    
    # Save figure
    if save_fig:
        if experiment_id == 'historical' or experiment_id == 'ssp370':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename = f'ensmean_{target_variable}_correlations.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
        elif experiment_id == 'ssp370-historical':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename =f'ensmean_{target_variable}_correlations_change.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
        
    return filepath

In [None]:
def corr_maps(ds_dict, target_variable_combination, cmap='viridis', save_fig=False, file_format='png'):
    """
    Plots a map of the specified variable combination for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        target_variable_combination (str): The target variable combination to plot.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.
    """
    # Info
    experiment_id = ds_dict[list(ds_dict.keys())[0]].experiment_id
    metric = ds_dict[list(ds_dict.keys())[0]].Metric
    metric_sign = ds_dict[list(ds_dict.keys())[0]].Metric_sign
    means = ds_dict[list(ds_dict.keys())[0]].attrs['means']
    
    # Create a figure
    n_datasets_with_var = sum([1 for ds in ds_dict.values() if target_variable_combination in ds.variables])
    n_cols = 3  # Set number of columns to 3
    n_rows = math.ceil(n_datasets_with_var / n_cols)  # Calculate rows
    
    fig = plt.figure(figsize=[12 * n_cols, 6.5 * n_rows])  # Start with a blank figure, without subplots

    plt.style.use('default')

    # Loop over datasets and plot the requested variable combination
    subplot_counter = 0
    for i, (name, ds) in enumerate(ds_dict.items()):
        if target_variable_combination not in ds.variables:
            continue

        # Add a new subplot with a cartopy projection
        ax = fig.add_subplot(n_rows, n_cols, subplot_counter+1, projection=ccrs.Robinson())
        
        data_to_plot = ds[target_variable_combination]
        im = data_to_plot.plot(ax=ax, cmap=cmap, transform=ccrs.PlateCarree(), add_colorbar=False)
        ax.set_title(name, fontsize=18)
        ax.coastlines()  # Adds lines around the continents

        subplot_counter += 1

    # Adjust spacing between subplots
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=fig.axes, orientation='horizontal', fraction=0.02, pad=0.04, aspect=75, shrink=0.4)
    
    # Set tick size
    cbar.ax.tick_params(labelsize=14)  # Adjust size as needed
    
    # Set colorbar label
    cbar.set_label(metric_sign, size=18)  # Adjust size as needed
    
    # Set figure title with first and last year of dataset 
    if experiment_id == 'historical' or 'ssp370':
        fig.suptitle(f"{metric} ({experiment_id}) for Variable Combination {target_variable_combination} ({means})", fontsize=20, y=0.9)
    elif experiment_id == 'ssp370-historical':
        fig.suptitle(f"{metric} Change ({experiment_id}) for Variable Combination {target_variable_combination} ({means})", fontsize=20, y=0.9)
    
    # Show plot
    plt.show()
    
    # Save figure
    if save_fig:
        if experiment_id == 'historical' or 'ssp370':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename = f'{target_variable_combination}_correlation.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
        elif experiment_id == 'ssp370-historical':
            savepath = os.path.join('../..', 'results', 'CMIP6', 'comparison', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename =f'{target_variable_combination}_correlation_change.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

### Load netCDF files

In [None]:
# ========= Define period, models and path ==============
#variables=['tas', 'pr', 'vpd', 'evspsbl', 'mrro', 'lmrso_1m', 'lmrso_2m', 'tran', 'lai', 'gpp', 'EI', 'wue']
variables=['tas', 'pr', 'vpd', 'evapo', 'mrro', 'mrso', 'tran', 'lai', 'gpp']
folder='preprocessed'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# Create dictionary using a dictionary comprehension and Dask
experiment_id = 'historical'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','GISS-E2-1-G','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','TaiESM1','UKESM1-0-LL']
ds_dict_hist = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, variables) for model in source_id})[0]

experiment_id = 'ssp370'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','GISS-E2-1-G','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','TaiESM1','UKESM1-0-LL']
ds_dict_ssp370 = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, variables) for model in source_id})[0]

In [None]:
ds

In [None]:
def monthly_mean(ds, variable):
    """
    Calculate the mean for each month across all years for a given variable in the dataset.

    Parameters:
    ds (xarray.Dataset): The input dataset.
    variable (str): The name of the variable to calculate the monthly mean for.

    Returns:
    xarray.Dataset: A dataset containing the monthly mean for the specified variable.
    """
    # Calculate the monthly mean
    monthly_means = ds[variable].groupby('time.month').mean('time')

    # Create a new dataset for the monthly means
    monthly_mean_ds = xr.Dataset({variable: monthly_means})

    # Copy attributes from the original dataset and variable
    monthly_mean_ds.attrs = ds.attrs
    monthly_mean_ds[variable].attrs = ds[variable].attrs
    monthly_mean_ds[variable].attrs['description'] = f'Monthly mean of {variable}'

    return monthly_mean_ds

In [None]:
ds_monthly_mean = monthly_mean(ds, 'lai')

In [None]:
ds_monthly_mean

In [None]:
def calculate_growing_season_length(ds):
    # Assuming 'lai' is the key for Leaf Area Index in the dataset
    lai = ds['lai']
    
    # Initialize an empty DataArray to store the growing season length with yearly time dimension
    years = range(ds.time.dt.year.min().values, ds.time.dt.year.max().values + 1)
    growing_season_length = xr.DataArray(
        np.nan, 
        dims=('year', 'lat', 'lon'), 
        coords={'year': ds.time, 'lat': ds.lat, 'lon': ds.lon}
    )

    for year in years:
        lai_yearly = lai.sel(time=str(year))

        # Perform calculations for each grid cell
        for lat in ds.lat:
            for lon in ds.lon:
                lai_ts = lai_yearly.sel(lat=lat, lon=lon, method='nearest')
                
                # Calculate monthly mean LAI
                monthly_mean = lai_ts.groupby('time.month').mean('time')

                # Compute percentage change including December to January
                monthly_pct_change = ((monthly_mean - monthly_mean.roll(month=1)) / monthly_mean.roll(month=1)) * 100
                monthly_pct_change[0] = ((monthly_mean[0] - monthly_mean[-1]) / monthly_mean[-1]) * 100

                # Compute the midpoint value for the December to January transition
                dec_to_jan_midpoint_lai = (monthly_mean[-1] + monthly_mean[0]) / 2
                dec_to_jan_midpoint_pct_change = (monthly_pct_change[-1] + monthly_pct_change[0]) / 2

                # Adjust growing season start and end detection
                starts, ends = [], []
                # Check if December ends a growing season and January starts a new one
                if monthly_pct_change[-2] > 0 and monthly_pct_change[-1] < 0:
                    ends.append(12)  # December ends a growing season
                if monthly_pct_change[-1] <= 0 and monthly_pct_change[0] > 0:
                    starts.append(1)  # January starts a growing season

                #Check if January ends a growing season (wrap-around to the previous December)
                if monthly_pct_change[0] < 0 and monthly_pct_change[-1] > 0:
                    ends.append(1)  # January ends a growing season

                # Now handle February to November for starts and ends
                for i in range(1, len(monthly_pct_change) - 1):
                    if monthly_pct_change[i-1] <= 0 and monthly_pct_change[i] > 0:
                        starts.append(i + 1)
                    elif monthly_pct_change[i-1] > 0 and monthly_pct_change[i] <= 0:
                        ends.append(i + 1)

                # Calculate the growing season length(s)
                if len(starts) > 1:
                    if any(end > starts[0] for end in ends):
                        closest_end = min(filter(lambda end: end > starts[0], ends), key=lambda end: end - starts[0], default=None)
                        growing_season_length = closest_end - starts[0]
                    else:
                        growing_season_length = min(ends) + (12 - starts[0])

                    if any(end > starts[1] for end in ends):
                        growing_season_length = growing_season_length + max(ends) - starts[1]
                    else:
                        growing_season_length = growing_season_length + min(ends) + (12 - starts[1])
                elif len(starts) == 1:
                    if starts[0] < ends[0]:
                        growing_season_length = ends[0] - starts[0]
                    elif starts[0] > ends[0]:
                        growing_season_length = (12 - starts[0]) + ends[0]
                
                # Calculate the growing season length for this cell
                growing_season_length_cell = None
                
                # Store the calculated value in the DataArray
                growing_season_length.loc[dict(time=str(year), lat=lat, lon=lon)] = growing_season_length_cell

    # Update dataset with new variable
    ds['growing_season_length'] = growing_season_length
    ds['growing_season_length'].attrs = {
        'description': 'Length of the growing season in months',
        'calculated_from': 'LAI'
    }

    return ds

In [None]:
ds_gs = calculate_growing_season_length(ds_monthly_mean)

In [None]:
def calculate_growing_season_length(ds):
    lai = ds['lai']
    
    years = range(ds.time.dt.year.min().values, ds.time.dt.year.max().values + 1)
    growing_season_length = xr.DataArray(
        np.nan, 
        dims=('year', 'lat', 'lon'), 
        coords={'year': list(years), 'lat': ds.lat, 'lon': ds.lon}
    )
    
    years = range(ds.time.dt.year.min().values, ds.time.dt.year.max().values + 1)
    growing_season_length = xr.DataArray(
        np.nan, 
        dims=('year', 'lat', 'lon'), 
        coords={'year': list(years), 'lat': ds.lat, 'lon': ds.lon}
    )


    for year in years:
        lai_yearly = lai.sel(time=str(year))
        for lat in ds.lat.values:
            for lon in ds.lon.values:
                lai_ts = lai_yearly.sel(lat=lat, lon=lon, method='nearest')
                monthly_mean, monthly_pct_change = calculate_monthly_changes(lai_ts)
                starts, ends = detect_season_starts_and_ends(monthly_pct_change)
                length = calculate_season_length(starts, ends)
                growing_season_length.loc[dict(year=year, lat=lat, lon=lon)] = length

    ds['growing_season_length'] = growing_season_length
    ds['growing_season_length'].attrs = {
        'description': 'Length of the growing season in months',
        'calculated_from': 'LAI'
    }

    return ds

def calculate_monthly_changes(lai_ts):
    monthly_mean = lai_ts.groupby('time.month').mean('time')
    monthly_pct_change = ((monthly_mean - monthly_mean.roll(month=1)) / monthly_mean.roll(month=1)) * 100
    monthly_pct_change[0] = ((monthly_mean[0] - monthly_mean[-1]) / monthly_mean[-1]) * 100
    return monthly_mean, monthly_pct_change

def detect_season_starts_and_ends(monthly_pct_change):
    starts, ends = [], []
    
    # Check if December ends a growing season and January starts a new one
    if monthly_pct_change[-2] > 0 and monthly_pct_change[-1] < 0:
        ends.append(12)  # December ends a growing season
    if monthly_pct_change[-1] <= 0 and monthly_pct_change[0] > 0:
        starts.append(1)  # January starts a growing season

    #Check if January ends a growing season (wrap-around to the previous December)
    if monthly_pct_change[0] < 0 and monthly_pct_change[-1] > 0:
        ends.append(1)  # January ends a growing season

    # Now handle February to November for starts and ends
    for i in range(1, len(monthly_pct_change) - 1):
        if monthly_pct_change[i-1] <= 0 and monthly_pct_change[i] > 0:
            starts.append(i + 1)
        elif monthly_pct_change[i-1] > 0 and monthly_pct_change[i] <= 0:
            ends.append(i + 1)
            
    return starts, ends

def calculate_season_length(starts, ends):
    growing_season_length = 0
    
    # Calculate the growing season length(s)
    if len(starts) > 1:
        if any(end > starts[0] for end in ends):
            closest_end = min(filter(lambda end: end > starts[0], ends), key=lambda end: end - starts[0], default=None)
            growing_season_length = closest_end - starts[0]
        else:
            growing_season_length = min(ends) + (12 - starts[0])

        if any(end > starts[1] for end in ends):
            growing_season_length = growing_season_length + max(ends) - starts[1]
        else:
            growing_season_length = growing_season_length + min(ends) + (12 - starts[1])
    elif len(starts) == 1:
        if starts[0] < ends[0]:
            growing_season_length = ends[0] - starts[0]
        elif starts[0] > ends[0]:
            growing_season_length = (12 - starts[0]) + ends[0]
    
    return growing_season_length

In [None]:
def calculate_growing_season_length(ds_monthly_mean):
    lai = ds_monthly_mean['lai']

    # Create a DataArray for storing the growing season length
    growing_season_length = xr.DataArray(
        np.nan, 
        dims=('lat', 'lon'), 
        coords={'lat': ds_monthly_mean.lat, 'lon': ds_monthly_mean.lon}
    )

    for lat in ds_monthly_mean.lat.values:
        for lon in ds_monthly_mean.lon.values:
            lai_ts = lai.sel(lat=lat, lon=lon)
            monthly_pct_change = calculate_monthly_pct_change(lai_ts)
            starts, ends = detect_season_starts_and_ends(monthly_pct_change)
            length = calculate_season_length(starts, ends)
            growing_season_length.loc[dict(lat=lat, lon=lon)] = length

    return growing_season_length

def calculate_monthly_pct_change(lai_ts):
    monthly_pct_change = ((lai_ts - lai_ts.roll(month=1)) / lai_ts.roll(month=1)) * 100
    monthly_pct_change[0] = ((lai_ts[0] - lai_ts[-1]) / lai_ts[-1]) * 100
    return monthly_pct_change

In [None]:
def calculate_growing_season_length(ds):
    # Convert the dataset to use Dask
    ds = ds.chunk({'time': -1, 'lat': 'auto', 'lon': 'auto'})

    # Apply the function to each chunk
    lai = ds['lai']
    years = range(ds.time.dt.year.min().values, ds.time.dt.year.max().values + 1)
    gsl_template = xr.DataArray(
        np.nan, 
        dims=('year', 'lat', 'lon'), 
        coords={'year': list(years), 'lat': ds.lat, 'lon': ds.lon}
    )
    gsl = xr.apply_ufunc(calculate_gsl_for_chunk, lai, dask='parallelized', output_dtypes=[float], kwargs={'years': years})

    # Add the result to the dataset
    ds['growing_season_length'] = gsl

    # Trigger the computation
    ds = ds.compute()

    return ds

def calculate_gsl_for_chunk(lai_chunk, years):

    for year in years:
        lai_yearly = lai_chunk.sel(time=str(year))
        for lat in ds.lat.values:
            for lon in ds.lon.values:
                lai_ts = lai_yearly.sel(lat=lat, lon=lon, method='nearest')
                monthly_mean, monthly_pct_change = calculate_monthly_changes(lai_ts)
                starts, ends = detect_season_starts_and_ends(monthly_pct_change)
                length = calculate_season_length(starts, ends)
                growing_season_length_chunk.loc[dict(year=year, lat=lat, lon=lon)] = length

    return growing_season_length_chunk

In [None]:
# ============= Have a look into the data ==============
#print(ds_dict_ssp126.keys())
#ds_dict_ssp370[list(ds_dict_hist.keys())[3]].parent_variant_label

In [None]:
locations = {
    "Manaus, Brazil (Tropical Rainforest)": {"lat": -3, "lon": -60},
    "Sahara Desert (Desert)": {"lat": 23, "lon": 13},
    "Madrid, Spain (Mediterranean)": {"lat": 40, "lon": -4},
    "Moscow, Russia (Continental)": {"lat": 56, "lon": 37},
    "Northern Siberia, Russia (Tundra)": {"lat": 70, "lon": 78},
    "Hanoi, Vietnam (Multiple Growing Seasons)": {"lat": 21, "lon": 105},
    "New Delhi, India (Monsoon)": {"lat": 28, "lon": 77},
    "Central Greenland (Greenland)": {"lat": 72, "lon": -40},
    "Melbourne, Australia (Temperate)": {"lat": -38, "lon": 145},
    "Cape Town, South Africa (Mediterranean)": {"lat": -33, "lon": 19},
    "Buenos Aires, Argentina (Pampas)": {"lat": -34, "lon": -58},
    "Central Sweden (Evergreen Forests)": {"lat": 60, "lon": 15},
    "San José, Costa Rica (Tropical)": {"lat": 9.93, "lon": -84.08},
}

In [None]:
growing_season_length

In [None]:
for name, coords in locations.items():
    lai_ts = lai.sel(lat=coords['lat'], lon=coords['lon'], method='nearest')
    
    # Calculate monthly mean LAI
    monthly_mean = lai_ts.groupby('time.month').mean('time')

    # Compute percentage change including December to January
    monthly_pct_change = ((monthly_mean - monthly_mean.roll(month=1)) / monthly_mean.roll(month=1)) * 100
    monthly_pct_change[0] = ((monthly_mean[0] - monthly_mean[-1]) / monthly_mean[-1]) * 100
    
    # Compute the midpoint value for the December to January transition
    dec_to_jan_midpoint_lai = (monthly_mean[-1] + monthly_mean[0]) / 2
    dec_to_jan_midpoint_pct_change = (monthly_pct_change[-1] + monthly_pct_change[0]) / 2

    # Adjust growing season start and end detection
    starts, ends = [], []
    # Check if December ends a growing season and January starts a new one
    if monthly_pct_change[-2] > 0 and monthly_pct_change[-1] < 0:
        ends.append(12)  # December ends a growing season
    if monthly_pct_change[-1] <= 0 and monthly_pct_change[0] > 0:
        starts.append(1)  # January starts a growing season
    
    #Check if January ends a growing season (wrap-around to the previous December)
    if monthly_pct_change[0] < 0 and monthly_pct_change[-1] > 0:
        ends.append(1)  # January ends a growing season

    # Now handle February to November for starts and ends
    for i in range(1, len(monthly_pct_change) - 1):
        if monthly_pct_change[i-1] <= 0 and monthly_pct_change[i] > 0:
            starts.append(i + 1)
        elif monthly_pct_change[i-1] > 0 and monthly_pct_change[i] <= 0:
            ends.append(i + 1)
            
    # Calculate the growing season length(s)
    if len(starts) > 1:
        if any(end > starts[0] for end in ends):
            closest_end = min(filter(lambda end: end > starts[0], ends), key=lambda end: end - starts[0], default=None)
            growing_season_length = closest_end - starts[0]
        else:
            growing_season_length = min(ends) + (12 - starts[0])

        if any(end > starts[1] for end in ends):
            growing_season_length = growing_season_length + max(ends) - starts[1]
        else:
            growing_season_length = growing_season_length + min(ends) + (12 - starts[1])
    elif len(starts) == 1:
        if starts[0] < ends[0]:
            growing_season_length = ends[0] - starts[0]
        elif starts[0] > ends[0]:
            growing_season_length = (12 - starts[0]) + ends[0]
            
    fig, ax1 = plt.subplots()

    # Plot Monthly Mean LAI including the wrap-around from December to January
    color = 'tab:blue'
    ax1.set_xlabel('Month')
    ax1.set_ylabel('Mean LAI', color=color)
    ax1.plot(range(1, 13), monthly_mean, color=color)
    ax1.plot([12, 12.5], [monthly_mean[-1], dec_to_jan_midpoint_lai], color=color)
    ax1.plot([0.5, 1], [dec_to_jan_midpoint_lai, monthly_mean[0]], color=color)
    
    # Plot the starts and ends for growing seasons
    for start in starts:
        ax1.axvline(x=start-0.5, color='green', linestyle='--', label='Start of Growing Season')
    for end in ends:
        ax1.axvline(x=end-0.5, color='brown', linestyle='--', label='End of Growing Season')

    # Plot vertical lines for each month
    for month in range(1, 13): 
        ax1.axvline(x=month, color='gray', linestyle='--', linewidth=0.5)

    ax1.set_xticks(range(1, 13)) 
    ax1.set_xticklabels([str(month) for month in range(1, 13)])

    # Plot Monthly Percentage Change on secondary y-axis
    ax2 = ax1.twinx()  
    color = 'tab:red'
    ax2.set_ylabel('% Change', color=color)  
    ax2.plot(range(1, 13), monthly_pct_change, color=color)
    ax2.axhline(y=0, color='gray', linestyle='--', linewidth=1)
    ax2.tick_params(axis='y', labelcolor=color)
    
    # Plot the December to January transition for percentage change
    ax2.plot([12, 12.5], [monthly_pct_change[-1], dec_to_jan_midpoint_pct_change], color='tab:red')
    ax2.plot([0.5, 1], [dec_to_jan_midpoint_pct_change, monthly_pct_change[0]], color='tab:red')
    
    plt.title(f"{name} - Growing Season: {growing_season_length} Months")

    plt.show()

### Select period

In [None]:
ds_dict_hist_period = select_period(ds_dict_hist, start_year=1985, end_year=2014, period=None, yearly_sum=False)

In [None]:
ds = (compute_ensemble(ds_dict_hist_period))['Ensemble mean']

In [None]:
ds_dict_ssp370_period = select_period(ds_dict_ssp370, start_year=2071, end_year=2100, period=None, yearly_sum=False)

### Compute statistics

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_hist_period_metric = compute_statistic(ds_dict_hist_period, 'mean', 'time')

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_ssp370_period_metric = compute_statistic(ds_dict_ssp370_period, 'mean', 'time')

### Compute BGWS

In [None]:
ds_dict_hist_period_metric = compute_bgws(ds_dict_hist_period_metric)

In [None]:
ds_dict_hist_period_metric_ensemble = compute_ensemble(ds_dict_hist_period_metric, metric='median')
ds_ensmed_glob = ds_dict_hist_period_metric_ensemble['Ensemble median']

In [None]:
ds_dict_ssp370_period_metric = compute_bgws(ds_dict_ssp370_period_metric)

### Compute global change

In [None]:
# Compute Change
ds_dict_change = compute_change(ds_dict_hist_period_metric, ds_dict_ssp370_period_metric, var_rel_change=None)

In [None]:
# Compute ensemble median for the change
ds_dict_change_ensemble = compute_ensemble(ds_dict_change, metric='median')

In [None]:
# Only use ensemble median for further analysis
ds_ensmed_glob = ds_dict_change_ensemble['Ensemble median']

### Select regions

In [None]:
ds_dict_hist_period_metric_regions = {}
ds_dict_hist_period_metric_regions = apply_region_mask(ds_dict_hist_period_metric)

In [None]:
ds_dict_ssp370_period_metric_regions = {}
ds_dict_ssp370_period_metric_regions = apply_region_mask(ds_dict_ssp370_period_metric)

In [None]:
ds_dict_hist_period_metric_regions

### Compute Change

In [None]:
# Compute Change
ds_dict_region_change = compute_change(ds_dict_hist_period_metric_regions, ds_dict_ssp370_period_metric_regions, var_rel_change=None)

In [None]:
# Compute ensemble median for the change
ds_dict_region_change_ensemble = compute_ensemble(ds_dict_region_change)

In [None]:
# Only use ensemble median for further analysis
ds_dict_region_change_ensemble_metric = {}
ds_dict_region_change_ensemble_metric['Ensemble mean'] = ds_dict_region_change_ensemble['Ensemble mean']

In [None]:
ds = ds_dict_region_change_ensemble_metric['Ensemble mean']

In [None]:
ds_dict_hist_region_ensemble = compute_ensemble(ds_dict_hist_period_metric_regions)
ds_hist_ensemble_metric = ds_dict_hist_region_ensemble['Ensemble mean']

### Extreme Gradient Boositng Regression Analysis

#### Build extreme gradient boosting model

In [None]:
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import cross_val_score, GridSearchCV, train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.inspection import permutation_importance
from xgboost import XGBRegressor

In [None]:
def custom_scale(data):
    # Find the absolute maximum value in the data
    max_val = np.max(np.abs(data), axis=0)

    # Scale data by dividing by the max value
    scaled_data = data / max_val

    return scaled_data, max_val

In [None]:
def scale_data(X, method):
    scaler_data = {}

    if method == 'std':
        scaler = StandardScaler()
        X_standardized = scaler.fit_transform(X)
        scaler_data = {'mean': scaler.mean_, 'std': scaler.scale_}

    elif method == 'minmax':
        scaler = MinMaxScaler(feature_range=(-1, 1))
        X_standardized = scaler.fit_transform(X)
        scaler_data = {'min': scaler.data_min_, 'max': scaler.data_max_}
        
    elif method == 'max':
        X_standardized, max_val = custom_scale(X)
        scaler_data = {'max': max_val}

    elif method == 'no_scaling':
        X_standardized = X
        
    else:
        raise ValueError('Scaling method not known')

    return X_standardized, scaler_data

In [None]:
from sklearn.model_selection import cross_val_score, ParameterGrid, RandomizedSearchCV

def train_xgb_models(ds, predictor_vars, predictant, scaling_method, cv_folds=5):
    """
    Train XGBoost models for each region with Grid Search, Cross-Validation, and Train/Test Split.

    Parameters:
    - ds: xarray dataset
    - predictor_vars: List of predictor variable names
    - predictant: Name of the predictant variable
    - scaling_method: Method for scaling features

    Returns:
    - A dictionary containing trained models, best parameters, best scores, CV scores (R2 and MSE), feature importances, and performance metrics for each region.
    """
    # Initialize dictionaries
    xgb_models = {}
    best_params = {}
    best_scores = {}
    feature_importances = {}
    permutation_importances_test = {}
    permutation_importances_train = {}
    performance_metrics_test = {}
    performance_metrics_train = {}
    cv_scores_r2 = {}
    cv_scores_mse = {}
    cv_scores_r2_test = {}
    cv_scores_mse_test = {}
    cv_scores_r2_train = {}
    cv_scores_mse_train = {}
    scaled_data = {}

    # Parameter grid for GridSearchCV
    param_grid = {
    'n_estimators': [100, 200, 400, 800, 1000, 1100, 1200, 1300],
    'learning_rate': [0.01, 0.05, 0.1],
    'max_depth': [5, 7, 10, 14],
    'min_child_weight': [0.01, 0.1, 1, 10, 12],
    'lambda': [1, 3],
    'alpha': [0.001, 0.01],
    #'gamma': [0.05, 0.1]
    }
    

    for region in ds.region.values:
        # Data preparation
        df = ds.sel(region=region).to_dataframe().reset_index()
        region_name = ds.names.sel(region=region).values
        df.dropna(inplace=True)
        X = df[predictor_vars]
        y = df[predictant]
        
        X_scaled, scaled_data[f'{region_name}'] = scale_data(X, method=scaling_method)

        # Train/Test split
        X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)
        
        # Further split the training set into training and validation sets for hyperparameter tuning
        X_train_tuning, X_val, y_train_tuning, y_val = train_test_split(X_train, y_train, test_size=0.3, random_state=42)
        
        # Initialize XGBoost model
        xgb = XGBRegressor(random_state=42, objective='reg:squarederror')

        # GridSearchCV for hyperparameter tuning on training data
        grid_search = GridSearchCV(xgb, param_grid, cv=cv_folds, n_jobs=-1, scoring='r2')
        grid_search.fit(X_train_tuning, y_train_tuning)

        # Store the best model, parameters, and score for the region
        xgb_models[f'{region_name}'] = grid_search.best_estimator_
        best_params[f'{region_name}'] = grid_search.best_params_
        best_scores[f'{region_name}'] = -grid_search.best_score_ 

        # Update feature_importances assignment
        feature_importances[f'{region_name}'] = xgb_models[f'{region_name}'].feature_importances_
        
        #Compute permutation importance
        perm_importance = permutation_importance(
            xgb_models[f'{region_name}'], X_test, y_test, n_repeats=20, random_state=42
        )
        permutation_importances_test[f'{region_name}'] = perm_importance['importances']
        
        perm_importance_train = permutation_importance(
            xgb_models[f'{region_name}'], X_train, y_train, n_repeats=20, random_state=42
        )
        permutation_importances_train[f'{region_name}'] = perm_importance_train['importances']

        # Performance metrics
        y_pred = xgb_models[f'{region_name}'].predict(X_test)
        mse = mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        performance_metrics_test[f'{region_name}'] = {'MSE': mse, 'R2': r2}
        
        y_pred_train = xgb_models[f'{region_name}'].predict(X_val)
        mse_train = mean_squared_error(y_val, y_pred_train)
        r2_train = r2_score(y_val, y_pred_train)
        performance_metrics_train[f'{region_name}'] = {'MSE': mse_train, 'R2': r2_train}

        # Perform cross-validation and store results
        cv_scores_r2[f'{region_name}'] = cross_val_score(xgb_models[f'{region_name}'], X_scaled, y, cv=cv_folds, scoring='r2')
        cv_scores_mse[f'{region_name}'] = -cross_val_score(xgb_models[f'{region_name}'], X_scaled, y, cv=cv_folds, scoring='neg_mean_squared_error')
        cv_scores_r2_train[f'{region_name}'] = cross_val_score(xgb_models[f'{region_name}'], X_train, y_train, cv=cv_folds, scoring='r2')
        cv_scores_mse_train[f'{region_name}'] = -cross_val_score(xgb_models[f'{region_name}'], X_train, y_train, cv=cv_folds, scoring='neg_mean_squared_error')
        cv_scores_r2_test[f'{region_name}'] = cross_val_score(xgb_models[f'{region_name}'], X_test, y_test, cv=cv_folds, scoring='r2')
        cv_scores_mse_test[f'{region_name}'] = -cross_val_score(xgb_models[f'{region_name}'], X_test, y_test, cv=cv_folds, scoring='neg_mean_squared_error')
        
    feature_importances_df = pd.DataFrame.from_dict(feature_importances, orient='index', columns=predictor_vars)

    return feature_importances_df, permutation_importances_test, permutation_importances_train, best_params, xgb_models, best_scores, performance_metrics_test, performance_metrics_train, cv_scores_r2, cv_scores_mse, cv_scores_r2_train, cv_scores_mse_train, cv_scores_r2_test, cv_scores_mse_test, scaled_data

In [None]:
# Define data and predictor variables
predictor_vars = ['tas', 'pr', 'vpd','evapo', 'mrso','lai', 'gpp'] #'pr',  'gpp' 'evspsbl'  'mrro', 'tran', 
predictant = 'bgws'

In [None]:
ds_reduced_regions = ds.sel(region=slice(10, 15))

In [None]:
feature_importances_df_xgb, permutation_importances_test_xgb, permutation_importances_train_xgb, best_params_xgb, x_gradient_boosting_models, best_scores_xgb, performance_metrics_test_xgb, performance_metrics_train_xgb, cv_scores_r2_xgb, cv_scores_mse_xgb, cv_scores_r2_train_xgb, cv_scores_mse_train_xgb, cv_scores_r2_test_xgb, cv_scores_mse_test_xgb, scaler_data_xgb = train_xgb_models(ds_reduced_regions, predictor_vars, predictant, scaling_method='no_scaling', cv_folds=5)


In [None]:
best_params_xgb

#### Test overfitting

In [None]:
from sklearn.model_selection import learning_curve

def plot_learning_curves_for_all_regions(models, ds, predictor_vars, predictant, cv_folds=5):
    """
    Plots learning curves for each region's model using the dataset.

    Parameters:
    - models: Dictionary of trained models, one for each region.
    - ds: xarray dataset used in training models.
    - predictor_vars: List of predictor variable names.
    - predictant: Name of the predictant variable.
    - region_names: List of region names.
    - cv_folds: Number of folds for cross-validation.
    """
    for region in ds.region.values:
        # Extract data for the region
        df = ds.sel(region=region).to_dataframe().reset_index()
        df.dropna(inplace=True)
        X = df[predictor_vars]
        y = df[predictant]
        
        region_name = ds.names.sel(region=region).values

        # Learning curve computation
        train_sizes, train_scores, validation_scores = learning_curve(
            estimator=models[f'{region_name}'],
            X=X, 
            y=y, 
            train_sizes=np.linspace(0.1, 1.0, 10, 550),
            cv=cv_folds,
            scoring='r2',
            n_jobs=-1
        )

        # Calculate mean and standard deviation for training and validation set scores
        train_mean = np.mean(train_scores, axis=1)
        train_std = np.std(train_scores, axis=1)
        validation_mean = np.mean(validation_scores, axis=1)
        validation_std = np.std(validation_scores, axis=1)

        # Plot learning curves
        plt.figure(figsize=(10, 6))
        plt.plot(train_sizes, train_mean, label='Training error', color='blue', marker='o')
        plt.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, color='blue', alpha=0.15)
        plt.plot(train_sizes, validation_mean, label='Validation error', color='green', marker='o')
        plt.fill_between(train_sizes, validation_mean - validation_std, validation_mean + validation_std, color='green', alpha=0.15)

        plt.title(f'Learning Curves for {region_name}')
        plt.xlabel('Training Data Size')
        plt.ylabel('R^2')
        plt.legend(loc='upper right')
        plt.grid()
        plt.show()

In [None]:

plot_learning_curves_for_all_regions(x_gradient_boosting_models, ds_reduced_regions, predictor_vars, predictant)

In [None]:
def plot_performance_metrics(performance_metrics_train, performance_metrics_test, cv_scores_r2, cv_scores_mse, cv_scores_r2_train, cv_scores_mse_train, cv_scores_r2_test, cv_scores_mse_test):
    regions = list(performance_metrics_train.keys())
    mse_train = [performance_metrics_train[region]['MSE'] for region in regions]
    mse_test = [performance_metrics_test[region]['MSE'] for region in regions]
    r2_train = [performance_metrics_train[region]['R2'] for region in regions]
    r2_test = [performance_metrics_test[region]['R2'] for region in regions]
   
    fig, ax = plt.subplots(6, 1, figsize=(15, 30))

    # Plot 1: R2 Comparison
    ax[0].plot(regions, r2_train, label='Train R2', marker='o')
    ax[0].plot(regions, r2_test, label='Test R2', marker='o')
    ax[0].set_title('R2 Comparison')
    ax[0].set_xlabel('Regions')
    ax[0].set_ylabel('R2')
    ax[0].set_ylim([0,1])
    ax[0].legend()
    ax[0].tick_params(axis='x', rotation=90)

    # Plot 2: MSE Comparison
    ax[1].plot(regions, mse_train, label='Train MSE', marker='o')
    ax[1].plot(regions, mse_test, label='Test MSE', marker='o')
    ax[1].set_title('MSE Comparison')
    ax[1].set_xlabel('Regions')
    ax[1].set_ylabel('MSE')
    ax[1].set_ylim([0,0.006])
    ax[1].legend()
    ax[1].tick_params(axis='x', rotation=90)

    # Function to plot mean or standard deviation
    def plot_cv_scores(ax, cv_scores_whole, cv_scores_train, cv_scores_test, title, ylabel):
        means_whole = [np.mean(scores) for scores in cv_scores_whole]
        means_train = [np.mean(scores) for scores in cv_scores_train]
        means_test = [np.mean(scores) for scores in cv_scores_test]
        if ylabel == 'Mean Scores':
            ax.plot(regions, means_whole, label='Whole Data', marker='o')
            ax.plot(regions, means_train, label='Train Data', marker='o')
            ax.plot(regions, means_test, label='Test Data', marker='o')
        else:
            stds_whole = [np.std(scores) for scores in cv_scores_whole]
            stds_train = [np.std(scores) for scores in cv_scores_train]
            stds_test = [np.std(scores) for scores in cv_scores_test]
            ax.plot(regions, stds_whole, label='Whole Data', marker='o')
            ax.plot(regions, stds_train, label='Train Data', marker='o')
            ax.plot(regions, stds_test, label='Test Data', marker='o')
        
        # Set y-axis limits for R2 plots
        if title == 'Mean CV R2 Scores' or title == 'Std CV R2 Scores':
            ax.set_ylim([0, 1]) 
        else:
            ax.set_ylim([0, 0.006]) 

        ax.set_title(title)
        ax.set_xlabel('Regions')
        ax.set_ylabel(ylabel)
        ax.legend()
        ax.tick_params(axis='x', rotation=90)

    # Plot 3: Mean CV R2 Scores
    plot_cv_scores(ax[2], [cv_scores_r2[region] for region in regions], [cv_scores_r2_train[region] for region in regions], [cv_scores_r2_test[region] for region in regions], 'Mean CV R2 Scores', 'Mean Scores')

    # Plot 4: Std CV R2 Scores
    plot_cv_scores(ax[3], [cv_scores_r2[region] for region in regions], [cv_scores_r2_train[region] for region in regions], [cv_scores_r2_test[region] for region in regions], 'Std CV R2 Scores', 'Standard Deviation')

    # Plot 5: Mean CV MSE Scores
    plot_cv_scores(ax[4], [cv_scores_mse[region] for region in regions], [cv_scores_mse_train[region] for region in regions], [cv_scores_mse_test[region] for region in regions], 'Mean CV MSE Scores', 'Mean Scores')

    # Plot 6: Std CV MSE Scores
    plot_cv_scores(ax[5], [cv_scores_mse[region] for region in regions], [cv_scores_mse_train[region] for region in regions], [cv_scores_mse_test[region] for region in regions], 'Std CV MSE Scores', 'Standard Deviation')

    plt.tight_layout()
    plt.show()

In [None]:
plot_performance_metrics(performance_metrics_train_xgb, performance_metrics_test_xgb, cv_scores_r2_xgb, cv_scores_mse_xgb, cv_scores_r2_train_xgb, cv_scores_mse_train_xgb, cv_scores_r2_test_xgb, cv_scores_mse_test_xgb)


In [None]:
def plot_permutation_importances(permutation_importances_test, permutation_importances_train, predictor_vars):
    # Number of regions
    num_regions = len(permutation_importances_test)

    # Create subplots - one row per region
    fig, axs = plt.subplots(num_regions, 2, figsize=(15, num_regions * 4))

    for idx, region in enumerate(permutation_importances_test):
        # Convert arrays to DataFrames for easy plotting
        test_df = pd.DataFrame(permutation_importances_test[region].T, columns=predictor_vars)
        train_df = pd.DataFrame(permutation_importances_train[region].T, columns=predictor_vars)
        
        # Plot for test data
        sns.boxplot(data=test_df, orient='h', ax=axs[idx, 0])
        axs[idx, 0].axvline(0, color='grey', linestyle='--')
        axs[idx, 0].set_title(f'{region} - Test Data')
        axs[idx, 0].set_xlabel('Decrease in Accuracy Score')
        axs[idx, 0].set_ylabel('Variables')

        # Plot for training data
        sns.boxplot(data=train_df, orient='h', ax=axs[idx, 1])
        axs[idx, 1].axvline(0, color='grey', linestyle='--')
        axs[idx, 1].set_title(f'{region} - Training Data')
        axs[idx, 1].set_xlabel('Decrease in Accuracy Score')
        axs[idx, 1].set_ylabel('Variables')

    plt.tight_layout()
    plt.show()

In [None]:
plot_permutation_importances(permutation_importances_test_xgb, permutation_importances_train_xgb, predictor_vars)

#### Assess variable importance

In [None]:
permutation_importance_df_gb

In [None]:
feature_importances_df * 100

In [None]:
feature_importances_df * 100

#### Model diagnostics

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import shapiro
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.stats.diagnostic import het_breuschpagan

def test_regression_assumptions_scikit(regression_models, test_data, predictor_vars):
    results = []

    for region_name in regression_models:
        model = regression_models[region_name]
        X_test, y_test = test_data[region_name]

        # Predict and calculate residuals
        predictions = model.predict(X_test)
        residuals = y_test - predictions

        # Add a constant term for the Breusch-Pagan test
        X_test_with_constant = np.column_stack((np.ones(X_test.shape[0]), X_test))

        # Test for normality of residuals
        shapiro_stat, shapiro_p = shapiro(residuals)

        # Test for homoscedasticity
        _, _, _, bp_pvalue = het_breuschpagan(residuals, X_test_with_constant)

        # VIF for multicollinearity
        vif = [variance_inflation_factor(X_test, i) for i in range(X_test.shape[1])]

        # Prepare plot data
        plot_data = {
            'Region': region_name,
            'Predictions': predictions,
            'Residuals': residuals,
            'Shapiro-Wilk': shapiro_stat,
            'Shapiro-Wilk p-value': shapiro_p,
            'Breusch-Pagan p-value': bp_pvalue,
            'VIF': vif
        }
        results.append(plot_data)
    
    # Plotting
    num_regions = len(results)
    fig, axs = plt.subplots(num_regions, 3, figsize=(22, 5 * num_regions)) # Changed to 3 subplots for VIF

    for i, result in enumerate(results):
        sns.residplot(x=result['Predictions'], y=result['Residuals'], lowess=True, ax=axs[i, 0])
        axs[i, 0].set_title(f'Residuals vs Predictions for {result["Region"]}')
        axs[i, 0].set_xlabel('Predicted values')
        axs[i, 0].set_ylabel('Residuals')

        # Adding text for statistical tests
        axs[i, 0].text(0.05, 0.95, f"Shapiro-Wilk: {result['Shapiro-Wilk']:.2f}", transform=axs[i, 0].transAxes)
        axs[i, 0].text(0.05, 0.90, f"Shapiro-Wilk p-value: {result['Shapiro-Wilk p-value']:.2f}", transform=axs[i, 0].transAxes)
        axs[i, 0].text(0.05, 0.85, f"Breusch-Pagan p-value: {result['Breusch-Pagan p-value']:.2f}", transform=axs[i, 0].transAxes)

        sns.histplot(result['Residuals'], kde=True, ax=axs[i, 1])
        axs[i, 1].set_title(f'Residual Distribution for {result["Region"]}')
        axs[i, 1].set_xlabel('Residuals')
        axs[i, 1].set_ylabel('Frequency')

        # VIF bar plot
        sns.barplot(x=predictor_vars, y=result['VIF'], ax=axs[i, 2])
        axs[i, 2].hlines(5, xmin=-0.5, xmax=len(result['VIF'])-0.5, colors='orange', linestyles='dashed')
        axs[i, 2].hlines(10, xmin=-0.5, xmax=len(result['VIF'])-0.5, colors='r', linestyles='dashed')
        axs[i, 2].set_title(f'VIF for {result["Region"]}')
        axs[i, 2].set_xlabel('Predictor Variables')
        axs[i, 2].set_ylabel('VIF Value')
        axs[i, 2].set_ylim(0, max(result['VIF']) + 1)

    plt.tight_layout()
    plt.show()

    return pd.DataFrame(results)

In [None]:
assumptions_df = test_regression_assumptions_scikit(regression_models, test_data, predictor_vars)

#### Test model performance

In [None]:
def assess_model_performance_all_regions(ds, regression_models, best_params, best_scores, scalers, predictor_vars):
    region_names = ds.names.values
    region_indices = ds.region.values
    
    # Determine the number of rows and columns for the subplots based on the number of regions
    num_regions = len(region_names)
    cols = 3  # We are keeping 3 columns as per your requirement
    rows = math.ceil(num_regions / cols)
    
    fig, axs = plt.subplots(rows, cols, figsize=(15, 5 * rows), squeeze=False)
    
    # Initialize an empty dictionary to store performance metrics for each region
    performance_metrics = {}

    for idx, (ax, region) in enumerate(zip(axs.flatten(), region_indices)):
        region_name = ds.names.sel(region=region).values.item()
        # Prepare the data for the region
        df = ds.sel(region=region).to_dataframe().dropna()
        X = df[predictor_vars]
        y_true = df['bgws'].values

        # Retrieve the scaler for the region
        scaler = scalers[region_name]
        X_standardized = scaler.transform(X)  # Use transform here, not fit_transform
        
        # Predict using the trained model
        model = regression_models[region_name]
        y_pred = model.predict(X_standardized)
        
        # Calculate residuals
        residuals = y_true - y_pred
        
        # Calculate and store performance metrics
        mse_value = mean_squared_error(y_true, y_pred)
        r2_value = r2_score(y_true, y_pred)
        best_param = best_params[region_name]
        best_cv_score = best_scores[region_name]
        
        performance_metrics[region_name] = {
            'MSE': mse_value,
            'R^2': r2_value,
            'Best Parameters': best_param,
            'Best CV Score': best_cv_score
        }
        
        # Plot the residuals
        ax.scatter(y_true, residuals, c='blue', alpha=0.5, s=10)
        ax.axhline(0, color='red', lw=2)
        ax.text(0.05, 0.95, f'MSE: {mse_value:.6f}\nR^2: {r2_value:.2f}\nBest CV: {best_cv_score:.4f}', 
                transform=ax.transAxes, verticalalignment='top')
        ax.set_title(f'{region_name}\n{best_param}')
        ax.set_xlabel('True Values')
        ax.set_ylabel('Residuals')
        
        # Hide axes for subplots that are not used (if num_regions < rows*cols)
        if idx >= num_regions:
            ax.set_visible(False)

    plt.tight_layout()
    plt.show()
    
    return performance_metrics

In [None]:
gb_performance_metrics = assess_model_performance_all_regions(ds, gradient_boosting_models, best_params, best_scores, scaler_data, predictor_vars)

### Multiple Linear Regression Analysis

#### Plot data statistics

In [None]:
def summarize_standardized_data(X_standardized, y, predictor_vars):
    # Summarize predictors
    X_summary = pd.DataFrame(X_standardized, columns=predictor_vars).describe().transpose()
    # Summarize response variable
    y_summary = pd.DataFrame(y, columns=['bgws']).describe().transpose()
    # Combine summaries
    summary = pd.concat([X_summary, y_summary])
    return summary[['count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']]

In [None]:
summarize_standardized_data(X, y, predictor_vars)

#### Plot data correlation

In [None]:
def plot_spearman_correlation(ds, predictor_vars, predictant):
    # Number of regions
    n_regions = len(ds.region)
    # Calculate the number of rows and columns for subplots
    n_cols = 3
    n_rows = np.ceil(n_regions / n_cols).astype(int)

    # Initialize the subplot figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 4))
    axes = axes.flatten()  # Flatten the array for easy iteration

    for index, region in enumerate(ds.region.values):
        # Select the current region data
        df = ds.sel(region=region).to_dataframe().dropna()

        # Concatenate predictor variables and predictant for correlation
        data = df[predictor_vars + [predictant]]

        # Compute the Spearman correlation matrix
        corr_matrix = data.corr(method='spearman')

        # Plot the heatmap
        sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="coolwarm", vmin=-1, vmax=1, cbar=index == 0, ax=axes[index])

        # Set the title with the region name
        region_name = ds.names.sel(region=region).values
        axes[index].set_title(f"Region: {region_name}")

    # Adjust layout to prevent overlap
    plt.tight_layout()

    # Remove any unused subplots
    for i in range(n_regions, n_rows * n_cols):
        fig.delaxes(axes[i])

    plt.show()

In [None]:
plot_spearman_correlation(ds, predictor_vars, predictant)

#### Plot data distribution

In [None]:
import seaborn as sns

def plot_all_distributions(ds, predictor_vars):
    num_regions = len(ds.region)
    # Set up the matplotlib figure with a certain number of columns
    cols = 3
    rows = (num_regions // cols) + (num_regions % cols > 0)
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, rows * 5), constrained_layout=True)
    axes = axes.flatten()  # Flatten to 1D array for easy iteration
    
    for idx, region in enumerate(ds.region.values):
        # Select data for the region
        df = ds.sel(region=region).to_dataframe().reset_index()
        df = df.dropna(subset=predictor_vars + ['bgws'])  # Drop NaN values for relevant columns only
        
        # Extract predictor variables and the target variable
        X = df[predictor_vars]
        y = df['bgws']
        
        # Standardize the predictors
        scaler = StandardScaler()
        X_standardized = scaler.fit_transform(X)
        
        # Create DataFrame from the standardized predictors
        df_standardized = pd.DataFrame(X_standardized, columns=predictor_vars)
        df_standardized['bgws'] = y.values  # Add non-standardized 'bgws'
        
        # Plotting on the respective subplot
        ax = axes[idx]
        # Create a list to store handles for the legend
        handles = []
        for col in df_standardized.columns:
            # Plot each variable and get the handle
            handle = sns.histplot(df_standardized[col], kde=True, ax=ax, label=col)
            handles.append(handle)

        region_name = ds.names.sel(region=region).values
        ax.set_title(f'Region {region_name}')

        # Only add legend to the first subplot
        if idx == 0:
            ax.legend(title='Variable')
    
    # Hide any unused subplots
    for ax in axes[num_regions:]:
        ax.set_visible(False)

    # Display the plot
    plt.show()

In [None]:
plot_all_distributions(ds, predictor_vars)

#### Build regression model

In [None]:
from sklearn.linear_model import Ridge, Lasso, ElasticNet

def lr_models(ds_change, predictor_vars, predictant, scaling_method, regression_type='ridge', cv_folds=5, scaling_back=False):
    """
    Train regularized linear regression models (Ridge, Lasso, or ElasticNet) for each region using Grid Search for hyperparameter tuning,
    perform cross-validation, compute permutation importance, and gather performance metrics.

    Parameters:
    - ds_change: xarray datasets
    - predictor_vars: List of predictor variable names
    - predictant: Name of the predictant variable
    - regression_type: Type of regression ('ridge', 'lasso', 'elasticnet')
    - cv_folds: Number of folds for cross-validation
    - scaling_method: Scaling method (std or minmax)
    - scaling_back: Boolean flag to scale back coefficients

    Returns:
    - A dictionary containing trained models, best hyperparameters, coefficients, cross-validation scores,
      permutation importances, and performance metrics for each region.
    """

    # Initialize dictionaries
    regression_models = {}
    best_hyperparams = {}
    best_scores = {}
    regression_coeffs = {}
    cv_scores_r2 = {}
    cv_scores_mse = {}
    cv_scores_r2_test = {}
    cv_scores_mse_test = {}
    cv_scores_r2_train = {}
    cv_scores_mse_train = {}
    train_data = {}
    test_data = {}
    residuals = {}
    performance_metrics_test = {}
    performance_metrics_train = {}
    permutation_importances_test = {}
    permutation_importances_train = {}
    scaled_data = {}

    # Define parameter grid based on regression type
    if regression_type == 'ridge':
        model = Ridge(random_state=42)
        param_grid = {'alpha': [0.001, 0.01, 0.1, 1, 10, 100]}
    elif regression_type == 'lasso':
        model = Lasso(random_state=42)
        param_grid = {'alpha': [0.001, 0.01, 0.1, 1, 10, 100]}
    elif regression_type == 'elasticnet':
        model = ElasticNet(random_state=42)
        param_grid = {'alpha': [0.001, 0.01, 0.1, 1, 10, 100], 'l1_ratio': [0.2, 0.5, 0.8]}
    else:
        raise ValueError("Invalid regression type. Choose 'ridge', 'lasso', or 'elasticnet'.")

    for region in ds_change.region.values:
        # Data preparation
        df = ds_change.sel(region=region).to_dataframe().reset_index()
        df.dropna(inplace=True)
        X = df[predictor_vars]
        y = df[predictant]
        
        # Get region name
        region_name = ds.names.sel(region=region).values

        # Scale the data
        X_scaled, scaled_data[f'{region_name}'] = scale_data(X, method=scaling_method)

        # Train/Test split
        X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)

        # Grid search
        grid_search = GridSearchCV(model, param_grid, cv=cv_folds, scoring='neg_mean_squared_error', n_jobs=-1)
        grid_search.fit(X_train, y_train)

        # Store best model and hyperparameters
        best_model = grid_search.best_estimator_
        regression_models[f'{region_name}'] = best_model
        best_hyperparams[f'{region_name}'] = grid_search.best_params_
        best_scores[f'{region_name}'] = grid_search.best_score_
        regression_coeffs[f'{region_name}'] = best_model.coef_
        
        # Perform cross-validation and store results
        cv_scores_r2[f'{region_name}'] = cross_val_score(best_model, X_scaled, y, cv=cv_folds, scoring='r2')
        cv_scores_mse[f'{region_name}'] = -cross_val_score(best_model, X_scaled, y, cv=cv_folds, scoring='neg_mean_squared_error')
        cv_scores_r2_train[f'{region_name}'] = cross_val_score(best_model, X_train, y_train, cv=cv_folds, scoring='r2')
        cv_scores_mse_train[f'{region_name}'] = -cross_val_score(best_model, X_train, y_train, cv=cv_folds, scoring='neg_mean_squared_error')
        cv_scores_r2_test[f'{region_name}'] = cross_val_score(best_model, X_test, y_test, cv=cv_folds, scoring='r2')
        cv_scores_mse_test[f'{region_name}'] = -cross_val_score(best_model, X_test, y_test, cv=cv_folds, scoring='neg_mean_squared_error')
        
        # Compute and store performance metrics
        y_pred_test = best_model.predict(X_test)
        mse_test = mean_squared_error(y_test, y_pred_test)
        r2_test = r2_score(y_test, y_pred_test)
        performance_metrics_test[f'{region_name}'] = {'MSE': mse_test, 'R2': r2_test}

        y_pred_train = best_model.predict(X_train)
        mse_train = mean_squared_error(y_train, y_pred_train)
        r2_train = r2_score(y_train, y_pred_train)
        performance_metrics_train[f'{region_name}'] = {'MSE': mse_train, 'R2': r2_train}

        # Compute permutation importance
        perm_importance_test = permutation_importance(best_model, X_test, y_test, n_repeats=20, random_state=42)
        permutation_importances_test[f'{region_name}'] = perm_importance_test['importances']

        perm_importance_train = permutation_importance(best_model, X_train, y_train, n_repeats=20, random_state=42)
        permutation_importances_train[f'{region_name}'] = perm_importance_train['importances']

        # Store train and test data
        train_data[f'{region}'] = (X_train, y_train)
        test_data[f'{region}'] = (X_test, y_test)

    # Convert regression coefficients to DataFrame
    coeffs_df = pd.DataFrame.from_dict(regression_coeffs, orient='index', columns=predictor_vars)

    return coeffs_df, regression_models, best_hyperparams, best_scores, train_data, test_data, cv_scores_r2, cv_scores_mse, cv_scores_r2_train, cv_scores_mse_train, cv_scores_r2_test, cv_scores_mse_test, performance_metrics_test, performance_metrics_train, permutation_importances_test, permutation_importances_train, scaled_data

In [None]:
def lr_models(ds_change, predictor_vars, predictant, scaling_method, cv_folds=5, scaling_back=False):
    """
    Train multivariate regression models for each region, perform cross-validation, 
    and compute performance metrics.
    
    Parameters:
    - ds_change, ds_hist: xarray datasets
    - predictor_vars: List of predictor variable names
    - predictant: Name of the predictant variable
    - cv_folds: Number of folds for cross-validation (default is 5)
    - scaling: Scaling method std or minmax
    - scaling_back: Boolean flag to scale back coefficients
    
    Returns:
    - A dictionary containing trained models, coefficients, cross-validation scores,
      and performance metrics for each region.
    """

    # Initialize dictionaries
    regression_models = {}
    regression_coeffs = {}
    cv_scores = {}
    train_data = {}
    test_data = {}
    residuals = {}
    performance_metrics = {}
    scaled_data = {}

    for region in ds_change.region.values:
        # Convert xarray data to pandas DataFrame
        df = ds_change.sel(region=region).to_dataframe().reset_index()
        df = df.dropna()  # Drop rows with NaN values
        
        # Get region name
        region_name = ds.names.sel(region=region).values

        X = df[predictor_vars]
        y = df[predictant]
        
        # Scale the data
        X_scaled, scaled_data[f'{region_name}'] = scale_data(X, method=scaling_method)
        
        # Train/Test split
        X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42) # 30% of the data are used to test the model

        # Train the regression model
        model = LinearRegression()
        model.fit(X_train, y_train)

        # Store model and coefficients
        regression_models[f'{region_name}'] = model
        regression_coeffs[f'{region_name}'] = model.coef_
        train_data[f'{region_name}'] = (X_train, y_train)
        test_data[f'{region_name}'] = (X_test, y_test)

        # Predict using the test set
        y_pred = model.predict(X_test)
        
        # Calculate residuals
        resid = y_test - y_pred
        residuals[f'{region_name}'] = resid

        # Compute and store performance metrics
        mse = mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        performance_metrics[f'{region_name}'] = {'MSE': mse, 'R2': r2}

        # Perform cross-validation and store results
        cv_score = cross_val_score(model, X, y, cv=cv_folds) 
        cv_scores[f'{region_name}'] = cv_score

    # Convert regression coefficients to DataFrame
    coeffs_list = []

    if scaling_back:
        if scaling_method == 'std':
            for region, coefs in regression_coeffs.items():
                feature_means, feature_stds = scaled_data[region]['mean'], scaled_data[region]['std']
                # Scale back coefficients
                original_coefs = coefs / feature_stds
                intercept = model.intercept_ - np.sum(original_coefs * feature_means)
                coeffs_list.append(np.concatenate(coefs))
                
        elif scaling_method == 'minmax':
            for region, coefs in regression_coeffs.items():
                feature_mins, feature_maxs = scaled_data[region]['min'], scaled_data[region]['max']
                feature_ranges = feature_maxs - feature_mins
                original_coefs = coefs / feature_ranges
                intercept = model.intercept_ - np.sum(original_coefs * feature_mins / feature_ranges)
                coeffs_list.append(np.concatenate(coefs))
        else:
            for region, coefs in regression_coeffs.items():
                coeffs_list.append(np.concatenate(coefs))
    else:
        for region, coefs in regression_coeffs.items():
            coeffs_list.append(coefs)

    coeffs_df = pd.DataFrame(coeffs_list, index=regression_coeffs.keys(), columns=predictor_vars)

    return coeffs_df, regression_models, train_data, test_data, cv_scores, performance_metrics, residuals

In [None]:
# Define data and predictor variables
predictor_vars = ['tas','pr', 'vpd', 'mrro', 'tran', 'evapo', 'mrso','lai', 'gpp'] #'pr',  'gpp' 'evspsbl' 
predictant = 'bgws'

In [None]:
coeffs_df_lr, linear_regression_models, best_hyperparams_lr, best_scores_lr, train_data_lr, test_data_lr, cv_scores_r2_lr, cv_scores_mse_lr, cv_scores_r2_train_lr, cv_scores_mse_train_lr, cv_scores_r2_test_lr, cv_scores_mse_test_lr, performance_metrics_test_lr, performance_metrics_train_lr, permutation_importances_test_lr, permutation_importances_train_lr, scaled_data_lr = lr_models(ds, predictor_vars, predictant, scaling_method='no_scaling', regression_type='elasticnet', cv_folds=4, scaling_back=False)


#### Test overfitting

In [None]:
plot_performance_metrics(performance_metrics_train_lr, performance_metrics_test_lr, cv_scores_r2_lr, cv_scores_mse_lr, cv_scores_r2_train_lr, cv_scores_mse_train_lr, cv_scores_r2_test_lr, cv_scores_mse_test_lr)


In [None]:
plot_permutation_importances(permutation_importances_test_lr, permutation_importances_train_lr, predictor_vars)

#### Test variable importance

In [None]:
coefficients_df

In [None]:
coefficients_df

In [None]:
coefficients_df

In [None]:
coefficients_df

#### Model diagnostics

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import shapiro
from statsmodels.stats.outliers_influence import variance_inflation_factor
from statsmodels.stats.diagnostic import het_breuschpagan

def test_regression_assumptions_scikit(regression_models, test_data, predictor_vars):
    results = []

    for region_name in regression_models:
        model = regression_models[region_name]
        X_test, y_test = test_data[region_name]

        # Predict and calculate residuals
        predictions = model.predict(X_test)
        residuals = y_test - predictions

        # Add a constant term for the Breusch-Pagan test
        X_test_with_constant = np.column_stack((np.ones(X_test.shape[0]), X_test))

        # Test for normality of residuals
        shapiro_stat, shapiro_p = shapiro(residuals)

        # Test for homoscedasticity
        _, _, _, bp_pvalue = het_breuschpagan(residuals, X_test_with_constant)

        # VIF for multicollinearity
        vif = [variance_inflation_factor(X_test, i) for i in range(X_test.shape[1])]

        # Prepare plot data
        plot_data = {
            'Region': region_name,
            'Predictions': predictions,
            'Residuals': residuals,
            'Shapiro-Wilk': shapiro_stat,
            'Shapiro-Wilk p-value': shapiro_p,
            'Breusch-Pagan p-value': bp_pvalue,
            'VIF': vif
        }
        results.append(plot_data)
    
    # Plotting
    num_regions = len(results)
    fig, axs = plt.subplots(num_regions, 3, figsize=(22, 5 * num_regions)) # Changed to 3 subplots for VIF

    for i, result in enumerate(results):
        sns.residplot(x=result['Predictions'], y=result['Residuals'], lowess=True, ax=axs[i, 0])
        axs[i, 0].set_title(f'Residuals vs Predictions for {result["Region"]}')
        axs[i, 0].set_xlabel('Predicted values')
        axs[i, 0].set_ylabel('Residuals')

        # Adding text for statistical tests
        axs[i, 0].text(0.05, 0.95, f"Shapiro-Wilk: {result['Shapiro-Wilk']:.2f}", transform=axs[i, 0].transAxes)
        axs[i, 0].text(0.05, 0.90, f"Shapiro-Wilk p-value: {result['Shapiro-Wilk p-value']:.2f}", transform=axs[i, 0].transAxes)
        axs[i, 0].text(0.05, 0.85, f"Breusch-Pagan p-value: {result['Breusch-Pagan p-value']:.2f}", transform=axs[i, 0].transAxes)

        sns.histplot(result['Residuals'], kde=True, ax=axs[i, 1])
        axs[i, 1].set_title(f'Residual Distribution for {result["Region"]}')
        axs[i, 1].set_xlabel('Residuals')
        axs[i, 1].set_ylabel('Frequency')

        # VIF bar plot
        sns.barplot(x=predictor_vars, y=result['VIF'], ax=axs[i, 2])
        axs[i, 2].hlines(5, xmin=-0.5, xmax=len(result['VIF'])-0.5, colors='orange', linestyles='dashed')
        axs[i, 2].hlines(10, xmin=-0.5, xmax=len(result['VIF'])-0.5, colors='r', linestyles='dashed')
        axs[i, 2].set_title(f'VIF for {result["Region"]}')
        axs[i, 2].set_xlabel('Predictor Variables')
        axs[i, 2].set_ylabel('VIF Value')
        axs[i, 2].set_ylim(0, max(result['VIF']) + 1)

    plt.tight_layout()
    plt.show()

    return pd.DataFrame(results)

Linearity: Random pattern without any discernible pattern --> non-linearity

Independence of Errors: Durbin-Watson values between 1.5 and 2.5 are relatively normal

Homoscedasticity (Equal Variance of Errors):  Variance of the residuals is consistent across all levels of the predicted values +
                                              Breusch-Pagan test: A small p-value (typically <= 0.05)  suggests heteroscedasticity, which can invalidate                                                   some of the statistical conclusions of the regression.
                                              
Normality of Errors: Shapiro-Wilk test: The closer this value is to 1, the more the residuals follow a normal distribution. 
                     Sapiro-Wilk p-value: The p-value from the Shapiro-Wilk test. A small p-value (typically <= 0.05) indicates that                        the residuals do not follow a normal distribution.
                     
Variance Inflation Factor (VIF): Measures multicollinearity among the independent variables in the regression model. A VIF value                                        greater than 10 is typically considered an indicator of serious multicollinearity that could affect                                    the model's estimates.    

In [None]:
assumptions_df = test_regression_assumptions_scikit(regression_models, test_data, predictor_vars)

#### Test model performance

In [None]:
def assess_model_performance_all_regions(ds_change, ds_hist, regression_models, test_data, performance_metrics, predictor_vars, residuals, cv_scores):
    region_names = ds_change.names.values
    region_indices = ds_change.region.values
    
    num_regions = len(region_names)
    plots_per_region = 3  # Number of plots per region
    total_plots = num_regions * plots_per_region
    cols = 3  # Number of columns (one for each type of plot)
    rows = math.ceil(total_plots / cols)  # Calculate the total number of rows needed

    fig, axs = plt.subplots(rows, cols, figsize=(20, rows * 5), squeeze=False)

    plot_idx = 0  # Initialize plot index
    for region in region_indices:
        region_name = ds_change.names.sel(region=region).values.item()
        X_test, y_test = test_data[region_name]
        model = regression_models[region_name]
        y_pred = model.predict(X_test)
        resids = residuals[region_name]

        mse_value = performance_metrics[region_name]['MSE']
        r2_value = performance_metrics[region_name]['R2']
        cv_score = np.mean(cv_scores[region_name])  # Average CV score

        # Plot 1: Residuals
        ax_resid = axs[plot_idx // cols, plot_idx % cols]
        ax_resid.scatter(y_test, resids, c='blue', alpha=0.5, s=10, label='Residuals')
        ax_resid.axhline(0, color='red', lw=2, label='Zero Residual')
        ax_resid.set_xlabel('True Values')
        ax_resid.set_ylabel('Residuals')
        ax_resid.set_title(f"{region_name} - Residuals")
        ax_resid.legend()
        plot_idx += 1

        # Plot 2: Predictions vs True Values
        ax_pred = axs[plot_idx // cols, plot_idx % cols]
        ax_pred.scatter(y_test, y_pred, c='green', alpha=0.5, s=10, label='Predictions')
        ax_pred.plot(y_test, y_test, color='orange', label='Ideal Prediction')
        ax_pred.set_xlabel('True Values')
        ax_pred.set_ylabel('Predicted Values')
        ax_pred.set_title(f"{region_name} - Predictions")
        ax_pred.legend()
        plot_idx += 1

        # Plot 3: Density Plot
        ax_density = axs[plot_idx // cols, plot_idx % cols]
        sns.kdeplot(y_test, ax=ax_density, label='Actual Values', fill=True)
        sns.kdeplot(y_pred, ax=ax_density, label='Predicted Values', fill=True)
        ax_density.set_xlabel('Values')
        ax_density.set_title(f"{region_name} - Density Plot")
        ax_density.legend()
        plot_idx += 1

        # Include performance metrics in the text or title
        ax_density.text(0.05, 0.95, f'MSE: {mse_value:.6f}\nR^2: {r2_value:.2f}\nCV: {cv_score:.2f}', transform=ax_density.transAxes, verticalalignment='top')

    plt.tight_layout()
    plt.show()

In [None]:
# relative change scaled
assess_model_performance_all_regions(ds, ds_hist_ensemble_metric, regression_models, test_data, performance_metrics, predictor_vars, residuals, cv_scores)


In [None]:
assess_model_performance_all_regions(ds, ds_hist_ensemble_metric, regression_models, test_data, performance_metrics, predictor_vars, residuals, cv_scores)


In [None]:
assess_model_performance_all_regions(ds, ds_hist_ensemble_metric, regression_models, test_data, performance_metrics, predictor_vars, residuals, cv_scores)


#### Test optimal variable subset

In [None]:
from itertools import chain, combinations

In [None]:
all_combinations = list(all_subsets(predictor_vars)) 
print(f"Number of variable combinations with at least 4 variables: {len(all_combinations)}")

In [None]:
def all_subsets(ss, min_length=4):
    """Generate all combinations of the elements in `ss` with a minimum length of `min_length`."""
    return chain(*map(lambda x: combinations(ss, x), range(min_length, len(ss)+1)))

In [None]:
def test_variable_combinations(ds, all_vars, predictant):
    # Get all possible combinations of predictor variables with at least 4 variables
    all_combinations = list(all_subsets(all_vars))
    print(f"Number of variable combinations: {len(all_combinations)}")

    # Prepare a list to store the performance metrics
    performance_list = []

    for combination in all_combinations:
        # Convert the tuple to a list of variables for this combination
        current_vars = list(combination)

        # Train the models using the current combination of variables
        _, regression_models, _, _, _, _ = train_multivariate_models(ds, current_vars, predictant)
       
        # Assess model performance for all regions
        for region in ds.region.values:
            region_name = ds.names.sel(region=region).values.item()
            
            # Prepare the data for the region
            df = ds.sel(region=region).to_dataframe().dropna()
            X = df[current_vars]
            y_true = df[predictant].values

            # Standardize the predictor variables
            scaler = StandardScaler()
            X_standardized = scaler.fit_transform(X) # standarized using this method might be wrong cause we need the direction. Standarize 
            #with mean 
            
            # Predict using the trained model
            model = regression_models[region_name]
            y_pred = model.predict(X_standardized)
            
            # Calculate performance metrics
            mse = mean_squared_error(y_true, y_pred)
            r2 = r2_score(y_true, y_pred)
            
            # Append the performance metrics for this region and variable combination to the list
            performance_list.append({
                'Variables': ', '.join(current_vars),
                'MSE': mse,
                'R^2': r2
            })

    # Convert the list to a DataFrame
    performance_summary = pd.DataFrame(performance_list)
    
    # Calculate the aggregated performance for each variable combination across all regions
    aggregated_performance = performance_summary.groupby('Variables').agg(['mean', 'min', 'max']).reset_index()
    aggregated_performance.columns = [' '.join(col).strip() for col in aggregated_performance.columns.values]

    # Sort the results by Mean MSE and Mean R^2
    aggregated_performance.sort_values(by=['MSE mean', 'R^2 mean'], ascending=[True, False], inplace=True)

    # Select and rename the columns to only include the required statistics
    final_table = aggregated_performance[['Variables', 'MSE mean', 'MSE min', 'MSE max', 'R^2 mean', 'R^2 min', 'R^2 max']]
    final_table.rename(columns={
        'MSE mean': 'Mean MSE',
        'MSE min': 'Min MSE',
        'MSE max': 'Max MSE',
        'R^2 mean': 'Mean R^2',
        'R^2 min': 'Min R^2',
        'R^2 max': 'Max R^2'
    }, inplace=True)

    return final_table

In [None]:
# Example usage:
pd.set_option('display.max_rows', None)  
test_variable_combinations(ds, predictor_vars, 'bgws')

#### Test variable importance

In [None]:
coefficients_df*1000

#### Cluster regions based on regression coefficients 

In [None]:
from sklearn.metrics import pairwise_distances
from scipy.spatial.distance import cdist

def cluster_regions(coefficients_df, max_clusters=40):
    """
    This function finds the optimal number of clusters using the elbow method
    and clusters regions based on their coefficients.
    
    :param coefficients_df: A pandas DataFrame containing the coefficients with regions as the index.
    :param max_clusters: The maximum number of clusters to test for the elbow method.
    :return: A tuple of pandas DataFrames with the first containing the coefficients and an additional
             column 'Cluster', and the second containing the centroids of each cluster.
    """
    
    # Find the optimal number of clusters using the elbow method
    sum_of_squared_distances = []
    K = range(1, max_clusters + 1)
    for k in K:
        km = KMeans(n_clusters=k,  n_init=10, random_state=42)
        km = km.fit(coefficients_df)
        sum_of_squared_distances.append(km.inertia_)

    # Plot the elbow curve
    plt.plot(K, sum_of_squared_distances, 'bx-')
    plt.xlabel('k')
    plt.ylabel('Sum of squared distances')
    plt.title('Elbow Method For Optimal k')
    plt.show()

    # Ask user for the optimal number of clusters
    n_clusters = int(input("Enter the optimal number of clusters: "))
    
    # Use the KMeans algorithm to find clusters with the optimal number of clusters
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
    kmeans.fit(coefficients_df)

    # Assign the clusters to each region
    clusters = kmeans.labels_
    centroids = kmeans.cluster_centers_

    # Calculate the distance of each region's coefficients to the centroid of its cluster
    distances_to_centroid = cdist(coefficients_df, centroids, 'euclidean')
    min_distances = distances_to_centroid.min(axis=1)

    # Add the cluster labels and distances to the DataFrame
    df = coefficients_df.copy()
    df['Cluster'] = clusters
    df['Distance_to_Centroid'] = (min_distances * 100).round(decimals=2)
    df.index.name = 'Region'
    df.reset_index(inplace=True)
    
    # Order the dataframe by cluster label and distance to centroid
    clustered_df = df.sort_values(['Cluster', 'Distance_to_Centroid'])

    # Get the centroids
    centroids_df = pd.DataFrame(centroids, columns=coefficients_df.columns)
    centroids_df['Cluster'] = range(n_clusters)

    return clustered_df, centroids_df

In [None]:
from sklearn.cluster import KMeans

def cluster_regions(coefficients_df, max_clusters=40):
    """
    This function finds the optimal number of clusters using the elbow method
    and clusters regions based on their coefficients.

    :param coefficients_df: A pandas DataFrame containing the coefficients with regions as the index.
    :param max_clusters: The maximum number of clusters to test for the elbow method.
    :return: A pandas DataFrame with an additional column 'Cluster' indicating the cluster for each region.
    """
    
    # Find the optimal number of clusters using the elbow method
    sum_of_squared_distances = []
    K = range(1, max_clusters + 1)
    for k in K:
        km = KMeans(n_clusters=k,  n_init=10, random_state=42)
        km = km.fit(coefficients_df)
        sum_of_squared_distances.append(km.inertia_)

    # Plot the elbow curve
    plt.plot(K, sum_of_squared_distances, 'bx-')
    plt.xlabel('k')
    plt.ylabel('Sum of squared distances')
    plt.title('Elbow Method For Optimal k')
    plt.show()

    # Ask user for the optimal number of clusters
    n_clusters = int(input("Enter the optimal number of clusters: "))
    
    # Use the KMeans algorithm to find clusters with the optimal number of clusters
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
    kmeans.fit(coefficients_df)

    # Assign the clusters to each region
    clusters = kmeans.labels_

    # Add the cluster labels to the DataFrame
    df = coefficients_df.copy()
    df['Cluster'] = clusters
    df.index.name = 'Region'
    df.reset_index(inplace=True)
    
    # Order the dataframe by cluster label
    clustered_df = df.sort_values('Cluster')

    return clustered_df

In [None]:
import numpy as np
import pandas as pd
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial import distance_matrix

def custom_distance(u, v):
    # Custom distance metric:
    # If any signs differ, the distance is increased
    sign_diff = np.sign(u) != np.sign(v)
    if sign_diff.any():  # If any signs differ, apply the penalty
        return np.linalg.norm(u - v) * 1
    else:
        return np.linalg.norm(u - v)
    
def cluster_regions(coefficients_df, n_clusters=12):
    # Create a new DataFrame for clustering to avoid modifying the original data
    X = coefficients_df.copy()

    # Compute the custom distance matrix
    dist_matrix = distance_matrix(X.values, X.values, p=2)
    for i in range(dist_matrix.shape[0]):
        for j in range(dist_matrix.shape[1]):
            if i != j:  # No need to penalize the diagonal
                dist_matrix[i, j] = custom_distance(X.values[i], X.values[j])

    # Perform Agglomerative Clustering with the precomputed distance matrix
    clustering = AgglomerativeClustering(n_clusters=n_clusters, affinity='precomputed', linkage='complete')
    clusters = clustering.fit_predict(dist_matrix)

    # Assign the clusters to each region
    coefficients_df['Cluster'] = clusters
    
    # Order the dataframe by cluster label
    clustered_df = coefficients_df.sort_values('Cluster')
    clustered_df.index.name = 'Region'
    clustered_df.reset_index(inplace=True)
    
    return clustered_df

In [None]:
clustered_df, centroids_df = cluster_regions(coefficients_df)
clustered_df

#### Plot clusters on map

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors = [(34/255, 139/255, 34/255), (1, 1, 1), (60/255, 145/255, 230/255)]  # Green -> White -> Blue
n_bins = [3]  # Discretizes the interpolation into bins
cmap_name = 'custom_div_cmap'

cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=100)

# To test and display the colormap
fig, ax = plt.subplots(figsize=(6, 1))
ax.set_title('Custom Diverging Colormap')
plt.imshow(np.linspace(0, 1, 256).reshape(1, -1), aspect='auto', cmap=cm)
plt.axis('off')
plt.show()

In [None]:
import cartopy.crs as ccrs
from cartopy.feature import ShapelyFeature
import regionmask
from shapely.geometry.polygon import Polygon
from shapely.geometry.multipolygon import MultiPolygon

In [None]:
from shapely.geometry import MultiPolygon, Polygon

def split_polygon(polygon, meridian=180):
    """
    Splits a Shapely polygon into two polygons at a specified meridian
    """
    minx, miny, maxx, maxy = polygon.bounds
    if maxx > meridian and minx < -meridian:
        # Polygon crosses the antimeridian
        left_poly = []
        right_poly = []
        for x, y in polygon.exterior.coords:
            if x >= meridian:
                right_poly.append((x - 360, y))  # Wraparound for the right side
            else:
                left_poly.append((x, y))
        return [Polygon(left_poly), Polygon(right_poly)]
    else:
        return [polygon]  # Wrap the single polygon in a list for consistency

In [None]:
def plot_clusters_and_regions(ds_ensmed_glob, ds, clustered_df):
    # Initialize the plot with a cartopy projection
    fig = plt.figure(figsize=(30, 15))
    ax_main = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    
    # Plot the 'bgws' variable from the dataset
    ds_ensmed_glob['bgws'].plot(ax=ax_main, vmin=-0.2, vmax=0.2, cmap=cm, transform=ccrs.PlateCarree(), add_colorbar=False)
    
    # Add coastlines and gridlines
    ax_main.coastlines()
    ax_main.tick_params(axis='both', which='major', labelsize=20)
    gridlines = ax_main.gridlines(draw_labels=True, color='black', alpha=0.2, linestyle='--')
    gridlines.top_labels = gridlines.right_labels = False
    gridlines.xlabel_style = {'size': 18}
    gridlines.ylabel_style = {'size': 18}
    
    # Get region bounds using regionmask
    land_regions = regionmask.defined_regions.ar6.land
    
    # Create a mapping from region names to region numbers
    region_name_to_number = dict(zip(ds.names.values, ds.region.values))
    
    # Map the region names to the same order as ds.names.values
    region_to_cluster_map = dict(zip(clustered_df['Region'], clustered_df['Cluster']))

    # Now create a new column in ds that maps the region names to cluster numbers
    # This assumes ds.names.values has the same region names as in clustered_df['Region']
    ds['Cluster'] = [region_to_cluster_map[name] for name in ds.names.values]

    # Convert the 'Cluster' DataArray to a NumPy array and get unique values
    unique_clusters = np.unique(ds['Cluster'].values)

    # Prepare colors for clusters - this assumes a finite number of clusters
    cluster_colors = plt.cm.tab20b(np.linspace(0, 1, len(unique_clusters)))


    # Loop over the regions and plot the cluster numbers with abbreviations
    for region_name, cluster_number in zip(ds.names.values, ds['Cluster']):
        reg_num = region_name_to_number[region_name]
        region_polygons = land_regions[reg_num].polygon
        
        region_abbr = ds.abbrevs.values[ds.region.values == reg_num][0]  # Assuming this gives us the correct abbreviation
        cluster_color = cluster_colors[cluster_number]  # Get the color for the cluster
        
        # Fetch the polygon or polygons for this region
        region_obj = land_regions[reg_num]
        if hasattr(region_obj, 'polygons'):
            # If the attribute is 'polygons', we assume it's iterable (e.g., a list of Polygon objects)
            region_polygons = region_obj.polygons
        elif hasattr(region_obj, 'polygon'):
            # If there's only one Polygon, we wrap it in a list to make it iterable
            region_polygons = [region_obj.polygon]
        else:
            raise AttributeError(f"The region object does not have 'polygons' or 'polygon' attribute.")

        for region_polygon in region_polygons:
            # If the polygon crosses the antimeridian, split it
            split_polys = split_polygon(region_polygon)

            # Handle both Polygon and MultiPolygon types after splitting
            for poly in split_polys:
                if isinstance(poly, Polygon):
                    features_to_plot = [poly]
                elif isinstance(poly, MultiPolygon):
                    features_to_plot = list(poly.geoms)
                else:
                    raise TypeError(f"Unhandled geometry type: {type(poly)}")

                for feature_poly in features_to_plot:
                    feature = ShapelyFeature([feature_poly], ccrs.PlateCarree(), edgecolor=cluster_color, facecolor='none', linewidth=2)
                    ax_main.add_feature(feature)

            # Calculate the centroid for text placement using features_to_plot
            centroids = [feature_poly.centroid for feature_poly in features_to_plot]
    
            text_lon, text_lat = max(centroids, key=lambda c: c.x).coords[0]  # Use the easternmost centroid
        
            # Ensure cluster_number is a plain integer if it's a single-item array or DataArray
            if isinstance(cluster_number, np.ndarray) and cluster_number.size == 1:
                cluster_number = cluster_number.item()  # Converts a one-element array to a scalar
            elif isinstance(cluster_number, xr.DataArray) and cluster_number.ndim == 0:
                cluster_number = cluster_number.values.item()  # Gets the scalar value from a 0-dim DataArray

            # Annotate the cluster number for each region
            ax_main.text(text_lon, text_lat, f"{region_abbr}\n{cluster_number}",
                         horizontalalignment='center', verticalalignment='center', transform=ccrs.PlateCarree(),
                         fontsize=20, bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))

    plt.show()

In [None]:
plot_clusters_and_regions(ds_ensmed_glob, ds, clustered_df)

In [None]:
clustered_df

In [None]:
def plot_simple_map(ds_ensmed_glob, projection, save_fig=False):
    fig = plt.figure(figsize=(30, 15))
    ax_main = fig.add_subplot(1, 1, 1, projection=projection)
    
    # Plot the 'bgws' variable from the dataset
    img = ds_ensmed_glob['bgws'].plot(ax=ax_main, vmin=-0.6, vmax=0.6,  cmap=cm, transform=ccrs.PlateCarree(), add_colorbar=False) 
    
    # Add coastlines and gridlines
    ax_main.coastlines()
    ax_main.tick_params(axis='both', which='major', labelsize=20)
    gridlines = ax_main.gridlines(draw_labels=True, color='black', alpha=0.2, linestyle='--')
    gridlines.top_labels = gridlines.right_labels = False
    gridlines.xlabel_style = {'size': 18}
    gridlines.ylabel_style = {'size': 18}
    
    # Colorbar
    cbar_ax = fig.add_axes([0.314, 0.1, 0.4, 0.025]) #left, bottom, width, height
    cbar = fig.colorbar(img, cax=cbar_ax, extend='both', orientation='horizontal')
    cbar.set_label("Blue-Green Water Share", fontsize=22, weight='bold', labelpad=15) 
    cbar.ax.tick_params(labelsize=18)
   
    plt.show()
    
    # Safe figure
    if save_fig:
        savepath = os.path.join('../..', 'results', 'CMIP6', 'historical', 'time', 'median', 'bgws')
        os.makedirs(savepath, exist_ok=True)
        filename = f'Ensemble_median.1985-2014.bgws.historical.pdf'
        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=600, bbox_inches='tight', format='pdf')
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

In [None]:
plot_simple_map(ds_ensmed_glob, ccrs.Robinson(), save_fig=True) # Robinson PlateCarree

In [None]:
plot_simple_map(ds_ensmed_glob, ccrs.PlateCarree(), save_fig=False)

### Compare Permutation Importance of different models

##### Compare LR test and train

In [None]:
def compare_permutation_importance(permutation_importances_train, permutation_importances_test, performance_metrics_train, performance_metrics_test, predictor_vars):
    comparison_results = {}

    for region in permutation_importances_train.keys():
        mean_importance_train = pd.Series(permutation_importances_train[region].mean(axis=1), index=predictor_vars)
        mean_importance_test = pd.Series(permutation_importances_test[region].mean(axis=1), index=predictor_vars)

        # Rank the variables, excluding negative importances
        rank_train = mean_importance_train.rank(method='dense', ascending=False).where(mean_importance_train >= 0, np.nan)
        rank_test = mean_importance_test.rank(method='dense', ascending=False).where(mean_importance_test >= 0, np.nan)

        # Special rank for zero importance in both training and test
        zero_importance = (mean_importance_train == 0) & (mean_importance_test == 0)
        max_rank = max(rank_train.max(), rank_test.max()) + 1
        rank_train[zero_importance] = max_rank
        rank_test[zero_importance] = max_rank

        # Calculate agreement
        valid_indices = (mean_importance_train >= 0) | (mean_importance_test >= 0)
        agreement = np.mean(rank_train[valid_indices] == rank_test[valid_indices]) * 100

        # Prepare data for dataframe
        data = {
            'Variable': predictor_vars,
            'Train_Rank': rank_train,
            'Test_Rank': rank_test,
            'Train_Mean_Importance': mean_importance_train,
            'Test_Mean_Importance': mean_importance_test,
            'Train_R2': performance_metrics_train[region]['R2'],
            'Test_R2': performance_metrics_test[region]['R2'],
            'Agreement (%)': agreement
        }

        df = pd.DataFrame(data).sort_values(by='Train_Rank')
        comparison_results[region] = df

    return comparison_results

In [None]:
def compare_permutation_importance(permutation_importances_train, permutation_importances_test, performance_metrics_train, performance_metrics_test, predictor_vars, zero_threshold=0.001):
    comparison_results = {}

    for region in permutation_importances_train.keys():
        mean_importance_train = pd.Series(permutation_importances_train[region].mean(axis=1), index=predictor_vars)
        mean_importance_test = pd.Series(permutation_importances_test[region].mean(axis=1), index=predictor_vars)

        # Separate zero and non-zero importance variables
        is_zero_train = mean_importance_train.abs() <= zero_threshold
        is_zero_test = mean_importance_test.abs() <= zero_threshold

        # Rank the variables, treating near-zero importance variables separately
        rank_train = mean_importance_train.rank(method='dense', ascending=False, na_option='bottom').astype('Int64')
        rank_test = mean_importance_test.rank(method='dense', ascending=False, na_option='bottom').astype('Int64')

        # Adjust ranks for zero importance variables
        max_rank = max(rank_train.max(), rank_test.max()) + 1
        rank_train[is_zero_train & is_zero_test] = max_rank
        rank_test[is_zero_train & is_zero_test] = max_rank

        # Calculate agreement
        agreement = np.mean(rank_train == rank_test) * 100

        # Prepare data for dataframe
        data = {
            'Variable': predictor_vars,
            'Train_Rank': rank_train,
            'Test_Rank': rank_test,
            'Train_Mean_Importance': mean_importance_train,
            'Test_Mean_Importance': mean_importance_test,
            'Train_R2': performance_metrics_train[region]['R2'],
            'Test_R2': performance_metrics_test[region]['R2'],
            'Agreement (%)': agreement
        }

        df = pd.DataFrame(data).sort_values(by='Train_Rank')
        comparison_results[region] = df

    return comparison_results

In [None]:
# Example usage
comparison_results = compare_permutation_importance(permutation_importances_train_lr, permutation_importances_test_lr, performance_metrics_train_lr, performance_metrics_test_lr, predictor_vars)

In [None]:
pd.DataFrame(comparison_results.keys())

In [None]:
# Displaying the results for one region as an example
region_number = 8
print(list(comparison_results.keys())[region_number])
comparison_results[list(comparison_results.keys())[region_number]]

In [None]:
# get mean overall agreement and r2

In [None]:
def compute_agreement_metrics(comparison_results):
    overall_agreement = []
    first_rank_agreement = []

    for region, df in comparison_results.items():
        # Calculate overall agreement for the region
        agreement = df['Agreement (%)'].iloc[0]
        overall_agreement.append(agreement)

        # Check if the top-ranked variable is the same in training and test data
        top_train = df[df['Train_Rank'] == 1.0]['Variable'].iloc[0] if any(df['Train_Rank'] == 1.0) else None
        top_test = df[df['Test_Rank'] == 1.0]['Variable'].iloc[0] if any(df['Test_Rank'] == 1.0) else None
        first_rank_agreement.append(top_train == top_test)

    # Calculate average agreement
    avg_overall_agreement = sum(overall_agreement) / len(overall_agreement)
    avg_first_rank_agreement = sum(first_rank_agreement) / len(first_rank_agreement) * 100

    return avg_overall_agreement, avg_first_rank_agreement

In [None]:
avg_overall_agreement, avg_first_rank_agreement = compute_agreement_metrics(comparison_results)
print("Average Overall Agreement:", avg_overall_agreement)
print("Average Agreement on First Rank:", avg_first_rank_agreement)

##### Compare xgb test and train

In [None]:

comparison_results_xgb = compare_permutation_importance(permutation_importances_train_xgb, permutation_importances_test_xgb, performance_metrics_train_xgb, performance_metrics_test_xgb, predictor_vars)

In [None]:
# Displaying the results for one region as an example
region_number = 15
print(list(comparison_results_xgb.keys())[region_number])
comparison_results_xgb[list(comparison_results_xgb.keys())[region_number]]

In [None]:
avg_overall_agreement, avg_first_rank_agreement = compute_agreement_metrics(comparison_results_xgb)
print("Average Overall Agreement:", avg_overall_agreement)
print("Average Agreement on First Rank:", avg_first_rank_agreement)

##### Compare xgb and lr

### Build Gaussian Processes model

Probabilistic Outputs: GPs not only provide a prediction for each data point but also give a measure of uncertainty (variance) associated with that prediction. This can help in understanding the confidence of the model in different regions of the input space.

Non-Linear Relationships: GPs, with the right choice of kernel, can model complex non-linear relationships between inputs and outputs, making them more flexible than traditional linear regression models.

Kernel Flexibility: The kernel in a GP defines the relationship between data points. By selecting or designing a kernel that captures the underlying structure of the data, GPs can be adapted to various types of data patterns.

Finally, regarding variable importance: In the context of GPs, interpreting variable importance is not as straightforward as in linear regression. However, one common approach is to examine the sensitivity of the GP's predictions to changes in each input variable. The Automatic Relevance Determination (ARD) kernel, for instance, can adapt its length scale for each dimension of the input space, which can give an indication of the relative importance of each input variable.

In [None]:
import numpy as np
import pandas as pd
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, WhiteKernel
from sklearn.preprocessing import StandardScaler

def prepare_data_for_gp(ds, region_index):
    # Extract data for the given region
    region_data = ds.sel(region=region_index)
    
    # Prepare predictor and target variables
    X = region_data[['pr', 'vpd', 'evspsbl', 'mrro', 'mrso', 'tran', 'lai', 'gpp']].to_array().transpose('lat', 'lon', 'variable')
    y = region_data['bgws'].stack(z=("lat", "lon"))
    
    # Convert X to a 2D array
    X = X.stack(z=("lat", "lon")).transpose('z', 'variable').values
    
    # Create a mask where either X or y has NaN values
    mask = ~np.isnan(y) & ~np.any(np.isnan(X), axis=1)
    
    # Filter out rows using the mask
    X_filtered = X[mask, :]
    y_filtered = y[mask].values  # Convert to numpy array
    
    return X_filtered, y_filtered

def train_gp_for_region(ds, region_name):
    # Prepare data
    X, y = prepare_data_for_gp(ds, region_name)
    
    # Standardize the features
    scaler = StandardScaler().fit(X)
    X_standardized = scaler.transform(X)
    
    # Define the kernel: RBF kernel with ARD + constant term + white noise term for model noise
    # Kernel = Covarianzfunction:
    # Defines relation between input variables
    # 
    kernel = C(1.0, (1e-3, 1e3)) * RBF(np.ones(8), (1e-2, 1e2)) + WhiteKernel(noise_level=1, noise_level_bounds=(1e-10, 1e+1))
    
    # Initialize and train GP regressor
    gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, normalize_y=True)
    gp.fit(X_standardized, y)
    
    return gp, scaler

def predict_with_gp(gp, scaler, X):
    # Standardize the features
    X_standardized = scaler.transform(X)
    
    # Predict using the GP regressor
    y_pred, y_std = gp.predict(X_standardized, return_std=True)
    
    return y_pred, y_std

def train_and_predict_for_all_regions(ds):
    region_indices = ds.region.values
    region_names = ds.names.values
    gp_models = {}
    scalers = {}
    predictions = {}
    std_devs = {}
    
    for i, region_index in enumerate(region_indices):
        region_name = region_names[i]
        gp, scaler = train_gp_for_region(ds, region_index)
        X, y = prepare_data_for_gp(ds, region_index)
        
        y_pred, y_std = predict_with_gp(gp, scaler, X)
        
        gp_models[region_name] = gp
        scalers[region_name] = scaler
        predictions[region_name] = y_pred
        std_devs[region_name] = y_std
    
    return gp_models, scalers, predictions, std_devs

In [None]:
gp_models, scalers, predictions, std_devs = train_and_predict_for_all_regions(ds)

In [None]:
predictions

In [None]:
import math
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import matplotlib.lines as mlines

def assess_model_performance_all_regions(ds, predictions, std_devs):
    region_names = ds.names.values
    region_indices = ds.region.values
    
    # Determine the number of rows and columns for the subplots based on the number of regions
    num_regions = len(region_names)
    cols = 3  # Assuming you'd like 3 columns
    rows = math.ceil(num_regions / cols)
    
    fig, axs = plt.subplots(rows, cols, figsize=(15, 5*rows))
    
    for i, region_index in enumerate(region_indices):
        region_name = region_names[i]
        
        row = i // cols
        col = i % cols
        
        # Extract true values
        _, y_true = prepare_data_for_gp(ds, region_index)
        
        # Extract predicted values and standard deviations for the region
        y_pred = predictions[region_name]
        y_std = std_devs[region_name]
        
        # Calculate residuals
        residuals = y_pred - y_true
        
        # Calculate the RMSE
        mse_value = mean_squared_error(y_true, y_pred)
        r2_s = r2_score(y_true, y_pred)
        
        # Plot
        ax = axs[row, col]
        ax.scatter(y_true, residuals, c='blue', alpha=0.5, s=10)
        ax.fill_between(y_true, residuals - y_std, residuals + y_std, color='gray', alpha=0.2)
        ax.axhline(0, color='red', lw=2)
        ax.text(0.05, 0.95, f'MSE: {mse_value:.6f}', transform=ax.transAxes, verticalalignment='top')
        ax.text(0.05, 0.9, f'R2 score: {r2_s:.2f}', transform=ax.transAxes, verticalalignment='top')
        ax.set_title(region_name)
        ax.set_xlabel('True Values')
        ax.set_ylabel('Residuals')
        
    # Create legend with explicit handles in the last subplot
    if num_regions < rows * cols:
        last_ax = axs.flatten()[-1]
        last_ax.axis('off')
        legend_elements = [mlines.Line2D([0], [0], color='blue', marker='o', markersize=10, label='Residuals (Observed - Predicted)', linestyle='None'),
                           mlines.Line2D([0], [0], color='gray', alpha=0.2, linewidth=10, label='Prediction Uncertainty (±1 Std. Dev.)'),
                           mlines.Line2D([0], [0], color='red', lw=2, label='Zero Residual Line')]
        last_ax.legend(handles=legend_elements, loc='center')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Example usage:
assess_model_performance_all_regions(ds, predictions, std_devs)

In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt

def assess_variable_importance_all_regions(ds, gp_models, variable_names):
    region_names = ds.names.values
    region_indices = ds.region.values
    
    # Determine the number of rows and columns for the subplots based on the number of regions
    num_regions = len(region_names)
    cols = 3  # Assuming you'd like 3 columns
    rows = math.ceil(num_regions / cols)
    
    fig, axs = plt.subplots(rows, cols, figsize=(15, 5*rows))
    
    for i, region_index in enumerate(region_indices):
        region_name = region_names[i]
        
        row = i // cols
        col = i % cols
        
        ax = axs[row, col]
        
        # Extract the ARD kernel length scales from the trained GP model for the region
        length_scales = gp_models[region_name].kernel_.k1.k2.length_scale
        
        # Plot
        ax.bar(variable_names, length_scales)
        ax.set_title(region_name)
        ax.set_xlabel('Variable Name')
        ax.set_ylabel('Length Scale')
        ax.tick_params(axis='x', rotation=45)  # Rotate x-axis labels for better visibility
        
    # Remove empty subplots
    for i in range(num_regions, rows*cols):
        fig.delaxes(axs.flatten()[i])
        
    plt.tight_layout()
    plt.show()

In [None]:
# Call the function
variable_names = ['pr', 'vpd', 'evspsbl', 'mrro', 'mrso', 'tran', 'lai', 'gpp']

assess_variable_importance_all_regions(ds, gp_models, variable_names)

In [None]:
from sklearn.inspection import permutation_importance
import matplotlib.pyplot as plt
import numpy as np
import math

def assess_permutation_importance(ds, gp_models, variable_names, predictions, n_repeats=30):
    """
    Assess the permutation importance of features for all regions.

    Parameters:
    - ds: Dataset containing region names, indices, and bgws values.
    - gp_models: Dictionary containing trained GP models for each region.
    - variable_names: List of variable names.
    - predictions: Predicted values for each region.
    - n_repeats: Number of times to permute a feature.

    Returns:
    - None. Plots the importance.
    """
    
    region_names = ds.names.values
    region_indices = ds.region.values
    y = ds.bgws.values  # Extracting target values from ds
    
    # Determine the number of rows and columns for the subplots based on the number of regions
    num_regions = len(region_names)
    cols = 3  # Assuming you'd like 3 columns
    rows = math.ceil(num_regions / cols)
    
    fig, axs = plt.subplots(rows, cols, figsize=(15, 5*rows))
    
    for region_name in region_names:
        region_index = (ds['region'].values == region_name)

        # Assuming ds has columns for each predictor variable and they are named consistently
        # Extract all predictor variables for this region
        X_region = ds[variable_names][region_index]

        y_region = y[region_index]  
        result = permutation_importance(gp_models[region_name], X_region, y_region, n_repeats=n_repeats)

        # Sort variables by importance
        sorted_idx = result.importances_mean.argsort()
        
        # Plot
        ax.boxplot(result.importances[sorted_idx].T,
                   vert=False, labels=np.array(variable_names)[sorted_idx])
        ax.set_title(region_name)
        ax.set_xlabel('Importance Score')
        
    # Remove empty subplots
    for i in range(num_regions, rows*cols):
        fig.delaxes(axs.flatten()[i])
        
    plt.tight_layout()
    plt.show()

In [None]:
# Example usage:
assess_permutation_importance(ds, gp_models, variable_names, predictions)