# CMIP6 Statistics and Plots

**Following steps are included in this script:**

1. Load netCDF files
2. Compute statistics
3. Plot statistics

### Import Packages

In [None]:
# ========== Packages ==========
import xarray as xr
import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import dask
import matplotlib.cm
from matplotlib import rcParams
import math
import multiprocessing as mp
from cftime import DatetimeNoLeap
import glob
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, pearsonr, kendalltau
from matplotlib.lines import Line2D
from sklearn.metrics import r2_score
import matplotlib.colors as colors
from matplotlib.patches import Patch
import matplotlib.patches as mpatches



%matplotlib inline

rcParams["mathtext.default"] = 'regular'

### Functions

#### Open files

In [None]:
# ========= Helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

In [None]:
# Define a helper function to open and merge datasets
def open_and_merge_datasets(folder, model, experiment_id, variables):
    filepaths = []
    for var in variables:
        path = f'../../data/CMIP6/{experiment_id}/{folder}/{var}'
        fp = glob.glob(os.path.join(path, f'CMIP.{model}.{experiment_id}.{var}_regridded.nc'))
        if fp:
            filepaths.append(fp[0])
        else:
            #print(f"No file found for variable '{var}' in model '{model}'.")
            print(fp)

    datasets = [xr.open_dataset(fp) for fp in filepaths]
    ds = xr.merge(datasets)
    return ds

#### Helper functions

In [None]:
def select_period(ds_dict, start_year=None, end_year=None):
    '''
    Helper function to select periods.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    """
    '''
    start_year = DatetimeNoLeap(start_year, 1, 16, 12, 0, 0, 0,has_year_zero=True) # 16th of January of start year
    end_year = DatetimeNoLeap(end_year, 12, 16, 12, 0, 0, 0, has_year_zero=True) # 16th of December of end year
    ds_dict = {k: v.sel(time=slice(start_year, end_year)) for k, v in ds_dict.items()}
    
    return ds_dict

In [None]:
# ======== Standardize ========
def standardize(ds_dict):
    '''
    Helper function to standardize datasets of a dictionary
    '''
    ds_dict_stand = {}
    
    for name, ds in ds_dict.items():
        attrs = ds.attrs
        ds_stand = (ds - ds.mean()) / ds.std()

        # Preserve variable attributes from the original dataset
        for var in ds.variables:
            if var in ds_stand.variables:
                ds_stand[var].attrs = ds[var].attrs

        ds_stand.attrs = attrs
        ds_dict_stand[name] = ds_stand
        
    return ds_dict_stand

In [None]:
def check_args_and_get_info(ds_dict, variable):
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if not isinstance(variable, str):
        raise TypeError('variable must be a string.')
        
    # Dictionary to store plot titles for each statistic
    titles = {"mean": "Mean", "std": "Standard deviation of yearly means", "min": "Minimum", "max": "Maximum", "median": "Median", "time": "Time", "space": "Space"}
    freq = {"mon": "Monthly"}
    
    long_name = {
        'Precipitation': 'Precipitation',
        'Total Runoff': 'Total Runoff',
        'Vapor Pressure Deficit': 'Vapor Pressure Deficit',
        'Evaporation Including Sublimation and Transpiration': 'Evapotranspiration',
        'Transpiration': 'Transpiration',
        'Leaf Area Index': 'Leaf Area Index',
        'Carbon Mass Flux out of Atmosphere Due to Gross Primary Production on Land [kgC m-2 s-1]': 'Gross Primary Production',
        'Total Liquid Soil Moisture Content of 1 m Column': '1 m Soil Moisture',
        'Total Liquid Soil Moisture Content of 2 m Column': '2 m Soil Moisture',
        'Runoff - Precipitation': 'Runoff - Precipitation',
        'Transpiration - Precipitation': 'Transpiration - Precipitation',
        '(Runoff + Transpiration) - Precipitation':  '(Runoff + Transpiration) - Precipitation',
        'ET - Precipitation':  'ET - Precipitation', 
        'Negative Runoff': 'Negative Runoff',
    }
   
    # Data information
    var_long_name = ds_dict[list(ds_dict.keys())[0]][variable].long_name
    period = f"{ds_dict[list(ds_dict.keys())[0]].attrs['period'][0]}-{ds_dict[list(ds_dict.keys())[0]].attrs['period'][1]}"
    experiment_id =  ds_dict[list(ds_dict.keys())[0]].experiment_id
    unit = ds_dict[list(ds_dict.keys())[0]][variable].units
    statistic_dim = ds_dict[list(ds_dict.keys())[0]].statistic_dimension
    statistic = ds_dict[list(ds_dict.keys())[0]].attrs['statistic']
    frequency = freq[ds_dict[list(ds_dict.keys())[0]].frequency]

    return var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency

In [None]:
def compute_ensemble(ds_dict):
    # Combine all datasets into one larger dataset
    combined = xr.concat(ds_dict.values(), dim='ensemble')
    # Compute the ensemble metric
    ds_dict['Ensemble mean'] = getattr(combined, 'mean')(dim='ensemble')#, skipna=True) # use getattr to call method by string name
    
    # Preserve variable attributes from the original dataset
    for var in ds_dict['Ensemble mean'].variables:
        ds_dict['Ensemble mean'][var].attrs = ds_dict[list(ds_dict.keys())[0]][var].attrs
    
    ds_dict['Ensemble mean'].attrs = {"period" : ds_dict[list(ds_dict.keys())[0]].attrs['period'],
                           "statistic" : "mean", # use variable metric here
                           "statistic_dimension" : "time",
                           "experiment_id": "ssp370-historical", 
                           "source_id" : f"Ensemble mean",
                           "frequency":  ds_dict[list(ds_dict.keys())[0]].attrs['frequency']} 
    
    ds_dict['Ensemble median'] = getattr(combined, 'median')(dim='ensemble')#, skipna=True) # use getattr to call method by string name
    
    ds_dict['Ensemble median'].attrs = {"period" : ds_dict[list(ds_dict.keys())[0]].attrs['period'],
                           "statistic" : "median", # use variable metric here
                           "statistic_dimension" : "time",
                           "experiment_id": "ssp370-historical", 
                           "source_id": f"Ensemble median",
                           "frequency":  ds_dict[list(ds_dict.keys())[0]].attrs['frequency']} 
    
    return ds_dict

#### Compute statistics

In [None]:
def compute_statistic_single(ds, statistic, dimension, yearly_mean=True):
    if dimension == "time":
        stat_ds = getattr(ds, statistic)("time", keep_attrs=True, skipna=True)
        stat_ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
    if dimension == "space":
        # Assign the period attribute before grouping by year
        ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
        if yearly_mean:
            ds = ds.groupby('time.year').mean('time', keep_attrs=True, skipna=True)
            ds.attrs['mean'] = 'yearly mean'
            
        
        #get the weights, apply on data, and compute statistic
        weights = np.cos(np.deg2rad(ds.lat))
        weights.name = "weights"
        ds_weighted = ds.weighted(weights)
        stat_ds = getattr(ds_weighted, statistic)(("lon", "lat"), keep_attrs=True, skipna=True)
    
    stat_ds.attrs['statistic'] = statistic
    stat_ds.attrs['statistic_dimension'] = dimension

    return stat_ds

In [None]:
def compute_statistic(ds_dict, statistic, dimension, start_year=None, end_year=None, yearly_mean=True):
    """
    Computes the specified statistic for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        statistic (str): The statistic to compute, which can be one of 'mean', 'std', 'min', 'var', or 'median'.
        dimension (str): The dimension to compute over, which can be 'time' or 'space'.
        start_year (str, optional): The start year of the period to compute the statistic over.
        end_year (str, optional): The end year of the period to compute the statistic over.

    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if statistic not in ["mean", "std", "min", "max", "var", "median"]:
        raise ValueError(f"Invalid statistic '{statistic}' specified.")
    if dimension not in ["time", "space"]:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")

    # Select period
    if start_year is not None and end_year is not None:
        select_period(ds_dict, start_year=start_year, end_year=end_year)
        
        
    # Use multiprocessing to compute the statistic for each dataset in parallel
    with mp.Pool() as pool:
        results = pool.starmap(compute_statistic_single, [(ds, statistic, dimension, yearly_mean) for ds in ds_dict.values()])

    return dict(zip(ds_dict.keys(), results))

In [None]:
from multiprocessing import Pool
from itertools import combinations

# Define a function to compute the metrics
def compute_metrics_for_pair(args):
    df, var1, var2, metrics = args
    valid_values = np.logical_and(np.isfinite(df[var1]), np.isfinite(df[var2]))

    # Compute metrics
    X = df[var1][valid_values].values.reshape(-1, 1)
    y = df[var2][valid_values].values
    metric_dict = {}

    if 'rmse_rf' in metrics:
        rf = RandomForestRegressor()
        rf.fit(X, y)
        y_pred_rf = rf.predict(X)
        rmse_rf = np.sqrt(mean_squared_error(y, y_pred_rf))
        metric_dict['rmse_rf'] = rmse_rf
    
    if 'rmse_lr' or 'r2_lr' in metrics:
        lr = LinearRegression()
        lr.fit(X, y)
        y_pred_lr = lr.predict(X)
        if 'rmse_lr' in metrics:
            rmse_lr = np.sqrt(mean_squared_error(y, y_pred_lr))
            metric_dict['rmse_lr'] = rmse_lr
        if 'r2_lr' in metrics:
            r2_lr = r2_score(y, y_pred_lr) # compute the R^2 (coefficient of determination)
            metric_dict['r2_lr'] = r2_lr

    if 'pearson' in metrics:
        r_pearson = pearsonr(X.flatten(), y)[0]
        metric_dict['pearson'] = r_pearson

    if 'spearman' in metrics:
        r_spearman = spearmanr(X.flatten(), y)[0]
        metric_dict['spearman'] = r_spearman
        
    if 'kendalltau' in metrics:
        tau_kendall = kendalltau(X.flatten(), y)[0]
        metric_dict['kendalltau'] = tau_kendall

    return (var1, var2, metric_dict)

In [None]:
from itertools import permutations

def precompute_metrics(ds_dict, variables, metrics=['pearson']):
    # Initialize the results dictionary
    results_dict = {metric: {} for metric in metrics}
    
    for name, ds in ds_dict.items():
        # Create a DataFrame with all the variables
        df = pd.DataFrame({var: ds[var].values.flatten() for var in variables})
        
        # Define all pairs of variables
        pairs = list(permutations(variables, 2))  # <-- Change here
        args = [(df, var1, var2, metrics) for var1, var2 in pairs]

        # Use a multiprocessing pool to compute the metrics for all pairs
        with Pool() as p:
            results = p.map(compute_metrics_for_pair, args)
        
        # Store the results in the results_dict
        for var1, var2, metric_dict in results:
            for metric, value in metric_dict.items():
                # Ensure the keys exist in the dictionary
                results_dict[metric].setdefault(name, {}).setdefault(f'{var1}_{var2}', value)
    return results_dict

In [None]:
def calculate_yearly_correlations(ds_dict, variable_pairs, start_year=None, end_year=None, corr_type='pearson'):
    """
    Calculates yearly Pearson correlation for the given pairs of variables from the same model.

    Parameters:
    ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
    variable_pairs (list of tuples): The pairs of variables to calculate the correlation for.
    start_year (int): The start year of the period to compute the correlation over.
    end_year (int): The end year of the period to compute the correlation over.
    corr_type (str): The type of correlation coefficient to compute. Can be either 'pearson', 'spearman', or 'kendall'.

    Returns:
    dict: A dictionary with a DataFrame for each dataset, where each DataFrame contains the yearly Pearson
        correlation for each pair of variables.
    """
    # Map for complete metric names and symbols
    metric_map = {
        'r2_lr': ('Coefficient of Determination', 'R²'),
        'pearson': ('Pearson Correlation Coefficient', 'r'),
        'spearman': ('Spearman Rank Correlation Coefficient', 'ρ'),
        'kendalltau': ('Kendall Rank Correlation Coefficient', 'τ')
    }
    
    # Select period
    if start_year is not None and end_year is not None:
        select_period(ds_dict, start_year=start_year, end_year=end_year)

    yearly_correlations = {}
    
    for name, ds in ds_dict.items():
        yearly_corr_dict = {}

        # Resample to yearly data
        ds_yearly = ds.resample(time='1Y').mean()
        
        for var1, var2 in variable_pairs:
            # Prepare empty list for yearly correlations
            yearly_correlations_values = []
            yearly_correlations_years = []

            # Get the unique years
            years = ds_yearly['time'].dt.year
            for year in np.unique(years):
                # Select the data for this year
                ds_year = ds_yearly.sel(time=f'{year}')
                
                # Calculate the correlation for this year and append to the list
                if corr_type == 'pearson':
                    corr_value = xr.corr(ds_year[var1], ds_year[var2], dim=['lon', 'lat'])
                    
                elif corr_type == 'spearman':
                    ds_year_stacked = ds_year.stack(z=('lon', 'lat'))
                    df = ds_year_stacked.to_dataframe()
                    #df = np.isfinite(df[var1])
                    corr_value, _ = spearmanr(df[var1], df[var2])
                    corr_value = xr.DataArray(corr_value)
                    
                elif corr_type == 'kendall':
                    ds_year_stacked = ds_year.stack(z=('lon', 'lat'))
                    df = ds_year_stacked.to_dataframe()
                    corr_value, _ = kendalltau(df[var1], df[var2])
                    corr_value = xr.DataArray(corr_value)

                else:
                    raise ValueError("Invalid correlation type. Expected 'pearson', 'spearman', or 'kendall'.")

                # Calculate the correlation for this year and append to the list
                #corr_value = xr.corr(ds_year[var1], ds_year[var2], dim=['lon', 'lat'])

                yearly_correlations_values.append(float(corr_value.values)) # Extract the scalar value
                yearly_correlations_years.append(year)

            # Store in the yearly_corr_dict
            yearly_corr_dict[f'{var1}-{var2}'] = xr.DataArray(yearly_correlations_values, dims='time', coords={'time': yearly_correlations_years})

        # Create a Dataset from the yearly_corr_dict and store in the yearly_correlations dict
        yearly_correlations[name] = xr.Dataset(yearly_corr_dict)
        yearly_correlations[name].attrs = {'Metric': metric_map[corr_type][0],
                                           'Metric_sign': metric_map[corr_type][1]
                                          }

    return yearly_correlations

In [None]:
def compute_stats(ds_dict):
    """
    Compute yearly mean of each variable in the dataset.

    Parameters:
    ds_dict (dict): The input dictionary of xarray.Dataset.

    Returns:
    dict: A dictionary where the keys are the dataset names and the values are another dictionary.
          This inner dictionary has keys as variable names and values as DataArray of yearly means.
    """
    stats = {}
    for model, ds in ds_dict.items():
        # Compute the yearly mean
        yearly_ds = ds.resample(time='1Y').mean()

        stats[model] = {}
        for var in yearly_ds.data_vars:
            # Compute the spatial mean
            spatial_mean = yearly_ds[var].mean(dim=['lat', 'lon'])
            
            # Store the yearly mean values
            stats[model][var] = spatial_mean
    return stats

In [None]:
def compute_yearly_means(ds_dict_region):
    yearly_means_dict = {}

    # For each dataset, compute the yearly mean over the 'time', 'lat', and 'lon' dimensions
    for region, ds_dict in ds_dict_region.items():
        yearly_means_dict[region] = {}
        for ds_name, ds in ds_dict.items():
            # Create weights
            weights = np.cos(np.deg2rad(ds.lat))

            # Compute the yearly mean
            ds_yearly = ds.groupby('time.year').mean('time')

            # Apply the weights and calculate the spatial mean
            ds_weighted = ds_yearly.weighted(weights)
            yearly_means_dict[region][ds_name] = ds_weighted.mean(('lat', 'lon'))

    return yearly_means_dict

In [None]:
def compute_correlation_coefficients(ds_dict, variables, time_dim='time', yearly_corr=False, target_var=None):
    """
    Compute the correlation coefficients for different variable combinations and
    store them in a new dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        variables (list): A list of variables for which to compute the correlation coefficients.
        time_dim (str): The name of the time dimension in the datasets. Default is 'time'.
        yearly_corr (bool): If True, compute correlation of yearly mean data. Default is False.
        target_var (str, optional): If provided, only compute correlations with this variable. Default is None.

    Returns:
        dict: A new dictionary where keys are the same as in the input dictionary, and each value is an xarray Dataset
              containing correlation coefficients for different variable combinations.
    """
    
    # Create a new dictionary to store the correlation data
    ds_dict_corr = {}

    # Determine the combinations of variables to compute correlations for
    if target_var is not None:
        if target_var not in variables:
            raise ValueError(f"Target variable '{target_var}' not found in the list of variables.")
        var_combinations = [(target_var, var) for var in variables if var != target_var]
    else:
        var_combinations = combinations(variables, 2)

    # Iterate over all combinations of two variables
    for var1, var2 in var_combinations:
        # Iterate over all datasets in the dictionary
        for ds_name, ds in ds_dict.items():
            # Check if both variables exist in the current dataset
            if var1 in ds and var2 in ds:
                if yearly_corr:
                    # Resample the dataset to yearly data
                    ds_corr = ds.resample({time_dim: 'Y'}).mean()
                else:
                    ds_corr = ds
                
                # If the dataset is not yet in the new dictionary, create a new xarray Dataset for it
                if ds_name not in ds_dict_corr:
                    ds_dict_corr[ds_name] = xr.Dataset()

                # Compute the correlation coefficients and add them as a new DataArray to the Dataset
                ds_dict_corr[ds_name][f'{var1} x {var2}'] = xr.corr(ds_corr[var1], ds_corr[var2], dim=time_dim)
                ds_dict_corr[ds_name].attrs = ds_dict[ds_name].attrs
                ds_dict_corr[ds_name].attrs['Metric'] = 'Pearson Correlation Coefficient'
                ds_dict_corr[ds_name].attrs['Metric_sign'] = 'r'
                
                if yearly_corr:
                    ds_dict_corr[ds_name].attrs['means'] = 'yearly means'
                else:
                    ds_dict_corr[ds_name].attrs['means'] = 'monthly means'
                
                
    return ds_dict_corr

In [None]:
def compute_correlation_coefficients(ds_dict, variables, time_dim='time', yearly_corr=False):
    """
    Compute the correlation coefficients for different variable combinations and
    store them in a new dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        variables (list): A list of variables for which to compute the correlation coefficients.
        time_dim (str): The name of the time dimension in the datasets. Default is 'time'.

    Returns:
        dict: A new dictionary where keys are the same as in the input dictionary, and each value is an xarray Dataset
              containing correlation coefficients for different variable combinations.
    """
    
    # Create a new dictionary to store the correlation data
    ds_dict_corr = {}

    # Iterate over all combinations of two variables
    for var1, var2 in combinations(variables, 2):
        # Iterate over all datasets in the dictionary
        for ds_name, ds in ds_dict.items():
            # Check if both variables exist in the current dataset
            if var1 in ds and var2 in ds:
                if yearly_corr:
                    # Resample the dataset to yearly data
                    ds_corr = ds.resample({time_dim: 'Y'}).mean()
                else:
                    ds_corr = ds
                
                # If the dataset is not yet in the new dictionary, create a new xarray Dataset for it
                if ds_name not in ds_dict_corr:
                    ds_dict_corr[ds_name] = xr.Dataset()

                # Compute the correlation coefficients and add them as a new DataArray to the Dataset
                ds_dict_corr[ds_name][f'{var1} x {var2}'] = xr.corr(ds_corr[var1], ds_corr[var2], dim=time_dim)
                ds_dict_corr[ds_name].attrs = ds_dict[ds_name].attrs
                ds_dict_corr[ds_name].attrs['Metric'] = 'Pearson Correlation Coefficient'
                ds_dict_corr[ds_name].attrs['Metric_sign'] = 'r'
                
                if yearly_corr:
                    ds_dict_corr[ds_name].attrs['means'] = 'yearly means'
                else:
                    ds_dict_corr[ds_name].attrs['means'] = 'monthly means'
                
                
    return ds_dict_corr

In [None]:
def compute_change_corr(ds_dict_hist_corr, ds_dict_ssp370_corr):
    ds_dict_corr_change = {}

    for name, ds in ds_dict_hist_corr.items():
        ds_dict_corr_change[name] = ds_dict_ssp370_corr[name] - ds
        ds_dict_corr_change[name].attrs = {'period': 'Change Correlation SSP370 - Historical',
                                      'statistic': 'mean',
                                      'statistic_dimension':  'time',
                                      'experiment_id': 'ssp370-historical',
                                      'source_id': name,
                                      'Metric': 'Pearson Correlation Coefficient',
                                      'Metric_sign': 'r'
                                    }
        if ds_dict_hist_corr[name].attrs['means'] == 'yearly means' and ds_dict_ssp370_corr[name].attrs['means'] == 'yearly means':
            ds_dict_corr_change[name].attrs['means'] = 'yearly means'
        elif ds_dict_hist_corr[name].attrs['means'] == 'monthly means' and ds_dict_ssp370_corr[name].attrs['means'] == 'monthly means':
            ds_dict_corr_change[name].attrs['means'] = 'monthly means'
        else:
            raise ValueError(f"Computing change between seasonal and monthly mean data.")
            
        return ds_dict_corr_change

In [None]:
def compute_change(ds_dict_hist_mean, ds_dict_ssp370_mean, relative_change=False, iqr=False):
    ds_dict_change = {}

    for name, ds in ds_dict_hist_mean.items():
        # Compute either absolute or relative change
        if relative_change:
            
            # Convert temperature back to Kelvin to not have negative and positive values
            attrs = ds_dict_hist_mean[name]['tas'].attrs
            ds_dict_hist_mean[name]['tas'] = ds_dict_hist_mean[name]['tas'] + 273.15
            ds_dict_hist_mean[name]['tas'].attrs = attrs
            ds_dict_hist_mean[name]['tas'].attrs['units'] = 'K'
            
            attrs = ds_dict_ssp370_mean[name]['tas'].attrs
            ds_dict_ssp370_mean[name]['tas'] = ds_dict_ssp370_mean[name]['tas'] + 273.15
            ds_dict_ssp370_mean[name]['tas'].attrs = attrs
            ds_dict_ssp370_mean[name]['tas'].attrs['units'] = 'K'
            
            ds_f = ds_dict_ssp370_period_metric_rel[name]

            # Compute relative change only where ds is not 0
            rel_change = ds.where(ds != 0)
            rel_change = ((ds_f - rel_change) / abs(rel_change)) * 100

            # Where ds was 0, set the corresponding relative change to np.nan
            rel_change = rel_change.where(ds != 0)

            if iqr:
                # Compute the IQR and use it for filtering
                q1 = rel_change.quantile(0.25)
                q3 = rel_change.quantile(0.75)
                iqr = q3 - q1
                lower_bound = q1 - 4 * iqr
                upper_bound = q3 + 4 * iqr

                # Apply the condition
                condition = (rel_change >= lower_bound) & (rel_change <= upper_bound)
                rel_change = rel_change.where(condition)

            ds_dict_change[name] = rel_change
            
        else:
            ds_dict_change[name] = ds_dict_ssp370_mean[name] - ds
            
        ds_dict_change[name].attrs = {'period': 'Change SSP370 - Historical',
                                      'statistic': ds_dict_ssp370_mean[name].statistic,
                                      'statistic_dimension':  ds_dict_ssp370_mean[name].statistic_dimension,
                                      'experiment_id': 'ssp370-historical',
                                      'source_id': ds_dict_ssp370_mean[name].source_id,
                                      'frequency': ds_dict_ssp370_mean[name].frequency
                                    }
        for variables in ds:
            ds_dict_change[name][variables].attrs = ds_dict_ssp370_mean[name][variables].attrs
            if relative_change:
                ds_dict_change[name][variables].attrs['units_rel'] = '%'
                ds_dict_change[name].attrs['change'] = 'Relative Change'
            else:
                ds_dict_change[name].attrs['change'] = 'Absolute Change'
             
    return ds_dict_change

In [None]:
def compute_nbwfp(ds_dict):

    for model, ds in ds_dict.items():
        nbwfp = (ds['mrro']-ds['tran'])/ds['pr']

        # Replace infinite values with NaN
        nbwfp = xr.where(np.isinf(nbwfp), float('nan'), nbwfp)

        # Set all values above 2 and below -2 to NaN
        nbwfp = xr.where(nbwfp > 2, float('nan'), nbwfp)
        nbwfp = xr.where(nbwfp < -2, float('nan'), nbwfp)

        ds_dict[model]['nbwfp'] = nbwfp
        ds_dict[model]['nbwfp'].attrs = {'long_name': 'Net Blue Water Flux / Precipitation',
                             'units': ''}
        
    return ds_dict

#### Plot data

In [None]:
def plot_mean_values(stats_hist, stats_ssp370, scaling_both=False, scaling_hist=False, full_variable_names=None):
    """
    Plots the boxplots for each variable in each model.

    Parameters:
    stats_hist (dict): The statistics for the historical period.
    stats_ssp370 (dict): The statistics for the SSP370 scenario.
    scaling (bool): whether to scale y-axis data.
    full_variable_names (dict): dictionary of full variable names.
    """
    # Get the list of variables
    variables = list(stats_hist[next(iter(stats_hist))].keys())

    # Calculate the number of plots and dimensions of the grid of subplots
    n_plots = len(variables)
    n_cols = min(n_plots, 3)  # Maximum 3 plots in a row
    n_rows = int(np.ceil(n_plots / n_cols))

    # Create the figure
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows), squeeze=False)
    axs = axs.flatten()  # Flatten the axes array

    for i, var in enumerate(variables):
        # Prepare lists to store yearly mean values
        yearly_means_hist, yearly_means_ssp370 = [], []

        for name in stats_hist.keys():
            # Extract the yearly mean values for the historical period
            yearly_means_hist.append(stats_hist[name][var].values)
            
            # Extract the yearly mean values for the SSP370 scenario
            yearly_means_ssp370.append(stats_ssp370[name][var].values)

        # Standardize the data if requested
        if scaling_hist:
            scaler = StandardScaler()
            yearly_means_hist = [scaler.fit_transform(np.array(x).reshape(-1, 1)).flatten() for x in yearly_means_hist]
            yearly_means_ssp370 = [scaler.transform(np.array(x).reshape(-1, 1)).flatten() for x in yearly_means_ssp370]
            
            fig.suptitle(f'Yearly Means of Variables for Historical (1985-2014) and SSP370 (2071-2100) Period Normalized by Historical Mean and Standard Deviation of Historical Data', fontsize=12, y=1.0)
            suffix = "_scaled_historical"
        elif scaling_both:
            scaler_hist = StandardScaler()
            yearly_means_hist = [scaler_hist.fit_transform(np.array(x).reshape(-1, 1)).flatten() for x in yearly_means_hist]

            scaler_ssp370 = StandardScaler()
            yearly_means_ssp370 = [scaler_ssp370.fit_transform(np.array(x).reshape(-1, 1)).flatten() for x in yearly_means_ssp370]
            
            fig.suptitle(f'Yearly Means of Variables for Historical (1985-2014) and SSP370 (2071-2100) Period Normalized by Respective Mean and Standard Deviation', fontsize=12, y=1.0)
            suffix = "_scaled_both"
        else:
            fig.suptitle(f'Yearly Means of Variables for Historical (1985-2014) and SSP370 (2071-2100)', fontsize=12, y=1.0)
            suffix = ""

            
        # Compute the box plot positions
        positions = np.arange(len(stats_hist.keys()))

        # Select the current axes
        ax = axs[i]

        # Define an offset for x-values to place box plots for different periods side by side
        offset = 0.15

        # Plot the box plots for the historical period
        ax.boxplot(yearly_means_hist, positions=positions-offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='cornflowerblue'), medianprops=dict(color='black'), 
                   showfliers=True)
        
        # Plot the box plots for the SSP370 scenario
        ax.boxplot(yearly_means_ssp370, positions=positions+offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='sandybrown'), medianprops=dict(color='black'), 
                   showfliers=True)

        # Set the x-ticks labels and the title
        ax.set_xticks(positions)
        ax.set_xticklabels(stats_hist.keys(), rotation=90)
        ax.set_ylabel(f'[{full_variable_names[var][1]}]', fontsize=9)
        
        # Set title with full variable name if provided
        if full_variable_names and var in full_variable_names:
            ax.set_title(full_variable_names[var][0], fontsize=9)
        else:
            ax.set_title(var)

    # Set a legend
    axs[0].legend([Patch(facecolor='cornflowerblue'), Patch(facecolor='sandybrown')], ['Historical', 'SSP370'])

    # Adjust the layout
    plt.tight_layout()
    
    plt.show()
    
    # Save figure
    filename = f"Variable_changes{suffix}.png"
    
    savepath = f'../../results/CMIP6/yearly_mean_comparison'
    os.makedirs(savepath, exist_ok=True)

    filepath = os.path.join(savepath, filename)

    fig.savefig(filepath, dpi=300, bbox_inches='tight')
    
    return filepath

In [None]:
def plot_mean_change_map(ds_dict, variable, vmin, vmax, cmap='viridis', metric='mean', save_fig=False, file_format='png'):
    """
    Plots a map of the specified statistic of the given variable for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        variable (str): The name of the variable to plot.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.

    Returns:
        str: The file path where the figure was saved.
    """
    
    # Check arguments and get info
    if ds_dict[list(ds_dict.keys())[0]].attrs['change'] == 'Relative Change':
        var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency = check_args_and_get_info(ds_dict, variable)
        unit = '%'
    else:
        var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency = check_args_and_get_info(ds_dict, variable)
    
    # Create a figure
    n_datasets_with_var = sum([1 for ds in ds_dict.values() if variable in ds])
    n_cols = 4  # Set number of columns to 4
    n_rows = math.ceil(n_datasets_with_var / n_cols)  # Calculate rows
    
    fig = plt.figure(figsize=[12 * n_cols, 6.5 * n_rows])  # Start with a blank figure, without subplots

    plt.style.use('default')

    # Loop over datasets and plot the requested statistic
    subplot_counter = 0
    for i, (name, ds) in enumerate(ds_dict.items()):
        if variable not in ds:
            print(f"Variable '{variable}' not found in dataset '{name}', skipping.")
            continue

        # Add a new subplot with a cartopy projection
        ax = fig.add_subplot(n_rows, n_cols, subplot_counter+1, projection=ccrs.Robinson())
        
        data_to_plot = ds[variable]
        im = data_to_plot.plot(ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(), add_colorbar=False)  # Added a cartopy transform to the plot and cmap parameter
        ax.set_title(name, fontsize=18)
        ax.coastlines()  # Adds lines around the continents

        subplot_counter += 1

    # Adjust spacing between subplots
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=fig.axes, extend='both', orientation='horizontal', fraction=0.02, pad=0.04, aspect=75, shrink=0.4)

    # Set tick size
    cbar.ax.tick_params(labelsize=20)  # Adjust size as needed
    
    # Set colorbar label
    cbar.set_label(f"{var_long_name} Change [{unit}]", size=26)  # Adjust size as needed
    
    # Set figure title with first and last year of dataset 
    if ds_dict[list(ds_dict.keys())[0]].attrs['change'] == 'Relative Change':
        fig.suptitle(f"Relative Change of {var_long_name} {titles[statistic]} ({period})", fontsize=26, y=0.9)
    else:
        fig.suptitle(f"Absolute Change of {var_long_name} {titles[statistic]} ({period})", fontsize=26, y=0.9)
    
    # Show plot
    plt.show()
    
    # Safe figure
    if save_fig:
        savepath = os.path.join('../..', 'results', 'CMIP6', 'ssp370-historical', 'time', metric, 'change_maps')
        os.makedirs(savepath, exist_ok=True)
        if ds_dict[list(ds_dict.keys())[0]].attrs['change'] == 'Relative Change':
            filename = f'relative_change.{statistic}.{variable}.{ds.experiment_id}.{file_format}'
        else:
            filename = f'absolute_change.{statistic}.{variable}.{ds.experiment_id}.{file_format}'
        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

#### Plot metrics

In [None]:
def plot_mean_correlations(correlations_hist, correlations_ssp370, variable_pairs, target_variable, scale_axis=False, variable_captions=None):
    """
    Plots the mean correlations for each variable pair that includes the target variable.

    Parameters:
    correlations_hist (dict): The output from calculate_correlations for the historical period.
    correlations_ssp370 (dict): The output from calculate_correlations for the SSP370 scenario.
    variable_pairs (list of tuples): The pairs of variables that the correlations were calculated for.
    target_variable (str): The variable that must be included in a pair for it to be plotted.
    scale_axis (bool): Whether to scale the y-axis according to metric value ranges. Default is False.
    """
    
    # Get info
    metric = correlations_hist[list(correlations_hist.keys())[0]].Metric
    metric_sign = correlations_hist[list(correlations_hist.keys())[0]].Metric_sign
    
    # Filter variable pairs
    variable_pairs = [(var1, var2) for var1, var2 in variable_pairs if var1 == target_variable or var2 == target_variable]

    # Calculate the number of plots and dimensions of the grid of subplots
    n_plots = len(variable_pairs)
    n_cols = min(n_plots, 3)  # Maximum 3 plots in a row
    n_rows = int(np.ceil(n_plots / n_cols))

    # Create the figure
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows), squeeze=False)
    axs = axs.flatten()  # Flatten the axes array
    
    for i, (var1, var2) in enumerate(variable_pairs):
        # Prepare lists to store yearly correlation values
        yearly_corr_hist, yearly_corr_ssp370 = [], []

        for name in correlations_hist.keys():
            # Extract the yearly mean values for the historical period
            yearly_corr_hist.append(correlations_hist[name][f'{var1}-{var2}'].values)

            # Extract the yearly mean values for the SSP370 scenario
            yearly_corr_ssp370.append(correlations_ssp370[name][f'{var1}-{var2}'].values)

        # Compute the box plot positions
        positions = np.arange(len(correlations_hist.keys()))

        # Select the current axes
        ax = axs[i]

        # Define an offset for x-values to place box plots for different periods side by side
        offset = 0.15

        # Plot the box plots for the historical period
        ax.boxplot(yearly_corr_hist, positions=positions-offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='cornflowerblue'), medianprops=dict(color='black'), 
                   showfliers=True)

        # Plot the box plots for the SSP370 scenario
        ax.boxplot(yearly_corr_ssp370, positions=positions+offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='sandybrown'), medianprops=dict(color='black'), 
                   showfliers=True)

        # Set the x-ticks labels and the title
        ax.set_ylabel(f'{metric_sign}')
        ax.set_title(f'{variable_captions.get(var1, var1)} x {variable_captions.get(var2, var2)}', fontsize=9)
        ax.set_xticks(positions)
        ax.set_xticklabels(list(correlations_hist.keys()), rotation=90)
        
    fig.suptitle(f'Yearly {metric} for Historical (1985-2014) and SSP370 (2071-2100) Period', fontsize=12, y=1.0)
    
    # Set a legend
    axs[0].legend([Patch(facecolor='cornflowerblue'), Patch(facecolor='sandybrown')], ['Historical', 'SSP370'])

    # Handle empty subplots in case n_plots is less than n_rows * n_cols
    for i in range(n_plots, n_rows*n_cols):
        fig.delaxes(axs[i])
    
    # Handle y-axis scaling
    if scale_axis:
        for ax in axs:
            ax.set_ylim([-1, 1] if metric != 'r2_lr' else [0, 1])

    # Adjust the layout
    plt.tight_layout()
    plt.show()
    
    # Save figure
    suffix = "_scaled_axis" if scale_axis else ""
    filename = f"{metric}_changes_{target_variable}{suffix}.png"
    
    savepath = f'../../results/CMIP6/yearly_metrics_comparison'
    os.makedirs(savepath, exist_ok=True)

    filepath = os.path.join(savepath, filename)

    fig.savefig(filepath, dpi=300, bbox_inches='tight')

In [None]:
def ensemble_corr_change_plot(ds_dict, target_variable, full_var_names_and_unit, cmap='viridis', save_fig=False, file_format='png'):
    """
    Plots a map of the specified statistic of the given variable for the Ensemble_mean dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        target_variable (str): The target variable to plot correlations with.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.
    """
    # Info
    experiment_id = ds_dict[list(ds_dict.keys())[0]].experiment_id
    metric = ds_dict[list(ds_dict.keys())[0]].Metric
    metric_sign = ds_dict[list(ds_dict.keys())[0]].Metric_sign
    means = ds_dict[list(ds_dict.keys())[0]].attrs['means']
    target_var_long_name = full_var_names_and_unit[target_variable][0]
    
    # Create a figure
    n_cols = 3  # Set number of columns to 3
    n_rows = 3  # Set number of rows to 3
    
    fig = plt.figure(figsize=[12 * n_cols, 6.5 * n_rows])  # Start with a blank figure, without subplots

    plt.style.use('default')

    # Loop over datasets and plot the requested statistic
    subplot_counter = 0

    # Get the Ensemble_mean dataset
    ensemble_ds = ds_dict.get("Ensemble_mean", None)

    if ensemble_ds is None:
        print("Ensemble_mean dataset not found.")
        return None

    for variable in ensemble_ds.variables:
        if not (f'{target_variable} x ' in variable or f' x {target_variable}' in variable):
            continue

        # Add a new subplot with a cartopy projection
        ax = fig.add_subplot(n_rows, n_cols, subplot_counter+1, projection=ccrs.Robinson())

        data_to_plot = ensemble_ds[variable]
        im = data_to_plot.plot(ax=ax, cmap=cmap, vmin = -1, vmax = 1, transform=ccrs.PlateCarree(), add_colorbar=False)  # Added a cartopy transform to the plot and cmap parameter
        
        if f'{target_variable} x ' in variable:
            corr_var = variable.replace(f'{target_variable} x ', '')
        elif f' x {target_variable}' in variable:
            corr_var = variable.replace(f' x {target_variable}', '')
        else:
            continue
        corr_var_long_name = full_var_names_and_unit[corr_var][0]
        ax.set_title(f'{target_var_long_name} x {corr_var_long_name}', fontsize=18)  # Use the long names in the title
        ax.coastlines()  # Adds lines around the continents

        subplot_counter += 1

    # Adjust spacing between subplots
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=fig.axes, orientation='horizontal', fraction=0.02, pad=0.04, aspect=75, shrink=0.4)
    

    # Set colorbar ticks
    cbar.set_ticks([-0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75])
    
    
    # Set tick size
    cbar.ax.tick_params(labelsize=20)  # Adjust size as needed
    
    # Set colorbar label
    cbar.set_label(f'{metric_sign} change', size=26)  # Adjust size as needed
    
    # Set figure title with first and last year of dataset 
    if experiment_id == 'historical' or experiment_id == 'ssp370':
        fig.suptitle(f"{metric} ({experiment_id}) of Ensemble Mean for Variable Combinations with {target_var_long_name} ({means})", fontsize=20, y=0.9)
    elif experiment_id == 'ssp370-historical':
        fig.suptitle(f"{metric} Change ({experiment_id}) of Ensemble Mean for Variable Combinations with {target_var_long_name} ({means})", fontsize=20, y=0.9)
    
    
    # Show plot
    plt.show()
    
    # Save figure
    if save_fig:
        if experiment_id == 'historical' or experiment_id == 'ssp370':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename = f'ensmean_{target_variable}_correlations.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
        elif experiment_id == 'ssp370-historical':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename =f'ensmean_{target_variable}_correlations_change.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
        
    return filepath

In [None]:
def corr_maps(ds_dict, target_variable_combination, cmap='viridis', save_fig=False, file_format='png'):
    """
    Plots a map of the specified variable combination for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        target_variable_combination (str): The target variable combination to plot.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.
    """
    # Info
    experiment_id = ds_dict[list(ds_dict.keys())[0]].experiment_id
    metric = ds_dict[list(ds_dict.keys())[0]].Metric
    metric_sign = ds_dict[list(ds_dict.keys())[0]].Metric_sign
    means = ds_dict[list(ds_dict.keys())[0]].attrs['means']
    
    # Create a figure
    n_datasets_with_var = sum([1 for ds in ds_dict.values() if target_variable_combination in ds.variables])
    n_cols = 3  # Set number of columns to 3
    n_rows = math.ceil(n_datasets_with_var / n_cols)  # Calculate rows
    
    fig = plt.figure(figsize=[12 * n_cols, 6.5 * n_rows])  # Start with a blank figure, without subplots

    plt.style.use('default')

    # Loop over datasets and plot the requested variable combination
    subplot_counter = 0
    for i, (name, ds) in enumerate(ds_dict.items()):
        if target_variable_combination not in ds.variables:
            continue

        # Add a new subplot with a cartopy projection
        ax = fig.add_subplot(n_rows, n_cols, subplot_counter+1, projection=ccrs.Robinson())
        
        data_to_plot = ds[target_variable_combination]
        im = data_to_plot.plot(ax=ax, cmap=cmap, transform=ccrs.PlateCarree(), add_colorbar=False)
        ax.set_title(name, fontsize=18)
        ax.coastlines()  # Adds lines around the continents

        subplot_counter += 1

    # Adjust spacing between subplots
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=fig.axes, orientation='horizontal', fraction=0.02, pad=0.04, aspect=75, shrink=0.4)
    
    # Set tick size
    cbar.ax.tick_params(labelsize=14)  # Adjust size as needed
    
    # Set colorbar label
    cbar.set_label(metric_sign, size=18)  # Adjust size as needed
    
    # Set figure title with first and last year of dataset 
    if experiment_id == 'historical' or 'ssp370':
        fig.suptitle(f"{metric} ({experiment_id}) for Variable Combination {target_variable_combination} ({means})", fontsize=20, y=0.9)
    elif experiment_id == 'ssp370-historical':
        fig.suptitle(f"{metric} Change ({experiment_id}) for Variable Combination {target_variable_combination} ({means})", fontsize=20, y=0.9)
    
    # Show plot
    plt.show()
    
    # Save figure
    if save_fig:
        if experiment_id == 'historical' or 'ssp370':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename = f'{target_variable_combination}_correlation.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
        elif experiment_id == 'ssp370-historical':
            savepath = os.path.join('../..', 'results', 'CMIP6', 'comparison', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename =f'{target_variable_combination}_correlation_change.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

### Load netCDF files

In [None]:
# ========= Define period, models and path ==============
variables=['tas', 'pr', 'vpd', 'evspsbl', 'mrro', 'lmrso_1m', 'lmrso_2m', 'tran', 'lai', 'gpp', 'EI', 'wue']
experiment_id = 'historical'
source_id = ['TaiESM1', 'BCC-CSM2-MR',  'CanESM5', 'CNRM-CM6-1', 'CNRM-ESM2-1', 'IPSL-CM6A-LR', 'UKESM1-0-LL', 'MPI-ESM1-2-LR', 'CESM2-WACCM', 'NorESM2-MM', 'Ensemble mean', 'Ensemble median'] #
folder='preprocessed'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# Create dictionary using a dictionary comprehension and Dask
ds_dict_hist = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, variables) for model in source_id})[0]

In [None]:
# ============= Have a look into the data ==============
print(ds_dict_hist.keys())
ds_dict_hist[list(ds_dict_hist.keys())[0]]

### Select period

In [None]:
ds_dict_hist_period = select_period(ds_dict_hist, start_year=1985, end_year=2014)

In [None]:
ds_dict_ssp370_period = select_period(ds_dict_ssp370, start_year=2071, end_year=2100)

### Compute statistics

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_hist_period_metric = compute_statistic(ds_dict_hist_period, 'mean', 'time')

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_ssp370_period_metric = compute_statistic(ds_dict_ssp370_period, 'mean', 'time')

### Compute Ensemble metric

In [None]:
ds_dict_hist_period_metric = compute_ensemble(ds_dict_hist_period_metric)

In [None]:
ds_dict_ssp370_period_metric = compute_ensemble(ds_dict_ssp370_period_metric)

### Compute NBWFP

In [None]:
ds_dict_hist_period = compute_nbwfp(ds_dict_hist_period)

In [None]:
ds_dict_ssp370_period = compute_nbwfp(ds_dict_ssp370_period)

### Compute change

##### Absolute change

In [None]:
ds_dict_abs_change = {}
ds_dict_abs_change = compute_change(ds_dict_hist_period_metric, ds_dict_ssp370_period_metric, relative_change=False)

##### Relative change

In [None]:
ds_dict_rel_change = {}
ds_dict_rel_change = compute_change(ds_dict_hist_period_metric, ds_dict_ssp370_period_metric, relative_change=True, iqr=False)

### Plot Variable Change Maps

In [None]:
# ========= Plot Change ==========
plot_mean_change_map(ds_dict_rel_change, 'tas', 0, 5, cmap='YlOrRd', metric='mean', save_fig=True, file_format='png')

### Plot Correlation Maps

In [None]:
# ========= Compute Correlations ========

In [None]:
variables=['tas', 'pr', 'vpd', 'evspsbl', 'mrro', 'lmrso_1m', 'lmrso_2m', 'tran', 'lai', 'gpp', 'wue', 'nbwfp']

In [None]:
ds_dict_hist_corr = compute_correlation_coefficients(ds_dict_hist_period, variables, yearly_corr=True)

In [None]:
ds_dict_ssp370_corr = compute_correlation_coefficients(ds_dict_ssp370_period, variables, yearly_corr=True)

In [None]:
# ======= Create unique variable pairs ========
variable_pairs = [(variables[i], variables[j]) for i in range(len(variables)) for j in range(i+1, len(variables))]

In [None]:
target_variable = 'nbwfp'
variable_pairs = [(target_variable, var) for var in variables if var != target_variable]

In [None]:
# ======== Create Plots for all variable pairs
for var in variable_pairs:
    corr_maps(ds_dict_hist_corr, f'{var[0]} x {var[0+1]}', 'coolwarm', save_fig=True)

In [None]:
# ======== Create Plots for all variable pairs
for var in variable_pairs:
    corr_maps(ds_dict_ssp370_corr, f'{var[0]} x {var[0+1]}', 'coolwarm', save_fig=True)

### Plot Correlation Change Maps

In [None]:
# ========= Compute Change ========
ds_dict_corr_change = compute_change_corr(ds_dict_hist_corr, ds_dict_ssp370_corr)

In [None]:
# ======= Create unique variable pairs ========
variable_pairs = [(variables[i], variables[j]) for i in range(len(variables)) for j in range(i+1, len(variables))]

In [None]:
# ======== Create Plots for all variable pairs
for var in variable_pairs:
    corr_maps(ds_dict_corr_change, f'{var[0]} x {var[0+1]}', 'coolwarm', save_fig=True)

### Plot Correlation Change Maps for Ensemble mean

In [None]:
ensemble_corr_change_plot(ds_dict_corr_change, 'mrro', full_var_names_and_unit, 'coolwarm', save_fig=True)