# CMIP6 Regional variable change

**Following steps are included in this script:**

1. Load netCDF files
2. Compute regional variable change
3. Plot change in reional parallel coordinate plots

### Import Packages

In [None]:
# ========== Packages ==========
import glob
import os
import xarray as xr
import pandas as pd
import multiprocessing as mp
import copy

#import numpy as np
#import copy
#import seaborn as sns
#import matplotlib.pyplot as plt
#import cartopy.crs as ccrs

#import matplotlib.cm
#from matplotlib import rcParams
#import math
#import multiprocessing as mp
#from cftime import DatetimeNoLeap


#from matplotlib.lines import Line2D
#from sklearn.metrics import r2_score
#import matplotlib.colors as colors
#from matplotlib.patches import Patch
#import matplotlib.patches as mpatches



#%matplotlib inline

#rcParams["mathtext.default"] = 'regular'

### Functions

#### Open files

In [None]:
def open_dataset(filename):
    """Opens and returns an xarray Dataset from a given filename."""
    return xr.open_dataset(filename)

def open_and_merge_datasets(folder, model, experiment_id, temp_res, variables):
    """Opens and merges datasets for a given model and experiment ID, across specified variables."""
    merged_ds = None
    for var in variables:
        # Build the file path pattern
        path_pattern = os.path.join('../../data/CMIP6', experiment_id, folder, temp_res, var, 
                                    f'CMIP.{model}.{experiment_id}.{var}.regridded.nc')
        # Find all files that match the pattern
        matching_files = glob.glob(path_pattern)
        if not matching_files:
            print(f"No file found for variable '{var}' in model '{model}'. Searched pattern: {path_pattern}")
            continue

        # Assuming only one match per variable, which seems to be the expectation
        file_path = matching_files[0]
        ds = open_dataset(file_path)

        if merged_ds is None:
            merged_ds = ds
        else:
            merged_ds = xr.merge([merged_ds, ds])

    if merged_ds is None:
        raise ValueError(f"No datasets found for model '{model}' and experiment '{experiment_id}'. Please check the inputs.")
    
    return merged_ds

#### Preprocessing

##### Select period

In [None]:
def select_period(ds_dict, start_year=None, end_year=None, period=None, yearly_sum=False):
    '''
    Helper function to select periods and optionally compute yearly sums.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    period (int, list, str, None): Single month (int), list of months (list), multiple seasons (str) to select,
                                   or None to not select any specific period.
    yearly_sum (bool): If True, compute the yearly sum over the selected period.
    '''
    
    # Create a deep copy of the original ds_dict to avoid modifying it directly
    ds_dict_copy = copy.deepcopy(ds_dict)

    # Define season to month mapping for northern hemisphere
    seasons_to_months = {
        'winter': [12, 1, 2],
        'spring': [3, 4, 5],
        'summer': [6, 7, 8],
        'fall': [9, 10, 11]
    }
    
    # Define month name mapping
    month_names = {
        1: 'J', 2: 'F', 3: 'M', 4: 'A', 5: 'M', 6: 'J',
        7: 'J', 8: 'A', 9: 'S', 10: 'O', 11: 'N', 12: 'D'
    }

    # Define number of days per month (assuming 28 days for February)
    days_per_month = {
        1: 31, 2: 28, 3: 31, 4: 30, 5: 31, 6: 30,
        7: 31, 8: 31, 9: 30, 10: 31, 11: 30, 12: 31
    }

    months = []

    # If no specific period is selected, all data will be used.
    if period is None:
        period_name = 'whole_year'
        months = list(range(1, 13))  # All months
    elif isinstance(period, int):
        period_name = month_names[period]
        months = [period]
    elif isinstance(period, str):
        # Check if the input is a single season or multiple seasons
        if 'and' in period:
            seasons = period.lower().split('and')
            period_name = ''
            for season in seasons:
                season = season.strip()
                months.extend(seasons_to_months.get(season, []))
                period_name += ''.join(month_names[m] for m in seasons_to_months.get(season, []))
        else:
            months = seasons_to_months.get(period.lower(), [])
            period_name = ''.join(month_names[m] for m in months)
    elif isinstance(period, list):
        period_name = ''.join(month_names[m] for m in period if m in month_names)
        months = period
    else:
        raise ValueError("Period must be None, an integer, a string representing a single season, "
                         "a string with multiple seasons separated by 'and', or a list of integers.")

    for k, ds in ds_dict_copy.items():
        if start_year and end_year:
            start_date = f'{start_year}-01-01'
            end_date = f'{end_year}-12-31'
            ds = ds.sel(time=slice(start_date, end_date))

        # If months are specified, select those months
        if months:
            month_mask = ds['time.month'].isin(months)
            ds = ds.where(month_mask, drop=True)

        # Store the original attributes of each variable
        original_attrs = {var: ds[var].attrs for var in ds.data_vars}

        # If yearly_sum is True, sum over 'time' dimension to get yearly sum
        if yearly_sum: # does only make sense for accumulative variables e.g. pr or tran
            attrs = ds.attrs
            # Multiply each value by the number of days in the respective month
            days = ds['time'].dt.days_in_month
            ds = (ds * days).resample(time='AS').sum(dim='time')
            sum_type = 'yearly_sum'
            ds.attrs = attrs
        else:
            sum_type = 'monthly_mean'

        # Reassign the original attributes back to each variable
        for var in ds.data_vars:
            ds[var].attrs = original_attrs[var]

        ds_dict_copy[k] = ds
        ds_dict_copy[k].attrs['months'] = period_name
        ds_dict_copy[k].attrs['yearly_sum'] = sum_type

    return ds_dict_copy

##### Compute period mean/median

In [None]:
def compute_statistic(ds_dict, period_statistic, dimension, yearly_mean=True):
    """
    Computes the specified statistic for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        period_statistic (str): The statistic to compute, which can be one of 'mean', 'std', 'min', 'var', or 'median'.
        dimension (str): The dimension to compute over, which can be 'time' or 'space'.
        yearly_mean (boolean): Defines the time interval for statistic computations over space. yearly_mean is True so 30 values for 30 year period.
        
    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if period_statistic not in ["mean", "std", "min", "max", "var", "median"]:
        raise ValueError(f"Invalid statistic '{period_statistic}' specified.")
    if dimension not in ["time", "space"]:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")  
        
    # Use multiprocessing to compute the statistic for each dataset in parallel
    with mp.Pool() as pool:
        results = pool.starmap(compute_statistic_single, [(ds, period_statistic, dimension, yearly_mean) for ds in ds_dict.values()])

    return dict(zip(ds_dict.keys(), results))

In [None]:
def compute_statistic_single(ds, period_statistic, dimension, yearly_mean=True):
    if dimension == "time":
        stat_ds = getattr(ds, period_statistic)("time", keep_attrs=True, skipna=True)
        stat_ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
    if dimension == "space":
        # Assign the period attribute before grouping by year
        ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
        if yearly_mean:
            ds = ds.groupby('time.year').mean('time', keep_attrs=True, skipna=True)
            ds.attrs['mean'] = 'yearly mean'
            
        
        #get the weights, apply on data, and compute statistic
        weights = np.cos(np.deg2rad(ds.lat))
        weights.name = "weights"
        ds_weighted = ds.weighted(weights)
        stat_ds = getattr(ds_weighted, period_statistic)(("lon", "lat"), keep_attrs=True, skipna=True)
    
    stat_ds.attrs['statistic'] = period_statistic
    stat_ds.attrs['statistic_dimension'] = dimension

    return stat_ds

#### Helper functions

In [None]:
def drop_var(ds_dict, var):
    for name, ds in ds_dict.items():
        ds_dict[name] = ds.drop(var)
        
    return ds_dict

In [None]:
def select_period(ds_dict, start_year=None, end_year=None, period=None, yearly_sum=False):
    '''
    Helper function to select periods and optionally compute yearly sums.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    period (int, list, str, None): Single month (int), list of months (list), multiple seasons (str) to select,
                                   or None to not select any specific period.
    yearly_sum (bool): If True, compute the yearly sum over the selected period.
    '''
    
    # Create a deep copy of the original ds_dict to avoid modifying it directly
    ds_dict_copy = copy.deepcopy(ds_dict)

    # Define season to month mapping for northern hemisphere
    seasons_to_months = {
        'nh_winter': [12, 1, 2],
        'nh_spring': [3, 4, 5],
        'nh_summer': [6, 7, 8],
        'nh_fall': [9, 10, 11]
    }
    
    # Define month name mapping
    month_names = {
        1: 'J', 2: 'F', 3: 'M', 4: 'A', 5: 'M', 6: 'J',
        7: 'J', 8: 'A', 9: 'S', 10: 'O', 11: 'N', 12: 'D'
    }

    # Define number of days per month (assuming 28 days for February)
    days_per_month = {
        1: 31, 2: 28, 3: 31, 4: 30, 5: 31, 6: 30,
        7: 31, 8: 31, 9: 30, 10: 31, 11: 30, 12: 31
    }

    months = []

    # If no specific period is selected, all data will be used.
    if period is None:
        period_name = 'whole_year'
        months = list(range(1, 13))  # All months
    elif isinstance(period, int):
        period_name = month_names[period]
        months = [period]
    elif isinstance(period, str):
        # Check if the input is a single season or multiple seasons
        if 'and' in period:
            seasons = period.lower().split('and')
            period_name = ''
            for season in seasons:
                season = season.strip()
                months.extend(seasons_to_months.get(season, []))
                period_name += ''.join(month_names[m] for m in seasons_to_months.get(season, []))
        else:
            months = seasons_to_months.get(period.lower(), [])
            period_name = ''.join(month_names[m] for m in months)
    elif isinstance(period, list):
        period_name = ''.join(month_names[m] for m in period if m in month_names)
        months = period
    else:
        raise ValueError("Period must be None, an integer, a string representing a single season, "
                         "a string with multiple seasons separated by 'and', or a list of integers.")

    for k, ds in ds_dict_copy.items():
        if start_year and end_year:
            start_date = f'{start_year}-01-01'
            end_date = f'{end_year}-12-31'
            ds = ds.sel(time=slice(start_date, end_date))

        # If months are specified, select those months
        if months:
            month_mask = ds['time.month'].isin(months)
            ds = ds.where(month_mask, drop=True)

        # Store the original attributes of each variable
        original_attrs = {var: ds[var].attrs for var in ds.data_vars}

        # If yearly_sum is True, sum over 'time' dimension to get yearly sum
        if yearly_sum: # does only make sense for accumulative variables e.g. pr or tran
            attrs = ds.attrs
            # Multiply each value by the number of days in the respective month
            days = ds['time'].dt.days_in_month
            ds = (ds * days).resample(time='AS').sum(dim='time')
            sum_type = 'yearly_sum'
            ds.attrs = attrs
        else:
            sum_type = 'monthly_mean'

        # Reassign the original attributes back to each variable
        for var in ds.data_vars:
            ds[var].attrs = original_attrs[var]

        ds_dict_copy[k] = ds
        ds_dict_copy[k].attrs['months'] = period_name
        ds_dict_copy[k].attrs['yearly_sum'] = sum_type

    return ds_dict_copy


In [None]:
# ======== Standardize ========
def standardize(ds_dict):
    '''
    Helper function to standardize datasets of a dictionary
    '''
    ds_dict_stand = {}
    
    for name, ds in ds_dict.items():
        attrs = ds.attrs
        ds_stand = (ds - ds.mean()) / ds.std()

        # Preserve variable attributes from the original dataset
        for var in ds.variables:
            if var in ds_stand.variables:
                ds_stand[var].attrs = ds[var].attrs

        ds_stand.attrs = attrs
        ds_dict_stand[name] = ds_stand
        
    return ds_dict_stand

In [None]:
def check_args_and_get_info(ds_dict, variable):
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if not isinstance(variable, str):
        raise TypeError('variable must be a string.')
        
    # Dictionary to store plot titles for each statistic
    titles = {"mean": "Mean", "std": "Standard deviation of yearly means", "min": "Minimum", "max": "Maximum", "median": "Median", "time": "Time", "space": "Space"}
    freq = {"mon": "Monthly"}
    
    long_name = {
        'Precipitation': 'Precipitation',
        'Total Runoff': 'Total Runoff',
        'Vapor Pressure Deficit': 'Vapor Pressure Deficit',
        'Evaporation Including Sublimation and Transpiration': 'Evapotranspiration',
        'Transpiration': 'Transpiration',
        'Leaf Area Index': 'Leaf Area Index',
        'Carbon Mass Flux out of Atmosphere Due to Gross Primary Production on Land [kgC m-2 s-1]': 'Gross Primary Production',
        'Total Liquid Soil Moisture Content of 1 m Column': '1 m Soil Moisture',
        'Total Liquid Soil Moisture Content of 2 m Column': '2 m Soil Moisture',
        'Runoff - Precipitation': 'Runoff - Precipitation',
        'Transpiration - Precipitation': 'Transpiration - Precipitation',
        '(Runoff + Transpiration) - Precipitation':  '(Runoff + Transpiration) - Precipitation',
        'ET - Precipitation':  'ET - Precipitation', 
        'Negative Runoff': 'Negative Runoff',
    }
   
    # Data information
    var_long_name = ds_dict[list(ds_dict.keys())[0]][variable].long_name
    period = f"{ds_dict[list(ds_dict.keys())[0]].attrs['period'][0]}-{ds_dict[list(ds_dict.keys())[0]].attrs['period'][1]}"
    experiment_id =  ds_dict[list(ds_dict.keys())[0]].experiment_id
    unit = ds_dict[list(ds_dict.keys())[0]][variable].units
    statistic_dim = ds_dict[list(ds_dict.keys())[0]].statistic_dimension
    statistic = ds_dict[list(ds_dict.keys())[0]].attrs['statistic']
    frequency = freq[ds_dict[list(ds_dict.keys())[0]].frequency]

    return var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency

In [None]:
def compute_ensemble(ds_dict_change):
    for key in ['Ensemble mean', 'Ensemble median']:
        if key in ds_dict_change:
            ds_dict_change.pop(key)

    # Drop 'member_id' coordinate if it exists in any of the datasets
    for ds_key in ds_dict_change:
        if 'member_id' in ds_dict_change[ds_key].coords:
            ds_dict_change[ds_key] = ds_dict_change[ds_key].drop('member_id')

    combined = xr.concat(ds_dict_change.values(), dim='ensemble')
    ds_dict_change['Ensemble median'] = getattr(combined, 'median')(dim='ensemble')
    
    return ds_dict_change

In [None]:
import regionmask

def apply_region_mask(ds_dict):
    """
    Applies the AR6 land region mask to datasets in the provided dictionary and adds a region dimension.

    Args:
        ds_dict (dict): A dictionary of xarray datasets.

    Returns:
        dict: A new dictionary where keys are the same as in the input dictionary,
              and each value is an xarray Dataset with a region dimension added to each variable.
    """
    
    land_regions = regionmask.defined_regions.ar6.land
    ds_masked_dict = {}
    
    for ds_name, ds in ds_dict.items():
        ds_out = xr.Dataset()  # Initiate an empty Dataset for the masked data
        
        # Get attributes
        attrs = ds.attrs
        
        for var in ds:
            # Get the binary mask
            mask = land_regions.mask_3D(ds[var])
            
            var_attrs = ds[var].attrs

            # Multiply the original data with the mask to get the masked data
            masked_var = ds[var] * mask

            # Replace 0s with NaNs, if desired
            masked_var = masked_var.where(masked_var != 0)

            # Add the masked variable to the output Dataset
            ds_out[var] = masked_var
            
            ds_out[var].attrs = var_attrs
            
        # Add the attributes
        ds_out.attrs = attrs

        # Add the Dataset to the output dictionary
        ds_masked_dict[ds_name] = ds_out

    return ds_masked_dict

In [None]:
def is_numeric(data):
    try:
        _ = data.astype(float)
        return True
    except (ValueError, TypeError):
        return False

def compute_change(ds_dict_hist, ds_dict_fut, var_rel_change=None):
    ds_dict_change = {}

    for name, ds_hist in ds_dict_hist.items():
        if name in ds_dict_fut:
            ds_future = ds_dict_fut[name]
            common_vars = set(ds_hist.data_vars).intersection(ds_future.data_vars)

            ds_change = ds_hist.copy(deep=True)
            
            if var_rel_change == 'all':
                var_rel_change = common_vars
                
            for var in common_vars:
                if is_numeric(ds_hist[var].data) and is_numeric(ds_future[var].data):
                    # Always compute percentage change for 'mrso' as models have different depths
                    if var == 'mrso':
                        rel_change = (ds_future[var] - ds_hist[var]) / ds_hist[var].where(ds_hist[var] != 0) * 100
                        ds_change[var].data = rel_change.data
                        ds_change[var].attrs['units'] = '%'
                    elif var_rel_change is not None and var in var_rel_change:
                        # Compute relative change where ds_hist is not zero for specified variables
                        rel_change = (ds_future[var] - ds_hist[var]) / ds_hist[var].where(ds_hist[var] != 0) * 100
                        ds_change[var].data = rel_change.data
                        ds_change[var].attrs['units'] = '%'
                    else:
                        # Compute absolute change for other variables
                        abs_change = ds_future[var] - ds_hist[var]
                        ds_change[var].data = abs_change.data

            ds_change.attrs = ds_future.attrs
            ds_dict_change[name] = ds_change

    return ds_dict_change

#### Compute statistics

In [None]:
def calculate_spatial_mean(ds_dict):
    ds_dict_mean = {}
    for key, ds in ds_dict.items():
        attrs = ds.attrs
        for var in list(ds.data_vars.keys()):
            var_attrs = ds[var].attrs
            
            ds_dict_mean[key][var] = ds.mean(['lon', 'lat'])
            ds_dict_mean[key][var].attrs = var_attrs
        
        ds_dict_mean[key].attrs = attrs
        
    return ds_dict_mean

In [None]:
from itertools import permutations

def precompute_metrics(ds_dict, variables, metrics=['pearson']):
    # Initialize the results dictionary
    results_dict = {metric: {} for metric in metrics}
    
    for name, ds in ds_dict.items():
        # Create a DataFrame with all the variables
        df = pd.DataFrame({var: ds[var].values.flatten() for var in variables})
        
        # Define all pairs of variables
        pairs = list(permutations(variables, 2))  # <-- Change here
        args = [(df, var1, var2, metrics) for var1, var2 in pairs]

        # Use a multiprocessing pool to compute the metrics for all pairs
        with Pool() as p:
            results = p.map(compute_metrics_for_pair, args)
        
        # Store the results in the results_dict
        for var1, var2, metric_dict in results:
            for metric, value in metric_dict.items():
                # Ensure the keys exist in the dictionary
                results_dict[metric].setdefault(name, {}).setdefault(f'{var1}_{var2}', value)
    return results_dict

In [None]:
def compute_stats(ds_dict):
    """
    Compute yearly mean of each variable in the dataset.

    Parameters:
    ds_dict (dict): The input dictionary of xarray.Dataset.

    Returns:
    dict: A dictionary where the keys are the dataset names and the values are another dictionary.
          This inner dictionary has keys as variable names and values as DataArray of yearly means.
    """
    stats = {}
    for model, ds in ds_dict.items():
        # Compute the yearly mean
        yearly_ds = ds.resample(time='1Y').mean()

        stats[model] = {}
        for var in yearly_ds.data_vars:
            # Compute the spatial mean
            spatial_mean = yearly_ds[var].mean(dim=['lat', 'lon'])
            
            # Store the yearly mean values
            stats[model][var] = spatial_mean
    return stats

In [None]:
def compute_yearly_means(ds_dict):
    yearly_means_dict = {}

    # For each dataset, compute the yearly mean over the 'time', 'lat', and 'lon' dimensions
    for name, ds in ds_dict.items():  
        ds_yearly = ds.groupby('time.year').mean('time')    
        
        yearly_means_dict[name] = ds_yearly

    return yearly_means_dict

In [None]:
def compute_yearly_regional_means(ds_dict_region):
    yearly_means_dict = {}

    # For each dataset, compute the yearly mean over the 'time', 'lat', and 'lon' dimensions
    for region, ds_dict in ds_dict_region.items():
        yearly_means_dict[region] = {}
        for ds_name, ds in ds_dict.items():
            # Compute the yearly mean
            ds_yearly = ds.groupby('time.year').mean('time')
            
            # Create weights
            weights = np.cos(np.deg2rad(ds.lat))
            # Apply the weights and calculate the spatial mean
            ds_weighted = ds_yearly.weighted(weights)
            yearly_means_dict[region][ds_name] = ds_weighted.mean(('lat', 'lon'))

    return yearly_means_dict

In [None]:
def calculate_spatial_mean(ds_dict):
    ds_dict_mean = {}
    
    for key, ds in ds_dict.items():
        attrs = ds.attrs
        
        # Initialize a new Dataset for this key
        ds_dict_mean[key] = xr.Dataset()
        
        for var in list(ds.data_vars.keys()):
            var_attrs = ds[var].attrs
            
            ds_dict_mean[key][var] = ds[var].mean(['lon', 'lat'])
            ds_dict_mean[key][var].attrs = var_attrs
        
        ds_dict_mean[key].attrs = attrs
        
    return ds_dict_mean

#### Compute metrics

In [None]:
def compute_bgws(ds_dict):

    for model, ds in ds_dict.items():
        bgws = (ds['mrro']-ds['tran'])/ds['pr']

        # Replace infinite values with NaN
        bgws = xr.where(np.isinf(bgws), float('nan'), bgws)

        # Set all values above 2 and below -2 to NaN
        bgws = xr.where(bgws > 2, float('nan'), bgws)
        bgws = xr.where(bgws < -2, float('nan'), bgws)

        ds_dict[model]['bgws'] = bgws
        ds_dict[model]['bgws'].attrs = {'long_name': 'Blue Green Water Share',
                             'units': ''}
        
    return ds_dict

In [None]:
def compute_wue(ds_dict):

    for model, ds in ds_dict.items():
        if 'gpp' in ds.variables:
            wue = ds['gpp']/ds['tran']

            # Replace infinite values with NaN
            wue = xr.where(np.isinf(wue), float('nan'), wue)

            # Set all values above 4 and below -4 to ±5
            wue = xr.where(wue > 4, 5, wue)
            wue = xr.where(wue < -4, -5, wue)

            ds_dict[model]['wue'] = wue
            ds_dict[model]['wue'].attrs = {'long_name': 'Water Use Efficiency',
                                 'units': ''}
        else:
            pass
        
    return ds_dict

In [None]:
def compute_rgtr(ds_dict):

    for model, ds in ds_dict.items():
        if 'gpp' in ds.variables:
            
            gpp_standardized = ds['gpp'] / ds['gpp'].max()
            tas_standardized = (ds['tas'] + 273.15) / (ds['tas'] + 273.15).max()

            # Compute Relative GPP-Temperature Response
            rgtr = gpp_standardized / tas_standardized

            # Replace infinite values with NaN
            rgtr = xr.where(np.isinf(rgtr), float('nan'), rgtr)
            
            # Set all values below 0 to 0
            rgtr = xr.where(rgtr < 0, 0, rgtr)

            ds_dict[model]['rgtr'] = rgtr
            ds_dict[model]['rgtr'].attrs = {'long_name': 'Relative GPP-Temperature Response',
                                 'units': ''}
        else:
            pass
        
    return ds_dict

In [None]:
def compute_etp(ds_dict):
    """
    Computes the partitioning of evapotranspiration into evaporation and transpiration components,
    scaled between -100 and 100 to represent the share of evaporation to transpiration in the total
    evapotranspiration (ET), where -100 indicates total evaporation and 100 indicates total transpiration.

    Parameters:
    - ds_dict (dict): A dictionary of datasets, where each dataset contains variables
                      for evaporation (E), transpiration (T), and evapotranspiration (ET).

    Returns:
    - dict: The same dictionary with an added variable for each model representing the
            Evapotranspiration Partitioning metric.
    """
    
    for model, ds in ds_dict.items():
        if {'evapo', 'tran', 'evspsbl'}.issubset(ds.variables):
            # Ensure ET is consistent with E + T if necessary or use directly as provided
            total_ET = ds['evspsbl']
            
            # Calculate the partitioning of ET into E and T, scaled between -100 and 100
            et_partitioning = ((ds['evapo'] - ds['tran']) / total_ET) * 100
            
            # Add the new metric to the dataset
            ds_dict[model]['et_partitioning'] = et_partitioning
            ds_dict[model]['et_partitioning'].attrs = {
                'long_name': 'Evapotranspiration Partitioning',
                'units': '%',
                'description': 'Indicates the share of Evaporation to Transpiration in ET, scaling between -100 (total evaporation) and 100 (total transpiration)'
            }
        else:
            print(f"Model {model} lacks necessary variables ('evapo', 'tran', 'evspsbl'). Skipping.")
            
    return ds_dict

In [None]:
def calculate_monthly_mean(ds, variable):
    """
    Calculate the mean for each month across all years for a given variable in the dataset.

    Parameters:
    ds (xarray.Dataset): The input dataset.
    variable (str): The name of the variable to calculate the monthly mean for.

    Returns:
    xarray.Dataset: A dataset containing the monthly mean for the specified variable.
    """
    # Calculate the monthly mean
    monthly_means = ds[variable].groupby('time.month').mean('time')

    # Create a new dataset for the monthly means
    monthly_mean_ds = xr.Dataset({variable: monthly_means})

    # Copy attributes from the original dataset and variable
    monthly_mean_ds.attrs = ds.attrs
    monthly_mean_ds[variable].attrs = ds[variable].attrs
    monthly_mean_ds[variable].attrs['description'] = f'Monthly mean of {variable}'

    return monthly_mean_ds

In [None]:
def calculate_growing_season_length(ds_monthly_mean):
    lai = ds_monthly_mean['lai']

    # Create a DataArray for storing the growing season length
    growing_season_length = xr.DataArray(
        np.nan, 
        dims=('lat', 'lon'), 
        coords={'lat': ds_monthly_mean.lat, 'lon': ds_monthly_mean.lon}
    )

    for lat in ds_monthly_mean.lat.values:
        for lon in ds_monthly_mean.lon.values:
            lai_ts = lai.sel(lat=lat, lon=lon)

            # Check if the LAI time series is entirely NaN (ocean cell)
            if lai_ts.isnull().all():
                continue  # Skip this cell, it remains NaN in growing_season_length

            monthly_pct_change = calculate_monthly_pct_change(lai_ts)
            starts, ends = detect_season_starts_and_ends(monthly_pct_change)
            length = calculate_season_length(starts, ends)
            growing_season_length.loc[dict(lat=lat, lon=lon)] = length

    # Create a new dataset for the growing season length
    ds_growing_season_length = xr.Dataset()
    ds_growing_season_length['growing_season_length'] = growing_season_length
    ds_growing_season_length['growing_season_length'].attrs = {
        'description': 'Length of the growing season in months',
        'calculated_from': 'LAI'
    }

    return ds_growing_season_length


def calculate_monthly_pct_change(lai_ts):
    monthly_pct_change = ((lai_ts - lai_ts.roll(month=1)) / lai_ts.roll(month=1)) * 100
    monthly_pct_change[0] = ((lai_ts[0] - lai_ts[-1]) / lai_ts[-1]) * 100
    return monthly_pct_change

def detect_season_starts_and_ends(monthly_pct_change):
    starts, ends = [], []
    
    # Check if December starts or ends a growing season 
    # STARTS
    # Wenn November kleiner gleich 0 und December größer 0 beginnt die GS
    if monthly_pct_change[-2] <= 0 and monthly_pct_change[-1] > 0:
        starts.append(12) 
    # ENDS
    # Wenn November größer 0 und December kleiner gleich 0 endet die GS
    if monthly_pct_change[-2] > 0 and monthly_pct_change[-1] <= 0:
        ends.append(12)
    
    #Check if January stats or ends a growing season
    # STARTS
    # Wenn Dezember kleiner gleich 0 und Januar größer Null beginnt die GS
    if monthly_pct_change[-1] <= 0 and monthly_pct_change[0] > 0:
        starts.append(1) 
    # ENDS
    # Wenn Dezember größer 0 und Januar kleiner gleich 0
    if monthly_pct_change[-1] > 0 and monthly_pct_change[0] <= 0:
        ends.append(1)
        
    # Now handle February to November for starts and ends
    for i in range(1, len(monthly_pct_change) - 1):
        if monthly_pct_change[i-1] <= 0 and monthly_pct_change[i] > 0:
            starts.append(i + 1)
        elif monthly_pct_change[i-1] > 0 and monthly_pct_change[i] <= 0:
            ends.append(i + 1)
            
    return starts, ends

def calculate_season_length(starts, ends):
    growing_season_length = 0
    
    # Calculate the growing season length(s)
    if len(starts) > 1:
        if any(end > starts[0] for end in ends):
            closest_end = min(filter(lambda end: end > starts[0], ends), key=lambda end: end - starts[0], default=None)
            growing_season_length = closest_end - starts[0]
        else:
            growing_season_length = min(ends) + (12 - starts[0])

        if any(end > starts[1] for end in ends):
            growing_season_length = growing_season_length + max(ends) - starts[1]
        else:
            growing_season_length = growing_season_length + min(ends) + (12 - starts[1])
    elif len(starts) == 1:
        if starts[0] < ends[0]:
            growing_season_length = ends[0] - starts[0]
        elif starts[0] > ends[0]:
            growing_season_length = (12 - starts[0]) + ends[0]
    
    return growing_season_length

### Load netCDF files

In [None]:
def load_and_preprocess(var='all', scenarios='all', period=None, yearly_sum=False, period_statistic='mean'):
    # Load the datasets for all specified scenarios
    ds_dict = load_data(var=var, scenarios=scenarios)

    # Define scenario-specific start and end years
    scenario_years = {
        'historical': (1985, 2014),
        'ssp370': (2071, 2100)
    }

    # Loop over each loaded scenario
    for scenario in ds_dict:
        # Get the start and end year for the current scenario
        start_year, end_year = scenario_years.get(scenario, (None, None))

        # If the scenario is not recognized, skip further processing
        if start_year is None or end_year is None:
            print(f"Warning: Start and end years not defined for scenario '{scenario}'. Skipping.")
            continue

        # Apply period selection and yearly sum computation
        ds_dict[scenario] = select_period(ds_dict[scenario], start_year=start_year, end_year=end_year, 
                                          period=period, yearly_sum=yearly_sum)

        # Compute the specified statistic
        ds_dict[scenario] = compute_statistic(ds_dict[scenario], period_statistic=period_statistic, dimension='time')

    return ds_dict


In [None]:
def load_data(var='all', scenarios='all', models='all'):
    folder = 'preprocessed'
    temp_res = 'month'

    # Define default variables if 'all' is selected
    default_variables = ['tas', 'pr', 'vpd', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo']
    variables = var if var != 'all' else default_variables

    # Define default experiment IDs if 'all' is selected
    default_experiment_ids = ['historical', 'ssp370']
    experiment_ids = scenarios if scenarios != 'all' else default_experiment_ids

    # Define the source IDs (models)
    source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM', 
                 'CNRM-CM6-1', 'CNRM-ESM2-1', 'GFDL-ESM4', 'GISS-E2-1-G', 'MIROC-ES2L', 
                 'MPI-ESM1-2-LR', 'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL']

    # Initialize the dictionary to store datasets
    ds_dict = {}

    # Loop over each experiment ID and source ID to load and merge datasets
    for experiment_id in experiment_ids:
        ds_dict[experiment_id] = {}
        for model in source_id:
            try:
                ds = open_and_merge_datasets(folder, model, experiment_id, temp_res, variables)
                ds_dict[experiment_id][model] = ds
            except ValueError as e:
                print(f"Failed to load data for model {model} in scenario {experiment_id}: {e}")

    return ds_dict

In [None]:
ds_dict = load_data()

In [None]:
categorized_variables=categorize_variables(['tas', 'pr', 'vpd', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo', 'RX5day', 'growing_season_length'], period=None)


In [None]:
categorized_variables

In [None]:
list(categorized_variables.keys())[0]

In [None]:
def categorize_variables(vars_selected, period):
    # Define categories based on temporal resolution
    monthly_variables = ['tas', 'pr', 'vpd', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo']
    yearly_variables = ['RX5day']
    
    # Adjust the handling of period
    if period is not None:
        period_variables = [f'growing_season_length_{period}']  # Use the specific period if provided
        yearly = period  # The default yearly category name
    else:
        period_variables = ['growing_season_length_period']  # Use a general name if no specific period is provided
        yearly = 'year'  # The default yearly category name
        period = 'period'  # Use 'period' as the category name if no specific period is provided
    
    # Initialize categories
    categories = {
        'month': [],
        yearly: [],
        period: []
    }

    # Check the selection and categorize
    if vars_selected == 'all':
        categories['month'].extend(monthly_variables)
        categories[yearly].extend(yearly_variables)
        categories[period].extend(period_variables)
    else:
        for var in vars_selected:
            if var in monthly_variables:
                categories['month'].append(var)
            elif var in yearly_variables:
                categories[yearly].append(var)
            elif var == 'growing_season_length':  # Use a general check for the variable name
                categories[period].append(var if period is None else f'growing_season_length_{period}')
            else:
                print(f"Warning: Variable '{var}' not recognized.")
                
    return categories

In [None]:
def load_data(scenario, source_id, folder, variables, temp_res):
    """
    Generalized function to load and merge datasets across different temporal resolutions.
    """
    ds_dict = {}
    
    for model in source_id:
        ds = open_and_merge_datasets(folder, model, scenario, temp_res, variables)
        ds_dict[model] = ds

    return ds_dict

In [None]:
def load_and_preprocess(vars='all', scenarios='all', models='all', period=None, yearly_sum=False, period_statistic='mean'):
    # Categorize variables based on temporal resolution
    categorized_variables = categorize_variables(vars, period)
    
    # Initialize dictionaries for processed datasets
    ds_dict = {}
    
    # Define default scenarios if 'all' is selected
    experiment_ids = ['historical', 'ssp126', 'ssp370', 'ssp585'] if scenarios == 'all' else scenarios
    
    print(experiment_ids)
    
    # Source models (IDs)
    source_ids = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                 'CNRM-ESM2-1', 'GFDL-ESM4', 'GISS-E2-1-G', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 
                 'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL'] if models == 'all' else models
    
    # Less models are available for daily pr data
    source_ids_rx5day = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                 'CNRM-ESM2-1', 'GFDL-ESM4','MIROC-ES2L', 'MPI-ESM1-2-LR', 
                 'NorESM2-MM', 'UKESM1-0-LL'] if models == 'all' else models

    # Loop through each scenario
    for scenario in experiment_ids:
        print(scenario)
        ds_dict_temp = {}
        ds_dict[scenario] = {}
        for temp_res, vars_in_res in categorized_variables.items():
            
            print(temp_res, vars_in_res)
            
            if vars_in_res == ['RX5day']:
                ds_dict_temp[f'{temp_res}'] = load_data(scenario, source_ids_rx5day, 'preprocessed', vars_in_res, temp_res)
            else:
                ds_dict_temp[f'{temp_res}'] = load_data(scenario, source_ids, 'preprocessed', vars_in_res, temp_res)
                
            if temp_res == 'month':
                ds_dict_temp[temp_res] = select_period(ds_dict_temp[temp_res], start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)
                ds_dict_temp[temp_res] = compute_statistic(ds_dict_temp[temp_res], period_statistic=period_statistic, dimension='time')
                print(f'Period mean of {temp_res} data computed')
            elif temp_res == 'year':
                #ds_dict_temp[temp_res] = select_period(ds_dict_temp[temp_res], start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)
                ds_dict_temp[temp_res] = compute_statistic(ds_dict_temp[temp_res], period_statistic=period_statistic, dimension='time')
                print(f'Period mean of {temp_res} data computed')
                
            print(ds_dict_temp.keys())
        
        for model in source_ids_rx5day:
            ds_dict[scenario][model] = xr.merge([ds_dict_temp['month'][model], ds_dict_temp['year'][model], ds_dict_temp['period'][model]])

    return ds_dict

In [None]:
def load_and_preprocess(vars='all', scenarios='all', models='all', period=None, yearly_sum=False, period_statistic='mean'):
    # Categorize variables based on temporal resolution
    categorized_variables = categorize_variables(vars, period)
    
    # Initialize dictionaries for processed datasets
    ds_dict = {}
    
    # Define default scenarios if 'all' is selected
    experiment_ids = ['historical', 'ssp126', 'ssp370', 'ssp585'] if scenarios == 'all' else scenarios
    
    print(experiment_ids)
    
    # Source models (IDs)
    source_ids = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                 'CNRM-ESM2-1', 'GFDL-ESM4', 'GISS-E2-1-G', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 
                 'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL'] if models == 'all' else models
    
    # Less models are available for daily pr data
    source_ids_rx5day = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                 'CNRM-ESM2-1', 'GFDL-ESM4','MIROC-ES2L', 'MPI-ESM1-2-LR', 
                 'NorESM2-MM', 'UKESM1-0-LL'] if models == 'all' else models
    
    # Loop through each scenario
    for scenario in experiment_ids:
        print(scenario)
        ds_dict_temp = {}
        ds_dict[scenario] = {}
        for temp_res, vars_in_res in categorized_variables.items():
            
            print(temp_res, vars_in_res)
            
            if vars_in_res == ['RX5day']:
                ds_dict_temp[f'{temp_res}'] = load_data(scenario, source_ids_rx5day, 'preprocessed', vars_in_res, temp_res)
            else:
                ds_dict_temp[f'{temp_res}'] = load_data(scenario, source_ids, 'preprocessed', vars_in_res, temp_res)
                
            if temp_res == 'month':
                ds_dict_temp[temp_res] = select_period(ds_dict_temp[temp_res], start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)
                ds_dict_temp[temp_res] = compute_statistic(ds_dict_temp[temp_res], period_statistic=period_statistic, dimension='time')
                print(f'Period mean of {temp_res} data computed')
            elif temp_res == 'year':
                #ds_dict_temp[temp_res] = select_period(ds_dict_temp[temp_res], start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)
                ds_dict_temp[temp_res] = compute_statistic(ds_dict_temp[temp_res], period_statistic=period_statistic, dimension='time')
                print(f'Period mean of {temp_res} data computed')
                
            print(ds_dict_temp.keys())
            
    # Merging logic starts here
        # First, merge for models that include RX5day data or have data for all temporal resolutions
        for model in source_ids_rx5day:
            datasets_to_merge = []
            for temp_res in ['month', 'year', 'period']:
                if temp_res in ds_dict_temp and model in ds_dict_temp[temp_res]:
                    datasets_to_merge.append(ds_dict_temp[temp_res][model])
            if datasets_to_merge:
                ds_dict[scenario][model] = xr.merge(datasets_to_merge)
        
        # Then, include models not in source_ids_rx5day, ensuring they're merged even if they lack 'year' data
        models_excluding_rx5day = set(source_ids) - set(source_ids_rx5day)
        for model in models_excluding_rx5day:
            datasets_to_merge = []
            # Exclude 'year' from temp_res as these models don't have RX5day data
            for temp_res in ['month', 'period']:
                if temp_res in ds_dict_temp and model in ds_dict_temp[temp_res]:
                    datasets_to_merge.append(ds_dict_temp[temp_res][model])
            if datasets_to_merge:
                # Ensure there's a dictionary entry for the model
                if model not in ds_dict[scenario]:
                    ds_dict[scenario][model] = xr.merge(datasets_to_merge)
                else:
                    # If the model already has data (from previous steps), merge new data with existing
                    ds_dict[scenario][model] = xr.merge([ds_dict[scenario][model]] + datasets_to_merge)

    return ds_dict


In [None]:
def get_default_settings(models='all'):
    source_ids = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                  'CNRM-ESM2-1', 'GFDL-ESM4', 'GISS-E2-1-G', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 
                  'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL'] if models == 'all' else models
    source_ids_rx5day = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                         'CNRM-ESM2-1', 'GFDL-ESM4', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 
                         'NorESM2-MM', 'UKESM1-0-LL'] if models == 'all' else models
    return source_ids, source_ids_rx5day

In [None]:
def load_and_preprocess_data_for_temp_res(scenario, temp_res, vars_in_res, period, yearly_sum, period_statistic, source_ids, source_ids_rx5day):
    if vars_in_res == ['RX5day']:
        current_source_ids = source_ids_rx5day
    else:
        current_source_ids = source_ids
    
    loaded_data = load_data(scenario, current_source_ids, 'preprocessed', vars_in_res, temp_res)
    processed_data = select_period(loaded_data, start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)
    processed_data = compute_statistic(processed_data, period_statistic=period_statistic, dimension='time')
    return processed_data

In [None]:
def merge_datasets_for_model(scenario, ds_dict_temp, source_ids, source_ids_rx5day):
    merged_data = {}
    models_excluding_rx5day = set(source_ids) - set(source_ids_rx5day)
    
    for model in source_ids:
        datasets_to_merge = []
        for temp_res in ['month', 'year', 'period']:
            if temp_res in ds_dict_temp and model in ds_dict_temp[temp_res]:
                datasets_to_merge.append(ds_dict_temp[temp_res][model])
        if datasets_to_merge:
            merged_data[model] = xr.merge(datasets_to_merge)
    
    for model in models_excluding_rx5day:
        if model not in merged_data:
            datasets_to_merge = [ds_dict_temp[res][model] for res in ['month', 'period'] if model in ds_dict_temp[res]]
            if datasets_to_merge:
                merged_data[model] = xr.merge(datasets_to_merge)
    return merged_data

In [None]:
def load_and_preprocess(vars='all', scenarios='all', models='all', period=None, yearly_sum=False, period_statistic='mean'):
    source_ids, source_ids_rx5day = get_default_settings(models)
    categorized_variables = categorize_variables(vars, period)
    ds_dict = {}
    experiment_ids = ['historical', 'ssp126', 'ssp370', 'ssp585'] if scenarios == 'all' else scenarios
    
    for scenario in experiment_ids:
        ds_dict_temp = {'month': {}, 'year': {}, 'period': {}}
        ds_dict[scenario] = {}
        for temp_res, vars_in_res in categorized_variables.items():
            ds_dict_temp[temp_res] = load_and_preprocess_data_for_temp_res(scenario, temp_res, vars_in_res, period, yearly_sum, period_statistic, source_ids, source_ids_rx5day)
        
        ds_dict[scenario] = merge_datasets_for_model(scenario, ds_dict_temp, source_ids, source_ids_rx5day)
    
    return ds_dict

In [None]:
ds_dict=load_and_preprocess(vars='all', scenarios=['historical', 'ssp370'], models='all', period='winter', yearly_sum=False, period_statistic='mean')


In [None]:
ds_dict['historical']

In [None]:
def load_and_preprocess(var='all', scenarios='all', period=None, yearly_sum=False, period_statistic='mean'):
    # Initial variable and scenario setup
    monthly_variables = ['tas', 'pr', 'vpd', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo']
    if var != 'all':
        variables = var
    else:
        monthly_variables = ['tas', 'pr', 'vpd', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo']
        yearly_variables = ['RX5day']
        period_variables = ['growing_season_length']
        
    variables = var if var != 'all' else monthly_variables
    default_scenarios = ['historical', 'ssp370']
    scenarios = scenarios if scenarios != 'all' else default_scenarios
    
    # Source ID configuration remains the same
    source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 'CNRM-ESM2-1', 'GFDL-ESM4', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 'NorESM2-MM', 'UKESM1-0-LL']
    
    # Load monthly resolved data
    ds_dict_monthly = load_data_general(scenarios, source_id, 'preprocessed', variables, 'month')

    # Process additional data with different temporal resolutions
    variables_period_mean = ['growing_season_length']
    ds_dict_period_mean = load_data_general(scenarios, source_id, 'preprocessed', variables_period_mean, 'period_mean')

    variables_year = ['RX5day']
    ds_dict_year = load_data_general(scenarios, source_id, 'preprocessed', variables_year, 'year')

    # Merge additional data into the main dataset dictionary
    for scenario in scenarios:
        ds_dict_monthly[scenario] = select_period(ds_dict_monthly[scenario], start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)
        ds_dict_monthly[scenario] = compute_statistic(ds_dict_monthly[scenario], period_statistic=period_statistic, dimension='time')
        
        
        
   
            
        for model in source_id:
            if model in ds_dict_monthly[scenario]:
                # Correctly apply `select_period`
                selected_ds = select_period(ds_dict_monthly[scenario][model], start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)

                # Correctly apply `compute_statistic`
                computed_ds = compute_statistic(selected_ds, period_statistic=period_statistic, dimension='time')

                # Integrate additional data (assuming these functions return a Dataset)
                if model in ds_dict_period_mean[scenario]:
                    computed_ds['growing_season_length'] = ds_dict_period_mean[scenario][model]['growing_season_length']
                if model in ds_dict_year[scenario]:
                    computed_ds['RX5day'] = ds_dict_year[scenario][model]['RX5day']

                # Store the updated Dataset back into the dictionary
                ds_dict_monthly[scenario][model] = computed_ds

    return ds_dict_monthly

In [None]:
def load_and_preprocess(vars='all', scenarios='all', models='all', period=None, yearly_sum=False, period_statistic='mean'):
    categorized_variables = categorize_variables(vars, period)
    
    ds_dict = {}
    
    experiment_ids = ['historical', 'ssp126', 'ssp370', 'ssp585'] if scenarios == 'all' else scenarios
    
    initial_source_ids = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                          'CNRM-ESM2-1', 'GFDL-ESM4', 'GISS-E2-1-G', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 
                          'NorESM2-MM', 'TaiESM1', 'UKESM1-0-LL'] if models == 'all' else models

    for scenario in experiment_ids:
        ds_dict[scenario] = {}
        for model_name in initial_source_ids:
            ds_dict[scenario][model_name] = []

        for temp_res, vars_in_res in categorized_variables.items():
            # Adjust source_ids dynamically based on variable requirements
            temp_source_ids = initial_source_ids.copy()
            if 'RX5day' in vars_in_res:
                temp_source_ids = [src for src in temp_source_ids if src in ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CESM2-WACCM', 'CNRM-CM6-1', 
                                                                             'CNRM-ESM2-1', 'GFDL-ESM4', 'MIROC-ES2L', 'MPI-ESM1-2-LR', 
                                                                             'NorESM2-MM', 'UKESM1-0-LL']]

            for model_name in temp_source_ids:
                # Load data for the specific model and variable
                data = load_data(scenario, [model_name], 'preprocessed', vars_in_res, temp_res)
                # Select period and compute statistics as needed
                data = select_period(data, start_year=1985 if scenario == 'historical' else 2071, end_year=2014 if scenario == 'historical' else 2100, period=period, yearly_sum=yearly_sum)
                data = compute_statistic(data, period_statistic=period_statistic, dimension='time')
                ds_dict[scenario][model_name].append(data)

        # After loading and processing all variables for a scenario
        for model_name in ds_dict[scenario]:
            # Combine all variable datasets for each model
            if ds_dict[scenario][model_name]:  # Check if there are datasets to combine
                ds_dict[scenario][model_name] = xr.combine_by_coords(ds_dict[scenario][model_name], combine_attrs='override')

    return ds_dict

In [None]:
ds_dict = load_and_preprocess()

In [None]:
# ========= Define period, models and path ==============
#variables=['tas', 'pr', 'vpd', 'evspsbl', 'mrro', 'lmrso_1m', 'lmrso_2m', 'tran', 'lai', 'gpp', 'EI', 'wue']
variables=['tas', 'pr', 'vpd', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo']
folder='preprocessed'
temp_res = 'month'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# Create dictionary using a dictionary comprehension and Dask
experiment_id = 'historical'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','GISS-E2-1-G','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','TaiESM1','UKESM1-0-LL']
ds_dict_hist = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, temp_res, variables) for model in source_id})[0]

experiment_id = 'ssp370'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5-CanOE', 'CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','GISS-E2-1-G','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','TaiESM1','UKESM1-0-LL']
ds_dict_ssp370 = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, temp_res, variables) for model in source_id})[0]

In [None]:
# ========= Define period, models and path ==============
variables=['growing_season_length']
folder='preprocessed'
temp_res = 'period_mean'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# Create dictionary using a dictionary comprehension and Dask
experiment_id = 'historical'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0','CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','UKESM1-0-LL']
ds_dict_hist_period_mean = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, temp_res, variables) for model in source_id})[0]

experiment_id = 'ssp370'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','UKESM1-0-LL']
ds_dict_ssp370_period_mean = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, temp_res, variables) for model in source_id})[0]

In [None]:
# ========= Define period, models and path ==============
variables=['RX5day']
folder='preprocessed'
temp_res = 'year'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# Create dictionary using a dictionary comprehension and Dask
experiment_id = 'historical'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0','CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','UKESM1-0-LL']
ds_dict_hist_idx = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, temp_res, variables) for model in source_id})[0]

experiment_id = 'ssp370'
source_id = ['BCC-CSM2-MR', 'CAMS-CSM1-0', 'CanESM5', 'CESM2-WACCM','CNRM-CM6-1', 'CNRM-ESM2-1','GFDL-ESM4','MIROC-ES2L', 'MPI-ESM1-2-LR','NorESM2-MM','UKESM1-0-LL']
ds_dict_ssp370_idx = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, temp_res, variables) for model in source_id})[0]

In [None]:
ds_dict_ssp370[list(ds_dict_hist.keys())[0]].lai

In [None]:
# ============= Have a look into the data ==============
#print(ds_dict_ssp126.keys())
ds_dict_ssp370_period[list(ds_dict_hist.keys())[1]].evspsbl.isel(time=0).plot()

### Select period

In [None]:
#'nh_winter': [12, 1, 2],'nh_spring': [3, 4, 5],'nh_summer': [6, 7, 8], 'nh_fall': [9, 10, 11]

In [None]:
ds_dict_hist_period = select_period(ds_dict_hist, start_year=1985, end_year=2014, period=None, yearly_sum=False)

In [None]:
#ds_dict_ssp126_period = select_period(ds_dict_ssp126, start_year=2071, end_year=2100, period=None, yearly_sum=False)

In [None]:
ds_dict_ssp370_period = select_period(ds_dict_ssp370, start_year=2071, end_year=2100, period=None, yearly_sum=False)

In [None]:
#ds_dict_ssp585_period = select_period(ds_dict_ssp585, start_year=2071, end_year=2100, period=None, yearly_sum=False)

### Compute statistics

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_hist_period_metric = compute_statistic(ds_dict_hist_period, 'mean', 'time')

In [None]:
#ds_dict_ssp126_period_metric = compute_statistic(ds_dict_ssp126_period, 'mean', 'time')

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_ssp370_period_metric = compute_statistic(ds_dict_ssp370_period, 'mean', 'time')

In [None]:
#ds_dict_ssp585_period_metric = compute_statistic(ds_dict_ssp585_period, 'mean', 'time')

In [None]:
ds_dict_hist_idx_period_metric = compute_statistic(ds_dict_hist_idx, 'mean', 'time')

In [None]:
for name, ds in ds_dict_hist_idx_period_metric.items():
    ds_dict_hist_period_metric[name]['RX5day'] = ds.RX5day
    ds_dict_hist_period_metric[name]['growing_season_length'] = ds_dict_hist_period_mean[name]['growing_season_length']

In [None]:
ds_dict_ssp370_idx_period_metric = compute_statistic(ds_dict_ssp370_idx, 'mean', 'time')

In [None]:
for name, ds in ds_dict_ssp370_idx_period_metric.items():
    ds_dict_ssp370_period_metric[name]['RX5day'] = ds.RX5day
    ds_dict_ssp370_period_metric[name]['growing_season_length'] = ds_dict_ssp370_period_mean[name]['growing_season_length']

### Compute variables

#### Compute BGWS

In [None]:
ds_dict_hist_period_metric = compute_bgws(ds_dict_hist_period_metric)

In [None]:
#ds_dict_ssp126_period_metric = compute_bgws(ds_dict_ssp126_period_metric)

In [None]:
ds_dict_ssp370_period_metric = compute_bgws(ds_dict_ssp370_period_metric)

In [None]:
#ds_dict_ssp585_period_metric = compute_bgws(ds_dict_ssp585_period_metric)

#### Compute WUE

In [None]:
ds_dict_hist_period_metric = compute_wue(ds_dict_hist_period_metric)

In [None]:
ds_dict_ssp370_period_metric = compute_wue(ds_dict_ssp370_period_metric)

#### Compute ET Partitioning

In [None]:
ds_dict_hist_period_metric = compute_etp(ds_dict_hist_period_metric)

In [None]:
ds_dict_ssp370_period_metric = compute_etp(ds_dict_ssp370_period_metric)

### Select regions

In [None]:
ds_dict_hist_period_metric_regions = {}
ds_dict_hist_period_metric_regions = apply_region_mask(ds_dict_hist_period_metric)

In [None]:
#ds_dict_ssp126_period_metric_regions = {}
#ds_dict_ssp126_period_metric_regions = apply_region_mask(ds_dict_ssp126_period_metric)

In [None]:
ds_dict_ssp370_period_metric_regions = {}
ds_dict_ssp370_period_metric_regions = apply_region_mask(ds_dict_ssp370_period_metric)

In [None]:
#ds_dict_ssp585_period_metric_regions = {}
#ds_dict_ssp585_period_metric_regions = apply_region_mask(ds_dict_ssp585_period_metric)

### Compute regional mean

In [None]:
# Compute spatial mean of regional data
ds_dict_hist_period_metric_regional_mean = {}
ds_dict_hist_period_metric_regional_mean = calculate_spatial_mean(ds_dict_hist_period_metric_regions)

In [None]:
# Compute spatial mean of regional data
#ds_dict_ssp126_period_metric_regional_mean = {}
#ds_dict_ssp126_period_metric_regional_mean = calculate_spatial_mean(ds_dict_ssp126_period_metric_regions)

In [None]:
# Compute spatial mean of regional data
ds_dict_ssp370_period_metric_regional_mean = {}
ds_dict_ssp370_period_metric_regional_mean = calculate_spatial_mean(ds_dict_ssp370_period_metric_regions)

In [None]:
# Compute spatial mean of regional data
#ds_dict_ssp585_period_metric_regional_mean = {}
#ds_dict_ssp585_period_metric_regional_mean = calculate_spatial_mean(ds_dict_ssp585_period_metric_regions)

### Compute Change

In [None]:
#rel
ds_dict_region_change_ssp126 = compute_change(ds_dict_hist_period_metric_regional_mean, ds_dict_ssp126_period_metric_regional_mean, var_rel_change=['pr', 'vpd', 'evspsbl', 'mrro', 'mrso', 'tran', 'lai', 'gpp'])
for name,ds in ds_dict_region_change_ssp126.items():
    if 'member_id' in ds:
        ds_dict_region_change_ssp126[name] = ds_dict_region_change_ssp126[name].drop('member_id')
#abs
#ds_dict_region_change_ssp126 = compute_change(ds_dict_hist_period_metric_regional_mean, ds_dict_ssp126_period_metric_regional_mean)

In [None]:
#rel
#ds_dict_region_change_ssp370 = compute_change(ds_dict_hist_period_metric_regional_mean, ds_dict_ssp370_period_metric_regional_mean)#, var_rel_change=['pr', 'vpd', 'evspsbl', 'mrro', 'mrso', 'tran', 'lai', 'gpp'])

#abs
ds_dict_region_change_ssp370 = compute_change(ds_dict_hist_period_metric_regional_mean, ds_dict_ssp370_period_metric_regional_mean)

In [None]:
#rel
ds_dict_region_change_ssp585 = compute_change(ds_dict_hist_period_metric_regional_mean, ds_dict_ssp585_period_metric_regional_mean, var_rel_change=['pr', 'vpd', 'evspsbl', 'mrro', 'mrso', 'tran', 'lai', 'gpp'])

#abs
#ds_dict_region_change_ssp585 = compute_change(ds_dict_hist_period_metric_regional_mean, ds_dict_ssp585_period_metric_regional_mean)

In [None]:
#for name, ds in ds_dict_region_change_ssp585.items():
#    # Drop variables only if they are present in the dataset
#    variables_to_drop_present = [var for var in variables_to_drop if var in ds.variables]
#    ds_dict_region_change_ssp585[name] = ds.drop_vars(variables_to_drop_present)

### Select important regions

In [None]:
selected_regions = ['W.North-America',
 'Mediterranean',
 'S.E.Asia',
 'N.South-America',
 'S.E.South-America',
 'Central-Africa']

In [None]:
selected_indices = [next(i for i, name in enumerate(ds_dict_hist_period_metric_regional_mean[list(ds_dict_hist.keys())[0]].names.values) if name == region) for region in selected_regions]

In [None]:
selected_indices

### Plot regional var change

In [None]:
def compute_ensemble(ds_dict_change):
    for key in ['Ensemble mean', 'Ensemble median']:
        if key in ds_dict_change:
            ds_dict_change.pop(key)
    
    combined = xr.concat(ds_dict_change.values(), dim='ensemble')
    ds_dict_change['Ensemble mean'] = getattr(combined, 'mean')(dim='ensemble')
    
    return ds_dict_change

In [None]:
def calculate_ticks(min_val, max_val, base=20):
    """Calculate tick positions for the y-axis."""
    # Calculate the 10% of the min and max values
    offset_min = abs(min_val) * 0.1
    offset_max = max_val * 0.1

    # Find the next "round" number beyond min and max values based on the base
    lower_tick = base * np.floor(min_val / base)
    upper_tick = base * np.ceil(max_val / base)

    # Ensure lower_tick is not more than 10% lower than min_val
    if min_val - lower_tick > offset_min:
        lower_tick += base

    # Ensure upper_tick is not more than 10% higher than max_val
    if upper_tick - max_val > offset_max:
        upper_tick -= base

    # If the lower_tick or upper_tick is equal to the base, divide the base by 2
    if lower_tick >= -base or upper_tick <= base:
        base = base / 2
        # Recalculate the ticks with the new base
        lower_tick = base * np.floor(min_val / base)
        upper_tick = base * np.ceil(max_val / base)

    # Include zero and extend to the next round number beyond the data's min and max
    ticks = [lower_tick] if lower_tick < 0 else []
    ticks += [0]  # Always include zero
    ticks += [upper_tick] if upper_tick > 0 else []

    # Generate intermediate ticks between the round numbers
    intermediate_ticks = np.arange(lower_tick + base, upper_tick, base)
    ticks.extend(intermediate_ticks)

    return sorted(ticks)

In [None]:
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import BoundaryNorm

# Define start and end colors for both gradients
deep_blue = (20/255, 110/255, 180/255)  
light_blue = (180/255, 215/255, 255/255)
deep_green = (14/255, 119/255, 14/255) 
light_green = (160/255, 220/255, 140/255)

# Create custom colormaps
blue_cmap = LinearSegmentedColormap.from_list("blue_cmap", [light_blue, deep_blue], N=4)
green_cmap = LinearSegmentedColormap.from_list("green_cmap", [deep_green, light_green], N=4)

# Sample colors from colormaps
blue_colors = [blue_cmap(i) for i in np.linspace(0, 1, 4)]
green_colors = [green_cmap(i) for i in np.linspace(0, 1, 4)]

# Combine both gradients
combined_grad = green_colors + blue_colors 

# Define boundaries
boundaries = [-0.1, -0.75, -0.05, -0.25, 0, 0.25, 0.05, 0.75, 0.1]
norm = BoundaryNorm(boundaries, len(combined_grad), clip=True)

cmap_name = 'BGWS colormap'
bgws_cm = LinearSegmentedColormap.from_list(cmap_name, combined_grad, N=len(combined_grad))

# To test and display the colormap
fig, ax = plt.subplots(figsize=(6, 1))
ax.set_title(cmap_name)
plt.imshow(np.linspace(0, 1, 256).reshape(1, -1), aspect='auto', cmap=bgws_cm)
plt.axis('off')
plt.show()

In [None]:
def plot_region_change(ds_dict_change, selected_indices="ALL", selected_vars=None, save_fig=True):
    
    # Compute ensemble data
    ds_dict_change = compute_ensemble(ds_dict_change)

    # Determine the type of change: relative or absolute
    change = 'rel_change' if ds_dict_change[list(ds_dict_change.keys())[0]]['pr'].units == '%' else 'abs_change'

    # Extract models and variables information
    models = list(ds_dict_change.keys())
    ensemble = ds_dict_change['Ensemble mean']
    if selected_vars is None:
        variables = [var for var in ensemble.data_vars.keys() if var != 'bgws']
    else:
        variables = [var for var in selected_vars if var in ensemble.data_vars.keys()]
    experiment_id = ds_dict_change[list(ds_dict_change.keys())[0]].experiment_id
    description = ds_dict_change[list(ds_dict_change.keys())[0]].description
    # Check for 'period' and 'yearly_sum' attributes, set defaults if not found
    months = ds_dict_change[list(ds_dict_change.keys())[0]].months
    yearly_sum = ds_dict_change[list(ds_dict_change.keys())[0]].yearly_sum


    # Create a map for variable display names
    var_map = {
        'tas': 'T', 'vpd': 'VPD', 'gpp': 'GPP', 'pr': 'P', 'mrro': 'R',
        'evspsbl': 'ET', 'tran': 'T', 'evapo': 'E', 'lai': 'Lai', 'mrso': 'SM', 
        'rgtr':'P/T', 'et_partitioning': 'EP', 'growing_season_length': 'GS', 'RX5day': 'P5'
    }
    display_vars = [
        f"$\Delta\, \mathrm{{\it{{{var_map[var]}}}}}$" if var in var_map else var for var in variables
    ]

    # Handle selection of indices
    selected_indices = (ensemble.region.values.tolist()
                        if selected_indices == "ALL" else selected_indices)

    # Setup plot grid
    #nregions = len(selected_indices)
    #ncols = 9
    #nrows = (nregions + ncols - 1) // ncols 
    
    #fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8 * ncols, 7 * nrows))

    # Setup plot grid
    nrows = len(selected_indices) // 9 + (len(selected_indices) % 9 > 0)
    fig, axes = plt.subplots(nrows=nrows, ncols=9, figsize=(8 * 9, 7 * nrows), squeeze=False)
    
    # bgws_colormap
    bgws_vals = np.array([ds_dict_change[model]['bgws'].values for model in models])
    bgws_min, bgws_max = np.nanmin(bgws_vals), np.nanmax(bgws_vals)
    norm = plt.Normalize(vmin=-0.1, vmax=0.1)
    
    # Preliminary settings
    threshold = 100  # Set the threshold for y-axis limits
    capped_value = 105  # The value to assign to capped data points

    if selected_indices == "ALL":
        selected_indices = ds_dict_change['Ensemble mean'].region.values.tolist()
        
    # Iterate over each region to plot data
    for ridx, region_idx in enumerate(selected_indices):
        ax = axes.flatten()[ridx]
        max_change = 0  # Track the maximum change for y-axis limits   
        min_change = 0 
        
        # Collect all y-values to determine if any exceed the threshold
        all_y_vals = []
        
        for idx, model in enumerate(models):
            model_data = ds_dict_change[model]
            y = [
                model_data[var].sel(region=region_idx).values
                if var in ds_dict_change[model].data_vars else float('nan')
                for var in variables
            ]
            all_y_vals.extend(y)
            
            # Ensure y is a NumPy array for element-wise comparison
            y_array = np.array(y)
            
            # Cap the values at the threshold
            y_capped = np.clip(y_array, -threshold, threshold)
            # Set values that were above or below the threshold
            y_capped = np.where(y_array > threshold, capped_value, y_capped)  # For values above the threshold
            y_capped = np.where(y_array < -threshold, -capped_value, y_capped)  # For values below the threshold
            
            # Update max_change and min_change
            max_change = max(max_change, np.nanmax(y))  
            min_change = min(min_change, np.nanmin(y)) 
            
            # Set plotting of bgws
            x = range(len(display_vars))
            bgws_val = model_data['bgws'].sel(region=region_idx).values if 'bgws' in model_data.data_vars else None
            linestyle = '-' if bgws_val is not None and bgws_val >= 0 else '--'
            color = bgws_cm(norm(bgws_val))
            
            # Plot all models
            if model != "Ensemble mean":
                ax.plot(x, y_capped, linestyle=linestyle, color=color, linewidth=1.0, alpha=0.6, zorder=1)
                for xi, yi in zip(x, y_capped):
                    if not np.isnan(yi):
                        ax.text(xi, yi, str(idx + 1), ha='center', va='center', fontsize=16, color=color, fontweight='bold', zorder=2)
            else: # Plot Ensemble 
                ax.scatter(x, y_capped, marker='D', edgecolors='red', s=150, label=model, zorder=3, facecolors='none', lw=2)
        
        
        # Set plot titles and labels
        ax.set_title(ensemble.names.sel(region=region_idx).values, fontsize=20)
        ax.set_xticks(range(len(display_vars)))
        ax.set_xticklabels(display_vars, rotation=45, fontsize=18, ha='right')
        y_label = 'End of century response [%]' if change == 'rel_change' else 'End of century response'
        ax.set_ylabel(y_label, fontsize=18)
        
        # Set y-axis limits dynamically based on data spread
        lower_limit = max(-threshold, min_change) 
        upper_limit = min(threshold, max_change) 

        # Add 10 on y-axis if change exceeds threshold
        extra_space = 10  # 10 percent extra space
        
        # Plot dashed line, extent and add arrows to y-axis if change > threshold
        # Max Change
        if max_change > threshold:
            ax.axhline(threshold, color='grey', linestyle='--', linewidth=0.5, zorder=1)
            ax.plot(0, threshold+10, "^k", transform=ax.get_yaxis_transform(), clip_on=False)

        # Min Change    
        #if min_change < -threshold:
            ax.axhline(-threshold, color='grey', linestyle='--', linewidth=0.5, zorder=1)
            ax.plot(0, -threshold-10, "vk", transform=ax.get_yaxis_transform(), clip_on=False)
            
        # Set y-limits 
        ax.set_ylim(lower_limit - extra_space, upper_limit + extra_space)
        ax.set_ylim(lower_limit * 1.1, upper_limit * 1.1)
        
        # Set a buffer for visual clarity
        buffer = 1
        
        # Adjust the y-axis limits based on your data range
        #ax.set_ylim(-5, 5)


        # Define the tick positions, ensuring they do not go beyond the max/min values or thresholds
        # Calculate and set y-axis ticks
        y_ticks = calculate_ticks(lower_limit, upper_limit)
        ax.set_yticks(y_ticks)
        
        ax.axhline(0, color='grey', linewidth=0.5)
        ax.tick_params(axis='x', length=0)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(True)
        ax.yaxis.set_tick_params(labelsize=14)
        
        for var_index in range(len(display_vars)):
            ax.axvline(x=var_index, color='gray', linestyle='--', linewidth=0.5)
    
    # Caption and figure saving
    caption = f"{description} (2071-2100) - Historical (1985-2014)"
    fig.text(0.52, 1.04, caption, ha='center', va='top', fontsize=45, wrap=True, weight='bold')
    
    # Layout adjustments and legend
    plt.subplots_adjust(wspace=0.4, hspace=0.7)
    
    # Remove the empty plot in the last row and third column
    axes.flatten()[-1].remove()
    
     # Define positions for the legend and colorbar
    legend_position = [0.8, 0.04, 0.3, 0.15]  # [x0, y0, width, height]
    colorbar_position = [0.905, 0.055, 0.09, 0.015]

    # Add legend directly to the figure
    legend_ax = fig.add_axes(legend_position, frame_on=False)
    legend_elements = [plt.Line2D([0], [0], marker='D', markeredgecolor='red', markerfacecolor='none', label='Ensemble mean', markersize=10, linestyle='None', lw=2)]
    for idx, model in enumerate(models):
        if model != "Ensemble mean":
            legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='w', label=f"{idx + 1}: {model}", markersize=10))
    
    legend = legend_ax.legend(handles=legend_elements, fontsize=14, ncol=2, loc='center')
    legend_ax.axis('off')

    # Add the colorbar below the legend
    cbar_ax = fig.add_axes(colorbar_position)
    cbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=bgws_cm), cax=cbar_ax, orientation='horizontal', extend='both')
    cbar.set_label("$\Delta$ Blue-Green Water Share", fontsize=16, weight='bold')

    # Define tick locations
    tick_locs = [-0.1, -0.075, -0.05, -0.025, 0, 0.025, 0.05, 0.075, 0.1]

    # Define tick labels
    tick_labels = ["-0.1", "-0.075", "-0.05", "-0.025", "0", "0.025", "0.05", "0.075", "0.1"]

    # Set the ticks and tick labels
    cbar.set_ticks(tick_locs)
    cbar.set_ticklabels(tick_labels)

    cbar.ax.tick_params(labelsize=14)
    
    # Adjust layout to prevent overlap
    plt.tight_layout()
    
    # Save figure
    if save_fig:
        savepath = os.path.join('../..', 'results', 'CMIP6', 'comparison', 'regional_var_change')
        os.makedirs(savepath, exist_ok=True)
        filename = f'regional_var_{change}_{experiment_id}_{months}_{yearly_sum}.pdf'
        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=600, bbox_inches='tight', format='pdf')
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'

    return filepath

In [None]:
def plot_custom_region_change(ds_dict_change, selected_indices="ALL", selected_vars=None, y_axis_limits=(-3, 3), save_fig=True):
    
    # Compute ensemble
    ds_dict_change = compute_ensemble(ds_dict_change)
    
    # Determine the type of change: relative or absolute
    change = 'rel_change' if ds_dict_change[list(ds_dict_change.keys())[0]]['pr'].units == '%' else 'abs_change'

    # Extract models and variables information
    models = list(ds_dict_change.keys())
    ensemble = ds_dict_change['Ensemble mean']
    variables = selected_vars if selected_vars else [var for var in ensemble.data_vars.keys() if var not in ['bgws']]
    selected_indices = ensemble.region.values.tolist() if selected_indices == "ALL" else selected_indices
    experiment_id = ds_dict_change[list(ds_dict_change.keys())[0]].experiment_id
    description = ds_dict_change[list(ds_dict_change.keys())[0]].description
    # Check for 'period' and 'yearly_sum' attributes, set defaults if not found
    months = ds_dict_change[list(ds_dict_change.keys())[0]].months
    yearly_sum = ds_dict_change[list(ds_dict_change.keys())[0]].yearly_sum
    
    # Create a map for variable display names
    var_map = {
        'tas': 'T', 'vpd': 'VPD', 'gpp': 'GPP', 'pr': 'P', 'mrro': 'R',
        'evspsbl': 'ET', 'tran': 'T', 'evapo': 'E', 'lai': 'Lai', 'mrso': 'SM', 
        'rgtr':'P/T', 'et_partitioning': 'EP', 'growing_season_length': 'GS', 'RX5day': 'P5'
    }
    display_vars = [
        f"$\Delta\, \mathrm{{\it{{{var_map[var]}}}}}$" if var in var_map else var for var in variables
    ]

    # Setup plot grid
    nregions = len(selected_indices)
    ncols = 9
    nrows = (nregions + ncols - 1) // ncols 
    
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(8 * ncols, 7 * nrows))
       
    # bgws_colormap
    bgws_vals = np.array([ds_dict_change[model]['bgws'].values for model in models])
    bgws_min, bgws_max = np.nanmin(bgws_vals), np.nanmax(bgws_vals)
    norm = plt.Normalize(vmin=-0.1, vmax=0.1)
        
    # Iterate over each region to plot data
    for ridx, region_idx in enumerate(selected_indices):
        ax = axes.flatten()[ridx]
        
        # Collect all y-values to determine if any exceed the threshold
        all_y_vals = []
        
        for idx, model in enumerate(models):
            model_data = ds_dict_change[model]
            y = [
                model_data[var].sel(region=region_idx).values
                if var in ds_dict_change[model].data_vars else float('nan')
                for var in variables
            ]
            all_y_vals.extend(y)
            
            # Ensure y is a NumPy array for element-wise comparison
            y_array = np.array(y)
            
            # Set plotting of bgws
            x = range(len(display_vars))
            bgws_val = model_data['bgws'].sel(region=region_idx).values if 'bgws' in model_data.data_vars else None
            linestyle = '-' if bgws_val is not None and bgws_val >= 0 else '--'
            color = bgws_cm(norm(bgws_val))
            
            # Plot all models
            if model != "Ensemble mean":
                ax.plot(x, y, linestyle=linestyle, color=color, linewidth=1.0, alpha=0.6, zorder=1)
                for xi, yi in zip(x, y):
                    if not np.isnan(yi):
                        ax.text(xi, yi, str(idx + 1), ha='center', va='center', fontsize=16, color=color, fontweight='bold', zorder=2)
            else: # Plot Ensemble 
                ax.scatter(x, y, marker='D', edgecolors='red', s=150, label=model, zorder=3, facecolors='none', lw=2)
        
        
        # Set plot titles and labels
        ax.set_title(ensemble.names.sel(region=region_idx).values, fontsize=20)
        ax.set_xticks(range(len(display_vars)))
        ax.set_xticklabels(display_vars, rotation=45, fontsize=18, ha='right')
        y_label = 'End of century response [%]' if change == 'rel_change' else 'End of century response'
        ax.set_ylabel(y_label, fontsize=18)
        
        # Adjust the y-axis limits based on your data range
        #ax.set_ylim(y_axis_limits)

        # Define the tick positions, ensuring they do not go beyond the max/min values or thresholds
        # Calculate and set y-axis ticks
        #y_ticks = calculate_ticks(y_axis_limits)
        #ax.set_yticks(y_ticks)
        
        ax.axhline(0, color='grey', linewidth=0.5)
        ax.tick_params(axis='x', length=0)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(True)
        ax.yaxis.set_tick_params(labelsize=14)
        
        for var_index in range(len(display_vars)):
            ax.axvline(x=var_index, color='gray', linestyle='--', linewidth=0.5)
    
    # Caption and figure saving
    caption = f"{description} (2071-2100) - Historical (1985-2014)"
    fig.text(0.52, 1.04, caption, ha='center', va='top', fontsize=45, wrap=True, weight='bold')
    
    # Layout adjustments and legend
    plt.subplots_adjust(wspace=0.4, hspace=0.7)
    
    # Remove the empty plot in the last row and third column
    axes.flatten()[-1].remove()
    
     # Define positions for the legend and colorbar
    legend_position = [0.8, 0.04, 0.3, 0.15]  # [x0, y0, width, height]
    colorbar_position = [0.905, 0.055, 0.09, 0.015]

    # Add legend directly to the figure
    legend_ax = fig.add_axes(legend_position, frame_on=False)
    legend_elements = [plt.Line2D([0], [0], marker='D', markeredgecolor='red', markerfacecolor='none', label='Ensemble mean', markersize=10, linestyle='None', lw=2)]
    for idx, model in enumerate(models):
        if model != "Ensemble mean":
            legend_elements.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='w', label=f"{idx + 1}: {model}", markersize=10))
    
    legend = legend_ax.legend(handles=legend_elements, fontsize=14, ncol=2, loc='center')
    legend_ax.axis('off')

    # Add the colorbar below the legend
    cbar_ax = fig.add_axes(colorbar_position)
    cbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=bgws_cm), cax=cbar_ax, orientation='horizontal', extend='both')
    cbar.set_label("$\Delta$ Blue-Green Water Share", fontsize=16, weight='bold')

    # Define tick locations
    tick_locs = [-0.1, -0.075, -0.05, -0.025, 0, 0.025, 0.05, 0.075, 0.1]

    # Define tick labels
    tick_labels = ["-0.1", "-0.075", "-0.05", "-0.025", "0", "0.025", "0.05", "0.075", "0.1"]

    # Set the ticks and tick labels
    cbar.set_ticks(tick_locs)
    cbar.set_ticklabels(tick_labels)

    cbar.ax.tick_params(labelsize=14)
    
    # Adjust layout to prevent overlap
    plt.tight_layout()
    
    # Save figure
    if save_fig:
        savepath = os.path.join('../..', 'results', 'CMIP6', 'comparison', 'regional_var_change')
        os.makedirs(savepath, exist_ok=True)
        filename = f'regional_var_{change}_{experiment_id}_{months}_{yearly_sum}.pdf'
        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=600, bbox_inches='tight', format='pdf')
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'

    return filepath

In [None]:
#plot_region_change(ds_dict_region_change_ssp126, selected_indices="ALL", save_fig=False)

In [None]:
# Assuming ds_dict_region_change_ssp370 is your dictionary of datasets

# Create an empty dictionary to store the selected region data for each model
selected_region_data = {}

# Loop through each model in the original dictionary
for model, ds in ds_dict_region_change_ssp370.items():
    # Select only data for region 0
    selected_region_data[model] = ds.isel(region=slice(0, 1))

In [None]:
selected_region_data.keys()

In [None]:
selected_region_data['BCC-CSM2-MR']

In [None]:
selected_vars = ['tas', 'vpd', 'RX5day', 'et_partitioning', 'pr']

In [None]:
def determine_change_type(ds_dict):
    return 'rel_change' if ds_dict[list(ds_dict.keys())[0]]['pr'].units == '%' else 'abs_change'

def extract_variables(ds_dict, selected_vars):
    ensemble = ds_dict['Ensemble mean']
    variables = [var for var in ensemble.data_vars.keys() if var not in ['bgws', 'region', 'abbrevs', 'names', 'member_id']] if selected_vars is None else [var for var in selected_vars if var in ensemble.data_vars.keys()]
    experiment_id = ds_dict[list(ds_dict.keys())[0]].experiment_id
    description = ds_dict[list(ds_dict.keys())[0]].description
    months = ds_dict[list(ds_dict.keys())[0]].months
    yearly_sum = ds_dict[list(ds_dict.keys())[0]].yearly_sum
    return variables, experiment_id, description, months, yearly_sum

def prepare_display_variables(variables):
    var_map = {
        'tas': 'T', 'vpd': 'VPD', 'gpp': 'GPP', 'pr': 'P', 'mrro': 'R',
        'evspsbl': 'ET', 'tran': 'T', 'evapo': 'E', 'lai': 'Lai', 'mrso': 'SM', 
        'rgtr':'P/T', 'et_partitioning': 'EP', 'growing_season_length': 'GS', 'RX5day': 'P5'
    }
    return [f"$\Delta\, \mathrm{{\it{{{var_map[var]}}}}}$" if var in var_map else var for var in variables]



def setup_plot_grid(selected_indices):
    nrows = len(selected_indices) // 9 + (len(selected_indices) % 9 > 0)
    fig, axes = plt.subplots(nrows=nrows, ncols=9, figsize=(8 * 9, 7 * nrows), squeeze=False)
    return fig, axes

def plot_models(axes, ds_dict, models, display_vars, selected_indices, threshold, capped_value):
    # Your existing logic for plotting the models
    pass

def format_axes(axes, ensemble, region_idx, display_vars, change, threshold, max_change, min_change):
    # Your existing logic for setting titles, labels, limits, ticks
    pass

def add_legend_and_colorbar(fig, models):
    # Your existing logic for adding legends and colorbars
    pass

def save_figure(fig, change, experiment_id, months, yearly_sum, save_fig):
    if save_fig:
        savepath = os.path.join('results', 'CMIP6', 'comparison', 'regional_var_change')
        os.makedirs(savepath, exist_ok=True)
        filename = f'regional_var_{change}_{experiment_id}_{months}_{yearly_sum}.pdf'
        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=600, bbox_inches='tight', format='pdf')
        return filepath
    else:
        return 'Figure not saved. If you want to save the figure add save_fig=True to the function call'

def plot_region_change(ds_dict_change, selected_indices="ALL", selected_vars=None, save_fig=True):
    ds_dict_change = compute_ensemble(ds_dict_change)
    change = determine_change_type(ds_dict_change)
    models = list(ds_dict_change.keys())
    
    variables, experiment_id, description, months, yearly_sum = extract_variables(ds_dict_change, selected_vars)
    display_vars = prepare_display_variables(variables)
    selected_indices = ds_dict_change['Ensemble mean'].region.values.tolist() if selected_indices == "ALL" else selected_indices
    fig, axes = setup_plot_grid(selected_indices)
    
    # Plotting logic...
    plot_models(axes, ds_dict_change, models, display_vars, selected_indices, 100, 105)
    
    # Formatting axes...
    # You would loop over your axes and apply formatting using the format_axes function
    
    # Add legend and colorbar
    add_legend_and_colorbar(fig, models)
    
    plt.tight_layout()  # Adjust layout
    filepath = save_figure(fig, change, experiment_id, months, yearly_sum, save_fig)
    return filepath

# Call the main function with your dataset dictionary
# plot_region_change(ds_dict_change)

In [None]:
selected_vars = ['tas', 'vpd', 'RX5day', 'et_partitioning', 'pr', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo', 'growing_season_length', 'wue']

In [None]:
selected_vars = ['tas', 'vpd', 'RX5day', 'et_partitioning']#['pr', 'mrro', 'mrso', 'tran', 'lai', 'gpp', 'evspsbl', 'evapo', 'growing_season_length', 'wue', 'rgtr']

In [None]:
plot_region_change(ds_dict_region_change_ssp370, selected_indices='ALL', selected_vars=selected_vars, save_fig=False)

In [None]:
plot_region_change(ds_dict_region_change_ssp585, selected_indices="ALL", save_fig=True)

In [None]:
# Clustern und bgws nicht den change plotten, da ich den plot sonst sehr schwer zu 
# verstehen finde. Grün kann nämich bedeuten, dass sehr positive (runoff dominiert) bgws regionen etwas
# negativer werden. Eventuell doch absolute veränderung plotten. Gleich größe relative Abnahme von runoff 
# und Transpiration können bedeuten, dass runoff viel mehr abnimmt in absoluten Zahlen und deswegen
# BGWS abnimmt (grün)

### Cluster Regions based on var change

In [None]:
from sklearn.cluster import KMeans

In [None]:
def custom_distance(point1, point2):
    """
    Custom distance metric that gives more weight to the sign.
    Compares the sign of each dimension and penalizes the distance if the signs are different.
    """
    sign_difference = np.sign(point1) != np.sign(point2)  # Array of sign differences for each dimension
    sign_weight = 2 if np.any(sign_difference) else 1    # If any sign difference, weight is 2, else 1
    return np.linalg.norm(point1 - point2) * sign_weight

def custom_kmeans(data, n_clusters, max_iter=100):
    # Randomly initialize cluster centers
    centers = data[np.random.choice(data.shape[0], n_clusters, replace=False)]

    for _ in range(max_iter):
        # Assign points to nearest cluster
        clusters = np.array([np.argmin([custom_distance(x, center) for center in centers]) for x in data])

        # Recalculate centers
        new_centers = np.array([data[clusters == k].mean(axis=0) for k in range(n_clusters)])

        # Break if centers do not change
        if np.all(centers == new_centers):
            break
        centers = new_centers

    return clusters, centers

In [None]:
def cluster_regions(ds_dict, max_clusters=40):
    """
    This function clusters regions based on the sign of their numeric variables across multiple models,
    and returns a DataFrame with region names and their corresponding cluster numbers.

    :param ds_dict: A dictionary of xarray.Dataset objects, each representing a different model.
    :param max_clusters: The maximum number of clusters to test for the elbow method.
    :return: A pandas DataFrame with columns 'Region' and 'Cluster'.
    """
    
    # Convert the xarray Datasets to pandas DataFrames and concatenate them
    dfs = []
    for model, ds in ds_dict.items():
        df = ds.to_dataframe().reset_index()
        df = df.set_index('region')
        # Select only numeric columns and convert values to -1 (negative) or 1 (positive)
        numeric_cols = df.select_dtypes(include=[float, int]).columns
        df = df[numeric_cols].apply(np.sign)
        df.columns = [f"{var}_{model}" for var in df.columns]  # Rename columns to include model name
        dfs.append(df)
    combined_df = pd.concat(dfs, axis=1)

    # Drop columns with NaN values
    combined_df = combined_df.dropna(axis=1)

    # Find the optimal number of clusters using the elbow method
    sum_of_squared_distances = []
    K = range(1, max_clusters + 1)
    for k in K:
        km = KMeans(n_clusters=k, n_init=10, random_state=42)
        km = km.fit(combined_df)
        sum_of_squared_distances.append(km.inertia_)

    # Plot the elbow curve
    plt.plot(K, sum_of_squared_distances, 'bx-')
    plt.xlabel('k')
    plt.ylabel('Sum of squared distances')
    plt.title('Elbow Method For Optimal k')
    plt.show()

    # Ask user for the optimal number of clusters
    n_clusters = int(input("Enter the optimal number of clusters: "))
    
    # Use the KMeans algorithm to find clusters with the optimal number of clusters
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
    kmeans.fit(combined_df)

    # Extract region names from one of the datasets (assuming all datasets have the same region names)
    region_names = list(ds_dict.values())[0].names.values

    # Create a DataFrame with region names and their assigned clusters
    clustered_df = pd.DataFrame({
        'Region': region_names,
        'Cluster': kmeans.labels_
    }).sort_values('Cluster')

    return clustered_df

In [None]:
from sklearn.preprocessing import StandardScaler

def cluster_regions(ds_dict, max_clusters=40):
    """
    This function clusters regions based on their numeric variables across multiple models,
    standardizes the data, and returns a DataFrame with region names and their corresponding cluster numbers.

    :param ds_dict: A dictionary of xarray.Dataset objects, each representing a different model.
    :param max_clusters: The maximum number of clusters to test for the elbow method.
    :return: A pandas DataFrame with columns 'Region' and 'Cluster'.
    """
    
    # Convert the xarray Datasets to pandas DataFrames and concatenate them
    dfs = []
    for model, ds in ds_dict.items():
        df = ds.to_dataframe().reset_index()
        df = df.set_index('region')
        # Select only numeric columns
        numeric_cols = df.select_dtypes(include=[float, int]).columns
        df = df[numeric_cols]
        df.columns = [f"{var}_{model}" for var in df.columns]  # Rename columns to include model name
        dfs.append(df)
    combined_df = pd.concat(dfs, axis=1)

    # Drop columns with NaN values
    combined_df = combined_df.dropna(axis=1)

    # Standardize the data
    scaler = StandardScaler()
    scaled_data = scaler.fit_transform(combined_df)

    # Find the optimal number of clusters using the elbow method
    sum_of_squared_distances = []
    K = range(1, max_clusters + 1)
    for k in K:
        km = KMeans(n_clusters=k,  n_init=10, random_state=42)
        km = km.fit(scaled_data)
        sum_of_squared_distances.append(km.inertia_)

    # Plot the elbow curve
    plt.plot(K, sum_of_squared_distances, 'bx-')
    plt.xlabel('k')
    plt.ylabel('Sum of squared distances')
    plt.title('Elbow Method For Optimal k')
    plt.show()

    # Ask user for the optimal number of clusters
    n_clusters = int(input("Enter the optimal number of clusters: "))
    
    # Use the KMeans algorithm to find clusters with the optimal number of clusters
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42)
    kmeans.fit(scaled_data)

    # Extract region names from one of the datasets (assuming all datasets have the same region names)
    region_names = list(ds_dict.values())[0].names.values

    # Create a DataFrame with region names and their assigned clusters
    clustered_df = pd.DataFrame({
        'Region': region_names,
        'Cluster': kmeans.labels_
    }).sort_values('Cluster')

    return clustered_df

In [None]:
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

def cluster_regions_gmm(ds_dict, max_clusters=40):
    """
    Clusters regions using a Gaussian Mixture Model based on their numeric variables across multiple models.
    Returns a DataFrame with region names and their corresponding cluster numbers.

    :param ds_dict: A dictionary of xarray.Dataset objects, each representing a different model.
    :param max_clusters: The maximum number of clusters to test for the elbow method.
    :return: A pandas DataFrame with columns 'Region' and 'Cluster'.
    """
    
    # Convert the xarray Datasets to pandas DataFrames and concatenate them
    dfs = []
    for model, ds in ds_dict.items():
        df = ds.to_dataframe().reset_index()
        df = df.set_index('region')
        # Select only numeric columns and convert values to -1 (negative) or 1 (positive)
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        df = df[numeric_cols].apply(np.sign)
        df.columns = [f"{var}_{model}" for var in df.columns]  # Rename columns to include model name
        dfs.append(df)
    combined_df = pd.concat(dfs, axis=1)

    # Drop columns with NaN values
    combined_df = combined_df.dropna(axis=1)

    # Find the optimal number of clusters using Gaussian Mixture Model
    bic_scores = []
    K = range(1, max_clusters + 1)
    for k in K:
        gmm = GaussianMixture(n_components=k, n_init=10, random_state=42)
        gmm = gmm.fit(combined_df)
        bic_scores.append(gmm.bic(combined_df))

    # Plot the BIC scores
    plt.plot(K, bic_scores, 'bx-')
    plt.xlabel('k (number of components)')
    plt.ylabel('BIC score')
    plt.title('BIC Scoring for Gaussian Mixture Model')
    plt.show()

    # Ask user for the optimal number of clusters
    n_clusters = int(input("Enter the optimal number of clusters: "))

    # Use the Gaussian Mixture Model to find clusters with the optimal number of clusters
    gmm = GaussianMixture(n_components=n_clusters, n_init=10, random_state=42)
    gmm = gmm.fit(combined_df)

    # Predict the cluster for each region
    cluster_labels = gmm.predict(combined_df)

    # Extract region names from one of the datasets (assuming all datasets have the same region names)
    region_names = list(ds_dict.values())[0].names.values

    # Create a DataFrame with region names and their assigned clusters
    clustered_df = pd.DataFrame({
        'Region': region_names,
        'Cluster': cluster_labels
    }).sort_values('Cluster')

    return clustered_df

In [None]:
from sklearn.preprocessing import StandardScaler

def cluster_regions(ds_dict, max_clusters=40):
    """
    This function clusters regions based on their numeric variables across multiple models,
    standardizes the data, and returns a DataFrame with region names and their corresponding cluster numbers.

    :param ds_dict: A dictionary of xarray.Dataset objects, each representing a different model.
    :param max_clusters: The maximum number of clusters to test for the elbow method.
    :return: A pandas DataFrame with columns 'Region' and 'Cluster'.
    """
    
    # Convert the xarray Datasets to pandas DataFrames and concatenate them
    dfs = []
    for model, ds in ds_dict.items():
        df = ds.to_dataframe().reset_index()
        df = df.set_index('region')
        # Select only numeric columns
        numeric_cols = df.select_dtypes(include=[float, int]).columns
        df = df[numeric_cols]
        df.columns = [f"{var}_{model}" for var in df.columns]  # Rename columns to include model name
        dfs.append(df)
    combined_df = pd.concat(dfs, axis=1)

    # Drop columns with NaN values
    combined_df = combined_df.dropna(axis=1)

    # Standardize the data
    scaler = StandardScaler()
    scaled_data = scaler.fit_transform(combined_df)

    # Find the optimal number of clusters using the elbow method
    sum_of_squared_distances = []
    K = range(1, max_clusters + 1)
    for k in K:
        _, centers = custom_kmeans(scaled_data, n_clusters=k)
        distances = [min([np.linalg.norm(data_point - center) for center in centers]) for data_point in scaled_data]
        sum_of_squared_distances.append(np.sum(np.square(distances)))

    # Plot the elbow curve
    plt.plot(K, sum_of_squared_distances, 'bx-')
    plt.xlabel('k')
    plt.ylabel('Sum of squared distances')
    plt.title('Elbow Method For Optimal k')
    plt.show()

    # Ask user for the optimal number of clusters
    n_clusters = int(input("Enter the optimal number of clusters: "))
    
    # Use the custom KMeans algorithm to find clusters with the optimal number of clusters
    kmeans_labels, _ = custom_kmeans(scaled_data, n_clusters=n_clusters)

    # Extract region names from one of the datasets (assuming all datasets have the same region names)
    region_names = list(ds_dict.values())[0].names.values

    # Create a DataFrame with region names and their assigned clusters
    clustered_df = pd.DataFrame({
        'Region': region_names,
        'Cluster': kmeans_labels
    }).sort_values('Cluster')

    return clustered_df

In [None]:
clustered_df = cluster_regions_gmm(ds_dict_region_change_ssp126, max_clusters=40)

In [None]:
clustered_df = cluster_regions(ds_dict_region_change_ssp126, max_clusters=40)

In [None]:
clustered_df

### Compute global median change

In [None]:
# Compute Change
ds_dict_change = compute_change(ds_dict_hist_period_metric, ds_dict_ssp126_period_metric, var_rel_change=None)

In [None]:
# Compute ensemble median for the change
ds_dict_change_ensemble = compute_ensemble(ds_dict_change)

In [None]:
# Only use ensemble median for further analysis
ds_ensmed_glob = ds_dict_change_ensemble['Ensemble median']

In [None]:
ds_ensmed_glob

In [None]:
ds_dict_region_change_ssp126_ensemble = compute_ensemble(ds_dict_region_change_ssp126)

In [None]:
ds = ds_dict_region_change_ssp126_ensemble['Ensemble median']

In [None]:
ds

### Plot clusters on map¶

In [None]:
import cartopy.crs as ccrs
from cartopy.feature import ShapelyFeature
import regionmask
from shapely.geometry.polygon import Polygon
from shapely.geometry.multipolygon import MultiPolygon

In [None]:
from shapely.geometry import MultiPolygon, Polygon

def split_polygon(polygon, meridian=180):
    """
    Splits a Shapely polygon into two polygons at a specified meridian
    """
    minx, miny, maxx, maxy = polygon.bounds
    if maxx > meridian and minx < -meridian:
        # Polygon crosses the antimeridian
        left_poly = []
        right_poly = []
        for x, y in polygon.exterior.coords:
            if x >= meridian:
                right_poly.append((x - 360, y))  # Wraparound for the right side
            else:
                left_poly.append((x, y))
        return [Polygon(left_poly), Polygon(right_poly)]
    else:
        return [polygon]  # Wrap the single polygon in a list for consistency

In [None]:
def plot_clusters_and_regions(ds_ensmed_glob, ds, clustered_df):
    # Initialize the plot with a cartopy projection
    fig = plt.figure(figsize=(30, 15))
    ax_main = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
    
    # Plot the 'bgws' variable from the dataset
    ds_ensmed_glob['bgws'].plot(ax=ax_main, vmin=-0.2, vmax=0.2, cmap=bgws_cm, transform=ccrs.PlateCarree(), add_colorbar=False)
    
    # Add coastlines and gridlines
    ax_main.coastlines()
    ax_main.tick_params(axis='both', which='major', labelsize=20)
    gridlines = ax_main.gridlines(draw_labels=True, color='black', alpha=0.2, linestyle='--')
    gridlines.top_labels = gridlines.right_labels = False
    gridlines.xlabel_style = {'size': 18}
    gridlines.ylabel_style = {'size': 18}
    
    # Get region bounds using regionmask
    land_regions = regionmask.defined_regions.ar6.land
    
    # Create a mapping from region names to region numbers
    region_name_to_number = dict(zip(ds.names.values, ds.region.values))
    
    # Map the region names to the same order as ds.names.values
    region_to_cluster_map = dict(zip(clustered_df['Region'], clustered_df['Cluster']))

    # Now create a new column in ds that maps the region names to cluster numbers
    # This assumes ds.names.values has the same region names as in clustered_df['Region']
    ds['Cluster'] = [region_to_cluster_map[name] for name in ds.names.values]

    # Convert the 'Cluster' DataArray to a NumPy array and get unique values
    unique_clusters = np.unique(ds['Cluster'].values)

    # Prepare colors for clusters - this assumes a finite number of clusters
    cluster_colors = plt.cm.tab20b(np.linspace(0, 1, len(unique_clusters)))


    # Loop over the regions and plot the cluster numbers with abbreviations
    for region_name, cluster_number in zip(ds.names.values, ds['Cluster']):
        reg_num = region_name_to_number[region_name]
        region_polygons = land_regions[reg_num].polygon
        
        region_abbr = ds.abbrevs.values[ds.region.values == reg_num][0]  # Assuming this gives us the correct abbreviation
        cluster_color = cluster_colors[cluster_number]  # Get the color for the cluster
        
        # Fetch the polygon or polygons for this region
        region_obj = land_regions[reg_num]
        if hasattr(region_obj, 'polygons'):
            # If the attribute is 'polygons', we assume it's iterable (e.g., a list of Polygon objects)
            region_polygons = region_obj.polygons
        elif hasattr(region_obj, 'polygon'):
            # If there's only one Polygon, we wrap it in a list to make it iterable
            region_polygons = [region_obj.polygon]
        else:
            raise AttributeError(f"The region object does not have 'polygons' or 'polygon' attribute.")

        for region_polygon in region_polygons:
            # If the polygon crosses the antimeridian, split it
            split_polys = split_polygon(region_polygon)

            # Handle both Polygon and MultiPolygon types after splitting
            for poly in split_polys:
                if isinstance(poly, Polygon):
                    features_to_plot = [poly]
                elif isinstance(poly, MultiPolygon):
                    features_to_plot = list(poly.geoms)
                else:
                    raise TypeError(f"Unhandled geometry type: {type(poly)}")

                for feature_poly in features_to_plot:
                    feature = ShapelyFeature([feature_poly], ccrs.PlateCarree(), edgecolor=cluster_color, facecolor='none', linewidth=2)
                    ax_main.add_feature(feature)

            # Calculate the centroid for text placement using features_to_plot
            centroids = [feature_poly.centroid for feature_poly in features_to_plot]
    
            text_lon, text_lat = max(centroids, key=lambda c: c.x).coords[0]  # Use the easternmost centroid
        
            # Ensure cluster_number is a plain integer if it's a single-item array or DataArray
            if isinstance(cluster_number, np.ndarray) and cluster_number.size == 1:
                cluster_number = cluster_number.item()  # Converts a one-element array to a scalar
            elif isinstance(cluster_number, xr.DataArray) and cluster_number.ndim == 0:
                cluster_number = cluster_number.values.item()  # Gets the scalar value from a 0-dim DataArray

            # Annotate the cluster number for each region
            ax_main.text(text_lon, text_lat, f"{region_abbr}\n{cluster_number}",
                         horizontalalignment='center', verticalalignment='center', transform=ccrs.PlateCarree(),
                         fontsize=20, bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'))

    plt.show()

In [None]:
plot_clusters_and_regions(ds_ensmed_glob, ds, clustered_df)

In [None]:
def calculate_ensemble_average_per_cluster(clustered_df, ds_dict):
    """
    Calculate the ensemble average for each variable across all models, grouped by cluster,
    and append the region names corresponding to each cluster.

    :param clustered_df: DataFrame with regions and their corresponding cluster assignments.
    :param ds_dict: Dictionary of xarray.Dataset objects with percentage changes for each model.
    :return: A DataFrame with the ensemble mean for each variable for each cluster, along with region names.
    """
    # Aggregate data across all models
    ensemble_data = []

    # Extract all variables from the first dataset as a reference
    reference_ds = list(ds_dict.values())[0]
    variables = list(reference_ds.data_vars)

    for var in variables:
        # Check if the variable exists in all models
        if all(var in ds.data_vars for ds in ds_dict.values()):
            var_all_models = [ds[var].to_dataframe().reset_index()[['region', var]] for ds in ds_dict.values()]
            var_combined = pd.concat(var_all_models, axis=0)
            ensemble_mean = var_combined.groupby('region').mean()
            ensemble_data.append(ensemble_mean)

    # Combine the ensemble means for all variables
    combined_ensemble = pd.concat(ensemble_data, axis=1)

    # Merge with cluster assignments using the region index
    combined_ensemble.reset_index(inplace=True)
    combined_ensemble = combined_ensemble.merge(clustered_df, left_on='region', right_index=True)

    # Group by cluster and calculate the mean for each variable, excluding 'region'
    final_stats_df = combined_ensemble.drop(columns=['region']).groupby('Cluster').mean(numeric_only=True)

    # Append region names to each cluster
    region_names = reference_ds.abbrevs.values
    clustered_regions = clustered_df.reset_index().groupby('Cluster')['index'].apply(
        lambda x: [region_names[i] for i in x])
    final_stats_df['Regions'] = clustered_regions
    
    return final_stats_df

In [None]:
average_changes_per_cluster = calculate_ensemble_average_per_cluster(clustered_df, ds_dict_region_change_ssp126)

In [None]:
average_changes_per_cluster

### Plot Cluster Spread

In [None]:
import seaborn as sns

In [None]:
def plot_cluster_distributions(clustered_df, ds_dict):
    combined_data = []
    for model, ds in ds_dict.items():
        df = ds.to_dataframe().reset_index()
        df['model'] = model  # Add a column to identify the model
        combined_data.append(df)
    combined_df = pd.concat(combined_data)
    
    # Create a mapping from region indices to names
    region_index_to_name = list(ds_dict_region_change_ssp126.values())[0].names.to_dict()

    # Melt the dataframe to long-form
    long_df = combined_df.melt(id_vars=['region', 'abbrevs', 'names', 'model'],
                               var_name='Variable', value_name='Value')
    # Merge with cluster assignments
    long_df = long_df.merge(clustered_df, left_on='names', right_on='Region')

    # Number of unique clusters
    num_clusters = clustered_df['Cluster'].nunique()

    # Calculate the layout for subplots (3 columns)
    num_cols = 3
    num_rows = np.ceil(num_clusters / num_cols).astype(int)

    # Create a figure with subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows), squeeze=False)
    axes = axes.flatten()

    # Iterate over each cluster to create a box plot
    for cluster_num in range(num_clusters):
        ax = axes[cluster_num]
        cluster_data = long_df[long_df['Cluster'] == cluster_num]
        sns.boxplot(data=cluster_data, x='Variable', y='Value', ax=ax)
        # Use region abbreviations
        region_abbrevs = cluster_data['abbrevs'].unique()
        ax.set_title(f'Cluster {cluster_num}: {", ".join(region_abbrevs)}')
        ax.tick_params(axis='x', rotation=45)
        
    # Hide any unused subplots
    for i in range(num_clusters, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.show()

In [None]:
# Example usage
plot_cluster_distributions(clustered_df, ds_dict_region_change_ssp126)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

def plot_cluster_distributions(clustered_df, ensemble_ds):
    # Convert the xarray Dataset to a pandas DataFrame
    ensemble_df = ensemble_ds.to_dataframe().reset_index()

    # Multiply 'bgws' by 100
    ensemble_df['bgws'] *= 100

    # Melt the DataFrame to long-form for plotting
    long_df = ensemble_df.melt(id_vars=['region', 'abbrevs', 'names'],
                               var_name='Variable', value_name='Value')

    # Merge with cluster assignments
    long_df = long_df.merge(clustered_df, left_on='names', right_on='Region')

    # Number of unique clusters
    num_clusters = clustered_df['Cluster'].nunique()

    # Calculate the layout for subplots (3 columns)
    num_cols = 3
    num_rows = np.ceil(num_clusters / num_cols).astype(int)

    # Create a figure with subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows), squeeze=False)
    axes = axes.flatten()

    # Define markers for regions
    markers = ['o', 's', 'D', '^', 'v', '<', '>', 'p', '*', 'h', 'H', 'X']  # All filled markers
    unique_abbrevs = long_df['abbrevs'].unique()
    marker_dict = {abbrev: markers[i % len(markers)] for i, abbrev in enumerate(unique_abbrevs)}

    # Iterate over each cluster to create a scatter plot
    for cluster_num in range(num_clusters):
        ax = axes[cluster_num]
        cluster_data = long_df[long_df['Cluster'] == cluster_num]
        sns.scatterplot(data=cluster_data, x='Variable', y='Value', ax=ax, hue='abbrevs', 
                        style='abbrevs', markers=marker_dict, legend='brief')
        # Use region abbreviations
        region_abbrevs = cluster_data['abbrevs'].unique()
        ax.set_title(f'Cluster {cluster_num}: {", ".join(region_abbrevs)}')
        ax.tick_params(axis='x', rotation=45)

        # Setting the legend outside the plot
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Hide any unused subplots
    for i in range(num_clusters, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_cluster_distributions(clustered_df, ensemble_ds):
    # Convert the xarray Dataset to a pandas DataFrame
    ensemble_df = ensemble_ds.to_dataframe().reset_index()

    # Multiply 'bgws' by 100
    ensemble_df['bgws'] *= 100

    # Exclude 'Cluster' from the variables if present
    ensemble_df = ensemble_df.drop(columns=['Cluster'], errors='ignore')

    # Melt the DataFrame to long-form for plotting
    long_df = ensemble_df.melt(id_vars=['region', 'abbrevs', 'names'],
                               var_name='Variable', value_name='Value')

    # Merge with cluster assignments
    long_df = long_df.merge(clustered_df, left_on='names', right_on='Region')

    # Number of unique clusters
    num_clusters = clustered_df['Cluster'].nunique()

    # Calculate the layout for subplots (3 columns)
    num_cols = 3
    num_rows = np.ceil(num_clusters / num_cols).astype(int)

    # Create a figure with subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows), squeeze=False)
    axes = axes.flatten()

    # Define markers for regions
    markers = ['o', 's', 'D', '^', 'v', '<', '>', 'p', '*', 'h', 'H', 'X']  # All filled markers
    unique_abbrevs = long_df['abbrevs'].unique()
    marker_dict = {abbrev: markers[i % len(markers)] for i, abbrev in enumerate(unique_abbrevs)}

    # Iterate over each cluster to create a scatter plot
    for cluster_num in range(num_clusters):
        ax = axes[cluster_num]
        cluster_data = long_df[long_df['Cluster'] == cluster_num]
        sns.scatterplot(data=cluster_data, x='Variable', y='Value', ax=ax, hue='abbrevs', 
                        style='abbrevs', markers=marker_dict, legend='brief')
        # Use region abbreviations
        region_abbrevs = cluster_data['abbrevs'].unique()
        ax.set_title(f'Cluster {cluster_num}: {", ".join(region_abbrevs)}')
        ax.tick_params(axis='x', rotation=45)

        # Setting the legend outside the plot
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Hide any unused subplots
    for i in range(num_clusters, len(axes)):
        axes[i].set_visible(False)

    plt.tight_layout()
    plt.show()

In [None]:
plot_cluster_distributions(clustered_df, ds)

In [None]:
# drop regions that cluster alone:
RAR

In [None]:
# 