# CMIP6 Statistics and Plots

**Following steps are included in this script:**

1. Load netCDF files
2. Compute statistics
3. Plot statistics

### Import Packages

In [None]:
# ========== Packages ==========
import xarray as xr
import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import dask
import matplotlib.cm
from matplotlib import rcParams
import math
import multiprocessing as mp
from cftime import DatetimeNoLeap
import glob
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, pearsonr, kendalltau
from matplotlib.lines import Line2D
from sklearn.metrics import r2_score
import matplotlib.colors as colors
from matplotlib.patches import Patch
import matplotlib.patches as mpatches



%matplotlib inline

rcParams["mathtext.default"] = 'regular'

### Functions

#### Open files

In [None]:
# ========= Helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

In [None]:
# Define a helper function to open and merge datasets
def open_and_merge_datasets(folder, model, experiment_id, variables):
    filepaths = []
    for var in variables:
        path = f'../../data/CMIP6/{experiment_id}/{folder}/{var}'
        fp = glob.glob(os.path.join(path, f'CMIP.{model}.{experiment_id}.{var}_regridded.nc'))
        if fp:
            filepaths.append(fp[0])
        else:
            #print(f"No file found for variable '{var}' in model '{model}'.")
            print(fp)

    datasets = [xr.open_dataset(fp) for fp in filepaths]
    ds = xr.merge(datasets)
    return ds

#### Helper functions

In [None]:
def select_period(ds_dict, start_year=None, end_year=None):
    '''
    Helper function to select periods.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    """
    '''
    start_year = DatetimeNoLeap(start_year, 1, 16, 12, 0, 0, 0,has_year_zero=True) # 16th of January of start year
    end_year = DatetimeNoLeap(end_year, 12, 16, 12, 0, 0, 0, has_year_zero=True) # 16th of December of end year
    ds_dict = {k: v.sel(time=slice(start_year, end_year)) for k, v in ds_dict.items()}
    
    return ds_dict

In [None]:
# ======== Standardize ========
def standardize(ds_dict):
    '''
    Helper function to standardize datasets of a dictionary
    '''
    ds_dict_stand = {}
    
    for name, ds in ds_dict.items():
        attrs = ds.attrs
        ds_stand = (ds - ds.mean()) / ds.std()

        # Preserve variable attributes from the original dataset
        for var in ds.variables:
            if var in ds_stand.variables:
                ds_stand[var].attrs = ds[var].attrs

        ds_stand.attrs = attrs
        ds_dict_stand[name] = ds_stand
        
    return ds_dict_stand

In [None]:
def check_args_and_get_info(ds_dict, variable):
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if not isinstance(variable, str):
        raise TypeError('variable must be a string.')
        
    # Dictionary to store plot titles for each statistic
    titles = {"mean": "Mean", "std": "Standard deviation of yearly means", "min": "Minimum", "max": "Maximum", "median": "Median", "time": "Time", "space": "Space"}
    
    long_name = {
        'Precipitation': 'Precipitation',
        'Total Runoff': 'Total Runoff',
        'Vapor Pressure Deficit': 'Vapor Pressure Deficit',
        'Evaporation Including Sublimation and Transpiration': 'Evapotranspiration',
        'Transpiration': 'Transpiration',
        'Leaf Area Index': 'Leaf Area Index',
        'Carbon Mass Flux out of Atmosphere Due to Gross Primary Production on Land [kgC m-2 s-1]': 'Gross Primary Production',
        'Total Liquid Soil Moisture Content of 1 m Column': '1 m Soil Moisture',
        'Total Liquid Soil Moisture Content of 2 m Column': '2 m Soil Moisture',
        'Runoff - Precipitation': 'Runoff - Precipitation',
        'Transpiration - Precipitation': 'Transpiration - Precipitation',
        '(Runoff + Transpiration) - Precipitation':  '(Runoff + Transpiration) - Precipitation'
    }
    
    # Data information
    var_long_name = long_name[ds_dict[list(ds_dict.keys())[0]][variable].long_name]
    period = f"{ds_dict[list(ds_dict.keys())[0]].attrs['period']}"
    experiment_id =  ds_dict[list(ds_dict.keys())[0]].experiment_id
    unit = ds_dict[list(ds_dict.keys())[0]][variable].units
    statistic_dim = ds_dict[list(ds_dict.keys())[0]].statistic_dimension
    statistic = ds_dict[list(ds_dict.keys())[0]].attrs['statistic']

    return var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles

In [None]:
def check_args_and_get_info(ds_dict, variable):
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if not isinstance(variable, str):
        raise TypeError('variable must be a string.')
        
    # Dictionary to store plot titles for each statistic
    titles = {"mean": "Mean", "std": "Standard deviation of yearly means", "min": "Minimum", "max": "Maximum", "median": "Median", "time": "Time", "space": "Space"}
    freq = {"mon": "Monthly"}
    
    long_name = {
        'Precipitation': 'Precipitation',
        'Total Runoff': 'Total Runoff',
        'Vapor Pressure Deficit': 'Vapor Pressure Deficit',
        'Evaporation Including Sublimation and Transpiration': 'Evapotranspiration',
        'Transpiration': 'Transpiration',
        'Leaf Area Index': 'Leaf Area Index',
        'Carbon Mass Flux out of Atmosphere Due to Gross Primary Production on Land [kgC m-2 s-1]': 'Gross Primary Production',
        'Total Liquid Soil Moisture Content of 1 m Column': '1 m Soil Moisture',
        'Total Liquid Soil Moisture Content of 2 m Column': '2 m Soil Moisture',
        'Runoff - Precipitation': 'Runoff - Precipitation',
        'Transpiration - Precipitation': 'Transpiration - Precipitation',
        '(Runoff + Transpiration) - Precipitation':  '(Runoff + Transpiration) - Precipitation',
        'ET - Precipitation':  'ET - Precipitation', 
        'Negative Runoff': 'Negative Runoff',
    }
   
    # Data information
    var_long_name = ds_dict[list(ds_dict.keys())[0]][variable].long_name
    period = f"{ds_dict[list(ds_dict.keys())[0]].attrs['period'][0]}-{ds_dict[list(ds_dict.keys())[0]].attrs['period'][1]}"
    experiment_id =  ds_dict[list(ds_dict.keys())[0]].experiment_id
    unit = ds_dict[list(ds_dict.keys())[0]][variable].units
    statistic_dim = ds_dict[list(ds_dict.keys())[0]].statistic_dimension
    statistic = ds_dict[list(ds_dict.keys())[0]].attrs['statistic']
    frequency = freq[ds_dict[list(ds_dict.keys())[0]].frequency]

    return var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency

In [None]:
def slice_to_regions(ds_dict, regions):

    ds_dict_region = {region: {} for region in regions.keys()}

    # For each dataset, slice to each region and save in new dict
    for ds_name, ds in ds_dict.items():
        for region, bounds in regions.items():
            ds_dict_region[region][ds_name] = ds.sel(lat=bounds['lat'], lon=bounds['lon'])
            
    return ds_dict_region

In [None]:
def compute_ensemble(ds_dict):
    # Combine all datasets into one larger dataset
    combined = xr.concat(ds_dict.values(), dim='ensemble')
    # Compute the ensemble metric
    ds_dict['Ensemble mean'] = getattr(combined, 'mean')(dim='ensemble')#, skipna=True) # use getattr to call method by string name
    
    # Preserve variable attributes from the original dataset
    for var in ds_dict['Ensemble mean'].variables:
        ds_dict['Ensemble mean'][var].attrs = ds_dict[list(ds_dict.keys())[0]][var].attrs
    
    ds_dict['Ensemble mean'].attrs = {"period" : ds_dict[list(ds_dict.keys())[0]].attrs['period'],
                           "statistic" : "mean", # use variable metric here
                           "statistic_dimension" : "time",
                           "experiment_id": "ssp370-historical", 
                           "source_id" : f"Ensemble mean",
                           "frequency":  ds_dict[list(ds_dict.keys())[0]].attrs['frequency']} 
    
    ds_dict['Ensemble median'] = getattr(combined, 'median')(dim='ensemble')#, skipna=True) # use getattr to call method by string name
    
    ds_dict['Ensemble median'].attrs = {"period" : ds_dict[list(ds_dict.keys())[0]].attrs['period'],
                           "statistic" : "median", # use variable metric here
                           "statistic_dimension" : "time",
                           "experiment_id": "ssp370-historical", 
                           "source_id": f"Ensemble median",
                           "frequency":  ds_dict[list(ds_dict.keys())[0]].attrs['frequency']} 
    
    return ds_dict

#### Compute statistics

In [None]:
def compute_statistic_single(ds, statistic, dimension, yearly_mean=True):
    if dimension == "time":
        stat_ds = getattr(ds, statistic)("time", keep_attrs=True, skipna=True)
        stat_ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
    if dimension == "space":
        # Assign the period attribute before grouping by year
        ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
        if yearly_mean:
            ds = ds.groupby('time.year').mean('time', keep_attrs=True, skipna=True)
            ds.attrs['mean'] = 'yearly mean'
            
        
        #get the weights, apply on data, and compute statistic
        weights = np.cos(np.deg2rad(ds.lat))
        weights.name = "weights"
        ds_weighted = ds.weighted(weights)
        stat_ds = getattr(ds_weighted, statistic)(("lon", "lat"), keep_attrs=True, skipna=True)
    
    stat_ds.attrs['statistic'] = statistic
    stat_ds.attrs['statistic_dimension'] = dimension

    return stat_ds

In [None]:
def compute_statistic(ds_dict, statistic, dimension, start_year=None, end_year=None, yearly_mean=True):
    """
    Computes the specified statistic for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        statistic (str): The statistic to compute, which can be one of 'mean', 'std', 'min', 'var', or 'median'.
        dimension (str): The dimension to compute over, which can be 'time' or 'space'.
        start_year (str, optional): The start year of the period to compute the statistic over.
        end_year (str, optional): The end year of the period to compute the statistic over.

    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if statistic not in ["mean", "std", "min", "max", "var", "median"]:
        raise ValueError(f"Invalid statistic '{statistic}' specified.")
    if dimension not in ["time", "space"]:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")

    # Select period
    if start_year is not None and end_year is not None:
        select_period(ds_dict, start_year=start_year, end_year=end_year)
        
        
    # Use multiprocessing to compute the statistic for each dataset in parallel
    with mp.Pool() as pool:
        results = pool.starmap(compute_statistic_single, [(ds, statistic, dimension, yearly_mean) for ds in ds_dict.values()])

    return dict(zip(ds_dict.keys(), results))

In [None]:
from multiprocessing import Pool
from itertools import combinations

# Define a function to compute the metrics
def compute_metrics_for_pair(args):
    df, var1, var2, metrics = args
    valid_values = np.logical_and(np.isfinite(df[var1]), np.isfinite(df[var2]))

    # Compute metrics
    X = df[var1][valid_values].values.reshape(-1, 1)
    y = df[var2][valid_values].values
    metric_dict = {}

    if 'rmse_rf' in metrics:
        rf = RandomForestRegressor()
        rf.fit(X, y)
        y_pred_rf = rf.predict(X)
        rmse_rf = np.sqrt(mean_squared_error(y, y_pred_rf))
        metric_dict['rmse_rf'] = rmse_rf
    
    if 'rmse_lr' or 'r2_lr' in metrics:
        lr = LinearRegression()
        lr.fit(X, y)
        y_pred_lr = lr.predict(X)
        if 'rmse_lr' in metrics:
            rmse_lr = np.sqrt(mean_squared_error(y, y_pred_lr))
            metric_dict['rmse_lr'] = rmse_lr
        if 'r2_lr' in metrics:
            r2_lr = r2_score(y, y_pred_lr) # compute the R^2 (coefficient of determination)
            metric_dict['r2_lr'] = r2_lr

    if 'pearson' in metrics:
        r_pearson = pearsonr(X.flatten(), y)[0]
        metric_dict['pearson'] = r_pearson

    if 'spearman' in metrics:
        r_spearman = spearmanr(X.flatten(), y)[0]
        metric_dict['spearman'] = r_spearman
        
    if 'kendalltau' in metrics:
        tau_kendall = kendalltau(X.flatten(), y)[0]
        metric_dict['kendalltau'] = tau_kendall

    return (var1, var2, metric_dict)

In [None]:
from itertools import permutations

def precompute_metrics(ds_dict, variables, metrics=['pearson']):
    # Initialize the results dictionary
    results_dict = {metric: {} for metric in metrics}
    
    for name, ds in ds_dict.items():
        # Create a DataFrame with all the variables
        df = pd.DataFrame({var: ds[var].values.flatten() for var in variables})
        
        # Define all pairs of variables
        pairs = list(permutations(variables, 2))  # <-- Change here
        args = [(df, var1, var2, metrics) for var1, var2 in pairs]

        # Use a multiprocessing pool to compute the metrics for all pairs
        with Pool() as p:
            results = p.map(compute_metrics_for_pair, args)
        
        # Store the results in the results_dict
        for var1, var2, metric_dict in results:
            for metric, value in metric_dict.items():
                # Ensure the keys exist in the dictionary
                results_dict[metric].setdefault(name, {}).setdefault(f'{var1}_{var2}', value)
    return results_dict

In [None]:
def calculate_yearly_correlations(ds_dict, variable_pairs, start_year=None, end_year=None, corr_type='pearson'):
    """
    Calculates yearly Pearson correlation for the given pairs of variables from the same model.

    Parameters:
    ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
    variable_pairs (list of tuples): The pairs of variables to calculate the correlation for.
    start_year (int): The start year of the period to compute the correlation over.
    end_year (int): The end year of the period to compute the correlation over.
    corr_type (str): The type of correlation coefficient to compute. Can be either 'pearson', 'spearman', or 'kendall'.

    Returns:
    dict: A dictionary with a DataFrame for each dataset, where each DataFrame contains the yearly Pearson
        correlation for each pair of variables.
    """
    # Map for complete metric names and symbols
    metric_map = {
        'r2_lr': ('Coefficient of Determination', 'R²'),
        'pearson': ('Pearson Correlation Coefficient', 'r'),
        'spearman': ('Spearman Rank Correlation Coefficient', 'ρ'),
        'kendalltau': ('Kendall Rank Correlation Coefficient', 'τ')
    }
    
    # Select period
    if start_year is not None and end_year is not None:
        select_period(ds_dict, start_year=start_year, end_year=end_year)

    yearly_correlations = {}
    
    for name, ds in ds_dict.items():
        yearly_corr_dict = {}

        # Resample to yearly data
        ds_yearly = ds.resample(time='1Y').mean()
        
        for var1, var2 in variable_pairs:
            # Prepare empty list for yearly correlations
            yearly_correlations_values = []
            yearly_correlations_years = []

            # Get the unique years
            years = ds_yearly['time'].dt.year
            for year in np.unique(years):
                # Select the data for this year
                ds_year = ds_yearly.sel(time=f'{year}')
                
                # Calculate the correlation for this year and append to the list
                if corr_type == 'pearson':
                    corr_value = xr.corr(ds_year[var1], ds_year[var2], dim=['lon', 'lat'])
                    
                elif corr_type == 'spearman':
                    ds_year_stacked = ds_year.stack(z=('lon', 'lat'))
                    df = ds_year_stacked.to_dataframe()
                    #df = np.isfinite(df[var1])
                    corr_value, _ = spearmanr(df[var1], df[var2])
                    corr_value = xr.DataArray(corr_value)
                    
                elif corr_type == 'kendall':
                    ds_year_stacked = ds_year.stack(z=('lon', 'lat'))
                    df = ds_year_stacked.to_dataframe()
                    corr_value, _ = kendalltau(df[var1], df[var2])
                    corr_value = xr.DataArray(corr_value)

                else:
                    raise ValueError("Invalid correlation type. Expected 'pearson', 'spearman', or 'kendall'.")

                # Calculate the correlation for this year and append to the list
                #corr_value = xr.corr(ds_year[var1], ds_year[var2], dim=['lon', 'lat'])

                yearly_correlations_values.append(float(corr_value.values)) # Extract the scalar value
                yearly_correlations_years.append(year)

            # Store in the yearly_corr_dict
            yearly_corr_dict[f'{var1}-{var2}'] = xr.DataArray(yearly_correlations_values, dims='time', coords={'time': yearly_correlations_years})

        # Create a Dataset from the yearly_corr_dict and store in the yearly_correlations dict
        yearly_correlations[name] = xr.Dataset(yearly_corr_dict)
        yearly_correlations[name].attrs = {'Metric': metric_map[corr_type][0],
                                           'Metric_sign': metric_map[corr_type][1]
                                          }

    return yearly_correlations

In [None]:
def compute_stats(ds_dict):
    """
    Compute yearly mean of each variable in the dataset.

    Parameters:
    ds_dict (dict): The input dictionary of xarray.Dataset.

    Returns:
    dict: A dictionary where the keys are the dataset names and the values are another dictionary.
          This inner dictionary has keys as variable names and values as DataArray of yearly means.
    """
    stats = {}
    for model, ds in ds_dict.items():
        # Compute the yearly mean
        yearly_ds = ds.resample(time='1Y').mean()

        stats[model] = {}
        for var in yearly_ds.data_vars:
            # Compute the spatial mean
            spatial_mean = yearly_ds[var].mean(dim=['lat', 'lon'])
            
            # Store the yearly mean values
            stats[model][var] = spatial_mean
    return stats

In [None]:
def compute_yearly_means(ds_dict_region):
    yearly_means_dict = {}

    # For each dataset, compute the yearly mean over the 'time', 'lat', and 'lon' dimensions
    for region, ds_dict in ds_dict_region.items():
        yearly_means_dict[region] = {}
        for ds_name, ds in ds_dict.items():
            # Create weights
            weights = np.cos(np.deg2rad(ds.lat))

            # Compute the yearly mean
            ds_yearly = ds.groupby('time.year').mean('time')

            # Apply the weights and calculate the spatial mean
            ds_weighted = ds_yearly.weighted(weights)
            yearly_means_dict[region][ds_name] = ds_weighted.mean(('lat', 'lon'))

    return yearly_means_dict

In [None]:
def compute_correlation_coefficients(ds_dict, variables, time_dim='time', yearly_corr=False):
    """
    Compute the correlation coefficients for different variable combinations and
    store them in a new dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        variables (list): A list of variables for which to compute the correlation coefficients.
        time_dim (str): The name of the time dimension in the datasets. Default is 'time'.

    Returns:
        dict: A new dictionary where keys are the same as in the input dictionary, and each value is an xarray Dataset
              containing correlation coefficients for different variable combinations.
    """
    
    # Create a new dictionary to store the correlation data
    ds_dict_corr = {}

    # Iterate over all combinations of two variables
    for var1, var2 in combinations(variables, 2):
        # Iterate over all datasets in the dictionary
        for ds_name, ds in ds_dict.items():
            # Check if both variables exist in the current dataset
            if var1 in ds and var2 in ds:
                if yearly_corr:
                    # Resample the dataset to yearly data
                    ds_corr = ds.resample({time_dim: 'Y'}).mean()
                else:
                    ds_corr = ds
                
                # If the dataset is not yet in the new dictionary, create a new xarray Dataset for it
                if ds_name not in ds_dict_corr:
                    ds_dict_corr[ds_name] = xr.Dataset()

                # Compute the correlation coefficients and add them as a new DataArray to the Dataset
                ds_dict_corr[ds_name][f'{var1} x {var2}'] = xr.corr(ds_corr[var1], ds_corr[var2], dim=time_dim)
                ds_dict_corr[ds_name].attrs = ds_dict[ds_name].attrs
                ds_dict_corr[ds_name].attrs['Metric'] = 'Pearson Correlation Coefficient'
                ds_dict_corr[ds_name].attrs['Metric_sign'] = 'r'
                
                if yearly_corr:
                    ds_dict_corr[ds_name].attrs['means'] = 'yearly means'
                else:
                    ds_dict_corr[ds_name].attrs['means'] = 'monthly means'
                
                
    return ds_dict_corr

In [None]:
def compute_change_corr(ds_dict_hist_corr, ds_dict_ssp370_corr):
    ds_dict_corr_change = {}

    for name, ds in ds_dict_hist_corr.items():
        ds_dict_corr_change[name] = ds_dict_ssp370_corr[name] - ds
        ds_dict_corr_change[name].attrs = {'period': 'Change Correlation SSP370 - Historical',
                                      'statistic': 'mean',
                                      'statistic_dimension':  'time',
                                      'experiment_id': 'ssp370-historical',
                                      'source_id': name,
                                      'Metric': 'Pearson Correlation Coefficient',
                                      'Metric_sign': 'r'
                                    }
        if ds_dict_hist_corr[name].attrs['means'] == 'yearly means' and ds_dict_ssp370_corr[name].attrs['means'] == 'yearly means':
            ds_dict_corr_change[name].attrs['means'] = 'yearly means'
        elif ds_dict_hist_corr[name].attrs['means'] == 'monthly means' and ds_dict_ssp370_corr[name].attrs['means'] == 'monthly means':
            ds_dict_corr_change[name].attrs['means'] = 'monthly means'
        else:
            raise ValueError(f"Computing change between seasonal and monthly mean data.")
            
        return ds_dict_corr_change

In [None]:
def compute_change(ds_dict_hist_mean, ds_dict_ssp370_mean, relative_change=False, iqr=False):
    ds_dict_change = {}

    for name, ds in ds_dict_hist_mean.items():
        # Compute either absolute or relative change
        if relative_change:
            
            # Convert temperature back to Kelvin to not have negative and positive values
            attrs = ds_dict_hist_mean[name]['tas'].attrs
            ds_dict_hist_mean[name]['tas'] = ds_dict_hist_mean[name]['tas'] + 273.15
            ds_dict_hist_mean[name]['tas'].attrs = attrs
            ds_dict_hist_mean[name]['tas'].attrs['units'] = 'K'
            
            attrs = ds_dict_ssp370_mean[name]['tas'].attrs
            ds_dict_ssp370_mean[name]['tas'] = ds_dict_ssp370_mean[name]['tas'] + 273.15
            ds_dict_ssp370_mean[name]['tas'].attrs = attrs
            ds_dict_ssp370_mean[name]['tas'].attrs['units'] = 'K'
            
            ds_f = ds_dict_ssp370_period_metric_rel[name]

            # Compute relative change only where ds is not 0
            rel_change = ds.where(ds != 0)
            rel_change = ((ds_f - rel_change) / abs(rel_change)) * 100

            # Where ds was 0, set the corresponding relative change to np.nan
            rel_change = rel_change.where(ds != 0)

            if iqr:
                # Compute the IQR and use it for filtering
                q1 = rel_change.quantile(0.25)
                q3 = rel_change.quantile(0.75)
                iqr = q3 - q1
                lower_bound = q1 - 4 * iqr
                upper_bound = q3 + 4 * iqr

                # Apply the condition
                condition = (rel_change >= lower_bound) & (rel_change <= upper_bound)
                rel_change = rel_change.where(condition)

            ds_dict_change[name] = rel_change
            
        else:
            ds_dict_change[name] = ds_dict_ssp370_mean[name] - ds
            
        ds_dict_change[name].attrs = {'period': 'Change SSP370 - Historical',
                                      'statistic': ds_dict_ssp370_mean[name].statistic,
                                      'statistic_dimension':  ds_dict_ssp370_mean[name].statistic_dimension,
                                      'experiment_id': 'ssp370-historical',
                                      'source_id': ds_dict_ssp370_mean[name].source_id,
                                      'frequency': ds_dict_ssp370_mean[name].frequency
                                    }
        for variables in ds:
            ds_dict_change[name][variables].attrs = ds_dict_ssp370_mean[name][variables].attrs
            if relative_change:
                ds_dict_change[name][variables].attrs['units_rel'] = '%'
                ds_dict_change[name].attrs['change'] = 'Relative Change'
            else:
                ds_dict_change[name].attrs['change'] = 'Absolute Change'
             
    return ds_dict_change

In [None]:
def compute_nbwfp(ds_dict):

    for model, ds in ds_dict.items():
        nbwfp = (ds['mrro']-ds['tran'])/ds['pr']

        # Replace infinite values with NaN
        nbwfp = xr.where(np.isinf(nbwfp), float('nan'), nbwfp)

        # Set all values above 2 and below -2 to NaN
        nbwfp = xr.where(nbwfp > 2, float('nan'), nbwfp)
        nbwfp = xr.where(nbwfp < -2, float('nan'), nbwfp)

        ds_dict[model]['nbwfp'] = nbwfp
        ds_dict[model]['nbwfp'].attrs = {'long_name': 'Net Blue Water Flux / Precipitation',
                             'units': ''}
        
    return ds_dict

#### Plot metrics

In [None]:
def plot_time_series_correlations(yearly_correlations, variable_pairs, target_variable):
    """
    Plots the time series of correlations for each variable pair that includes the target variable.

    Parameters:
    yearly_correlations (dict): The output from calculate_yearly_correlations.
    variable_pairs (list of tuples): The pairs of variables that the correlations were calculated for.
    target_variable (str): The variable that must be included in a pair for it to be plotted.
    """

    # Calculate the number of plots
    n_plots = sum([var1 == target_variable or var2 == target_variable for var1, var2 in variable_pairs])

    # Calculate the dimensions of the grid of subplots
    grid_size = math.ceil(math.sqrt(n_plots))
    
    # Create the figure
    fig, axs = plt.subplots(grid_size, grid_size, figsize=(15, 15), sharex=True, sharey=True)

    # Flatten the axes
    axs = axs.flatten()

    # Create an index for the current plot
    i_plot = 0

    # Prepare a list to store handles and labels for the legend
    handles, labels = [], []

    for var1, var2 in variable_pairs:
        if var1 == target_variable or var2 == target_variable:
            # Select the current axes
            ax = axs[i_plot]
            
            # Construct the correlation variable name
            corr_var = f'{var1}-{var2}'

            # Prepare a list to store correlations of all models
            all_corrs = []

            for name, ds in yearly_correlations.items():
                # Check if this variable exists in the Dataset
                if corr_var in ds:
                    # Plot the time series of the correlation
                    line, = ax.plot(ds['time'], ds[corr_var], label=name)

                    # Append to all_corrs
                    all_corrs.append(ds[corr_var])

                    # Append to handles and labels if not already present
                    if name not in labels:
                        handles.append(line)
                        labels.append(name)

            # Compute the mean correlation across all models
            mean_corr = xr.concat(all_corrs, dim='model').mean(dim='model')
            mean_line, = ax.plot(mean_corr['time'], mean_corr, color='black', linestyle='--')

            ax.set_xlabel('Year')
            ax.set_ylabel('Correlation')
            ax.set_title(f'{var1} vs {var2}')
            ax.grid(True)

            # Increment the plot index
            i_plot += 1

    # Add 'Mean' to the legend
    handles.append(mean_line)
    labels.append('Mean')

    # Show the figure with a legend
    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.1, 0.5))
    plt.tight_layout()
    plt.show()

In [None]:
def plot_mean_correlations(correlations_hist, correlations_ssp370, variable_pairs, target_variable, scale_axis=False, variable_captions=None):
    """
    Plots the mean correlations for each variable pair that includes the target variable.

    Parameters:
    correlations_hist (dict): The output from calculate_correlations for the historical period.
    correlations_ssp370 (dict): The output from calculate_correlations for the SSP370 scenario.
    variable_pairs (list of tuples): The pairs of variables that the correlations were calculated for.
    target_variable (str): The variable that must be included in a pair for it to be plotted.
    scale_axis (bool): Whether to scale the y-axis according to metric value ranges. Default is False.
    """
    
    # Get info
    metric = correlations_hist[list(correlations_hist.keys())[0]].Metric
    metric_sign = correlations_hist[list(correlations_hist.keys())[0]].Metric_sign
    
    # Filter variable pairs
    variable_pairs = [(var1, var2) for var1, var2 in variable_pairs if var1 == target_variable or var2 == target_variable]

    # Calculate the number of plots and dimensions of the grid of subplots
    n_plots = len(variable_pairs)
    n_cols = min(n_plots, 3)  # Maximum 3 plots in a row
    n_rows = int(np.ceil(n_plots / n_cols))

    # Create the figure
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows), squeeze=False)
    axs = axs.flatten()  # Flatten the axes array
    
    for i, (var1, var2) in enumerate(variable_pairs):
        # Prepare lists to store yearly correlation values
        yearly_corr_hist, yearly_corr_ssp370 = [], []

        for name in correlations_hist.keys():
            # Extract the yearly mean values for the historical period
            yearly_corr_hist.append(correlations_hist[name][f'{var1}-{var2}'].values)

            # Extract the yearly mean values for the SSP370 scenario
            yearly_corr_ssp370.append(correlations_ssp370[name][f'{var1}-{var2}'].values)

        # Compute the box plot positions
        positions = np.arange(len(correlations_hist.keys()))

        # Select the current axes
        ax = axs[i]

        # Define an offset for x-values to place box plots for different periods side by side
        offset = 0.15

        # Plot the box plots for the historical period
        ax.boxplot(yearly_corr_hist, positions=positions-offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='cornflowerblue'), medianprops=dict(color='black'), 
                   showfliers=True)

        # Plot the box plots for the SSP370 scenario
        ax.boxplot(yearly_corr_ssp370, positions=positions+offset, widths=0.3, patch_artist=True, 
                   boxprops=dict(facecolor='sandybrown'), medianprops=dict(color='black'), 
                   showfliers=True)

        # Set the x-ticks labels and the title
        ax.set_ylabel(f'{metric_sign}')
        ax.set_title(f'{variable_captions.get(var1, var1)} x {variable_captions.get(var2, var2)}', fontsize=9)
        ax.set_xticks(positions)
        ax.set_xticklabels(list(correlations_hist.keys()), rotation=90)
        
    fig.suptitle(f'Yearly {metric} for Historical (1985-2014) and SSP370 (2071-2100) Period', fontsize=12, y=1.0)
    
    # Set a legend
    axs[0].legend([Patch(facecolor='cornflowerblue'), Patch(facecolor='sandybrown')], ['Historical', 'SSP370'])

    # Handle empty subplots in case n_plots is less than n_rows * n_cols
    for i in range(n_plots, n_rows*n_cols):
        fig.delaxes(axs[i])
    
    # Handle y-axis scaling
    if scale_axis:
        for ax in axs:
            ax.set_ylim([-1, 1] if metric != 'r2_lr' else [0, 1])

    # Adjust the layout
    plt.tight_layout()
    plt.show()
    
    # Save figure
    suffix = "_scaled_axis" if scale_axis else ""
    filename = f"{metric}_changes_{target_variable}{suffix}.png"
    
    savepath = f'../../results/CMIP6/yearly_metrics_comparison'
    os.makedirs(savepath, exist_ok=True)

    filepath = os.path.join(savepath, filename)

    fig.savefig(filepath, dpi=300, bbox_inches='tight')

In [None]:
def ensemble_corr_change_plot(ds_dict, target_variable, full_var_names_and_unit, cmap='viridis', save_fig=False, file_format='png'):
    """
    Plots a map of the specified statistic of the given variable for the Ensemble_mean dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        target_variable (str): The target variable to plot correlations with.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.
    """
    # Info
    experiment_id = ds_dict[list(ds_dict.keys())[0]].experiment_id
    metric = ds_dict[list(ds_dict.keys())[0]].Metric
    metric_sign = ds_dict[list(ds_dict.keys())[0]].Metric_sign
    means = ds_dict[list(ds_dict.keys())[0]].attrs['means']
    target_var_long_name = full_var_names_and_unit[target_variable][0]
    
    # Create a figure
    n_cols = 3  # Set number of columns to 3
    n_rows = 3  # Set number of rows to 3
    
    fig = plt.figure(figsize=[12 * n_cols, 6.5 * n_rows])  # Start with a blank figure, without subplots

    plt.style.use('default')

    # Loop over datasets and plot the requested statistic
    subplot_counter = 0

    # Get the Ensemble_mean dataset
    ensemble_ds = ds_dict.get("Ensemble_mean", None)

    if ensemble_ds is None:
        print("Ensemble_mean dataset not found.")
        return None

    for variable in ensemble_ds.variables:
        if not (f'{target_variable} x ' in variable or f' x {target_variable}' in variable):
            continue

        # Add a new subplot with a cartopy projection
        ax = fig.add_subplot(n_rows, n_cols, subplot_counter+1, projection=ccrs.Robinson())

        data_to_plot = ensemble_ds[variable]
        im = data_to_plot.plot(ax=ax, cmap=cmap, vmin = -1, vmax = 1, transform=ccrs.PlateCarree(), add_colorbar=False)  # Added a cartopy transform to the plot and cmap parameter
        
        if f'{target_variable} x ' in variable:
            corr_var = variable.replace(f'{target_variable} x ', '')
        elif f' x {target_variable}' in variable:
            corr_var = variable.replace(f' x {target_variable}', '')
        else:
            continue
        corr_var_long_name = full_var_names_and_unit[corr_var][0]
        ax.set_title(f'{target_var_long_name} x {corr_var_long_name}', fontsize=18)  # Use the long names in the title
        ax.coastlines()  # Adds lines around the continents

        subplot_counter += 1

    # Adjust spacing between subplots
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=fig.axes, orientation='horizontal', fraction=0.02, pad=0.04, aspect=75, shrink=0.4)
    

    # Set colorbar ticks
    cbar.set_ticks([-0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75])
    
    
    # Set tick size
    cbar.ax.tick_params(labelsize=20)  # Adjust size as needed
    
    # Set colorbar label
    cbar.set_label(f'{metric_sign} change', size=26)  # Adjust size as needed
    
    # Set figure title with first and last year of dataset 
    if experiment_id == 'historical' or experiment_id == 'ssp370':
        fig.suptitle(f"{metric} ({experiment_id}) of Ensemble Mean for Variable Combinations with {target_var_long_name} ({means})", fontsize=20, y=0.9)
    elif experiment_id == 'ssp370-historical':
        fig.suptitle(f"{metric} Change ({experiment_id}) of Ensemble Mean for Variable Combinations with {target_var_long_name} ({means})", fontsize=20, y=0.9)
    
    
    # Show plot
    plt.show()
    
    # Save figure
    if save_fig:
        if experiment_id == 'historical' or experiment_id == 'ssp370':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename = f'ensmean_{target_variable}_correlations.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
        elif experiment_id == 'ssp370-historical':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename =f'ensmean_{target_variable}_correlations_change.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
        
    return filepath

In [None]:
def corr_maps(ds_dict, target_variable_combination, cmap='viridis', save_fig=False, file_format='png'):
    """
    Plots a map of the specified variable combination for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        target_variable_combination (str): The target variable combination to plot.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.
    """
    # Info
    experiment_id = ds_dict[list(ds_dict.keys())[0]].experiment_id
    metric = ds_dict[list(ds_dict.keys())[0]].Metric
    metric_sign = ds_dict[list(ds_dict.keys())[0]].Metric_sign
    means = ds_dict[list(ds_dict.keys())[0]].attrs['means']
    
    # Create a figure
    n_datasets_with_var = sum([1 for ds in ds_dict.values() if target_variable_combination in ds.variables])
    n_cols = 3  # Set number of columns to 3
    n_rows = math.ceil(n_datasets_with_var / n_cols)  # Calculate rows
    
    fig = plt.figure(figsize=[12 * n_cols, 6.5 * n_rows])  # Start with a blank figure, without subplots

    plt.style.use('default')

    # Loop over datasets and plot the requested variable combination
    subplot_counter = 0
    for i, (name, ds) in enumerate(ds_dict.items()):
        if target_variable_combination not in ds.variables:
            continue

        # Add a new subplot with a cartopy projection
        ax = fig.add_subplot(n_rows, n_cols, subplot_counter+1, projection=ccrs.Robinson())
        
        data_to_plot = ds[target_variable_combination]
        im = data_to_plot.plot(ax=ax, cmap=cmap, transform=ccrs.PlateCarree(), add_colorbar=False)
        ax.set_title(name, fontsize=18)
        ax.coastlines()  # Adds lines around the continents

        subplot_counter += 1

    # Adjust spacing between subplots
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=fig.axes, orientation='horizontal', fraction=0.02, pad=0.04, aspect=75, shrink=0.4)
    
    # Set tick size
    cbar.ax.tick_params(labelsize=14)  # Adjust size as needed
    
    # Set colorbar label
    cbar.set_label(metric_sign, size=18)  # Adjust size as needed
    
    # Set figure title with first and last year of dataset 
    if experiment_id == 'historical' or 'ssp370':
        fig.suptitle(f"{metric} ({experiment_id}) for Variable Combination {target_variable_combination} ({means})", fontsize=20, y=0.9)
    elif experiment_id == 'ssp370-historical':
        fig.suptitle(f"{metric} Change ({experiment_id}) for Variable Combination {target_variable_combination} ({means})", fontsize=20, y=0.9)
    
    # Show plot
    plt.show()
    
    # Save figure
    if save_fig:
        if experiment_id == 'historical' or 'ssp370':
            savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, 'time', 'mean', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename = f'{target_variable_combination}_correlation.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
        elif experiment_id == 'ssp370-historical':
            savepath = os.path.join('../..', 'results', 'CMIP6', 'comparison', 'corr_maps', means)
            os.makedirs(savepath, exist_ok=True)
            filename =f'{target_variable_combination}_correlation_change.{file_format}'
            filepath = os.path.join(savepath, filename)
            fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

### Load netCDF files

In [None]:
# ========= Define period, models and path ==============
variables=['tas', 'pr', 'vpd', 'evspsbl', 'mrro', 'lmrso_1m', 'lmrso_2m', 'tran', 'lai', 'gpp', 'EI', 'wue']
experiment_id = 'historical'
source_id = ['TaiESM1', 'BCC-CSM2-MR',  'CanESM5', 'CNRM-CM6-1', 'CNRM-ESM2-1', 'IPSL-CM6A-LR', 'UKESM1-0-LL', 'MPI-ESM1-2-LR', 'CESM2-WACCM', 'NorESM2-MM', 'Ensemble mean', 'Ensemble median'] #
folder='preprocessed'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# Create dictionary using a dictionary comprehension and Dask
ds_dict_hist = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, variables) for model in source_id})[0]

In [None]:
# ============= Have a look into the data ==============
print(ds_dict_hist.keys())
ds_dict_hist[list(ds_dict_hist.keys())[0]]

In [None]:
# Define the lat/lon bounds for each region
regions = {
        'Greenland': {'lat': slice(59, 89), 'lon': slice(-65, -12)},
        'Northern North America': {'lat': slice(40, 73), 'lon': slice(-170, -53)},
        'Southern North America': {'lat': slice(15, 40), 'lon': slice(-130, -55)},
        'Northern South America': {'lat': slice(-25, 15), 'lon': slice(-85, -30)},
        'Southern South America': {'lat': slice(-59, -25), 'lon': slice(-80, -40)},
        'Northern Europe': {'lat': slice(45, 89), 'lon': slice(-12, 40)},
        'Mediterranean and Middle East': {'lat': slice(30, 45), 'lon': slice(-12, 55)},
        'Sahara': {'lat': slice(12, 30), 'lon': slice(-20, 55)},
        'Sub-Sahara Africa': {'lat': slice(-35, 12), 'lon': slice(-20, 55)},
        'Northern Asia': {'lat': slice(45, 89), 'lon': slice(40, 179)},
        'Southwest Asia': {'lat': slice(0, 45), 'lon': slice(55, 90)},
        'Southeast Asia': {'lat': slice(-11, 45), 'lon': slice(90, 165)},
        'Oceania': {'lat': slice(-50, -11), 'lon': slice(110, 180)}
    }

In [None]:
ds_dict_hist_regions = slice_to_regions(ds_dict_hist, regions)

In [None]:
ds_dict_ssp370_regions = slice_to_regions(ds_dict_ssp370, regions)

### Select period

In [None]:
ds_dict_hist_period = select_period(ds_dict_hist, start_year=1985, end_year=2014)

In [None]:
ds_dict_ssp370_period = select_period(ds_dict_ssp370, start_year=2071, end_year=2100)

### Compute statistics

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_hist_period_metric = compute_statistic(ds_dict_hist_period, 'mean', 'time')

In [None]:
# ========= Compute statistic for plot ===============
ds_dict_ssp370_period_metric = compute_statistic(ds_dict_ssp370_period, 'mean', 'time')

In [None]:
# ====== Plot timeseries for each variable combination
plot_time_series_correlations(yearly_correlations_hist, variable_pairs, 'pr')

In [None]:
stats_hist = compute_stats(ds_dict_hist)
stats_ssp370 = compute_stats(ds_dict_ssp370)

In [None]:
plot_mean_values(stats_hist, stats_ssp370, scaling_both=False, scaling_hist=True, full_variable_names=full_var_names_and_unit)

In [None]:
def plot_consecutive_periods(ds_dict_hist, ds_dict_ssp370, save_fig=False, fig_name='plot.png'):
    # Define the order of the variables
    variable_order = ['pr', 'vpd', 'mrro', 'evspsbl', 'lmrso_1m', 'lmrso_2m', 'tran', 'gpp', 'lai']

    fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(15, 15))
    axes = axes.flatten()  # Flatten the axes array to easily iterate over
    fig.suptitle('Variable Change for Historical and SSP370 Period', fontsize=14)

    for i, variable in enumerate(variable_order):
        ax = axes[i]  # No need to subtract 1 because we're iterating over variable_order list
        
        # Check if variable exists in the dictionary
        if variable not in ds_dict_hist['Ensemble_mean'].variables:
            print(f"Variable {variable} not found in dataset. Skipping.")
            continue
        
        # Get the data for each period
        data_hist = [ds[variable] for ds in ds_dict_hist.values() if ds != 'Ensemble_mean']
        data_ssp370 = [ds[variable] for ds in ds_dict_ssp370.values() if ds != 'Ensemble_mean']
        
        # Extract ensemble mean from the dictionaries
        mean_hist = ds_dict_hist['Ensemble_mean'][variable]
        mean_ssp370 = ds_dict_ssp370['Ensemble_mean'][variable]

        # Calculate the standard deviation for each period
        std_dev_hist = np.std(data_hist, ddof=1)
        std_dev_ssp370 = np.std(data_ssp370, ddof=1)

        # Plot the means with shaded areas for the model spread
        ax.plot(mean_hist.year, mean_hist, label='Historical')
        ax.fill_between(mean_hist.year, mean_hist - std_dev_hist, mean_hist + std_dev_hist, alpha=0.3)
        ax.plot(mean_ssp370.year, mean_ssp370, label='SSP370')
        ax.fill_between(mean_ssp370.year, mean_ssp370 - std_dev_ssp370, mean_ssp370 + std_dev_ssp370, alpha=0.3)

        # Get units and long_name, or use '' if not present
        unit = ds_dict_hist['Ensemble_mean'][variable].attrs.get('units', '')
        long_name = ds_dict_hist['Ensemble_mean'][variable].attrs.get('long_name', '')

        # Add labels and legend
        ax.set_xlabel('Year')
        ax.set_ylabel(unit)
        ax.set_title(long_name)
        ax.legend()

    # Remove unused subplots
    for i in range(len(variable_order), len(axes)):
        fig.delaxes(axes[i])
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # adjust subplot positions so the title doesn't overlap

    # Save the figure if specified
    if save_fig:
        # Define save path and filename
        statistic_dim = 'space'
        statistic = ds_dict_hist[list(ds_dict_hist.keys())[0]].statistic
        savepath = f'../../results/CMIP6/ssp370-historical/{statistic_dim}/{statistic}/line_plot_variable_change'
        filename = f'Line_plot_variable_change.png'
        filepath = os.path.join(savepath, filename)

    plt.show()

In [None]:
# Call the function
plot_consecutive_periods(ds_dict_hist_mean, ds_dict_ssp370_mean, save_fig=True)

### Compute Ensemble metric

In [None]:
ds_dict_hist_period = compute_ensemble(ds_dict_hist_period)

In [None]:
ds_dict_ssp370_period = compute_ensemble(ds_dict_ssp370_period)

### Compute NBWFP

In [None]:
ds_dict_hist_period = compute_nbwfp(ds_dict_hist_period)

In [None]:
ds_dict_ssp370_period = compute_nbwfp(ds_dict_ssp370_period)

### Select regions

### Regionmask

In [None]:
import regionmask

In [None]:
land_regions = regionmask.defined_regions.ar6.land

In [None]:
land_regions.plot()

In [None]:
tas = ds_dict_hist_period['TaiESM1'].tas

In [None]:
mask_2D = land_regions.mask_3D(tas)

In [None]:
mask_2D.isel(region=1).plot()

In [None]:
proj = ccrs.Robinson()
f, ax = plt.subplots(subplot_kw=dict(projection=proj))

h = mask_2D.plot.pcolormesh(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False)

ax.coastlines()

land_regions.plot_regions(line_kws=dict(lw=0.5), add_label=False);

In [None]:
# ABSOLUTE CHANGE
ds_change = ds_ssp - ds

In [None]:
# RELATIVE CHANGE
ds_change_rel = ((ds_ssp - ds)/ds) * 100

In [None]:
fig = ds['NBWF'].plot(cmap=cmap, vmin=-0.5, vmax=0.5)
fig.figure.text(0.5, 0.9, "Historical Net Blue Water Flux (NBWF)", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_ssp['NBWF'].plot(cmap=cmap, vmin=-0.5, vmax=0.5)
fig.figure.text(0.5, 0.9, "Future Net Blue Water Flux (NBWF)", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_change['NBWF'].plot(cmap=cmap, vmin=-0.5, vmax=0.5)
fig.figure.text(0.5, 0.9, "Net Blue Water Flux Absolute Change", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_change_rel['NBWF'].plot(cmap=cmap, vmin=-100, vmax=100)
fig.figure.text(0.5, 0.9, "Net Blue Water Flux Relative Change", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
print("Historical:")
print(f"Global Mean of Historical NBWF: {ds['NBWF'].mean().values} mm/day")
print(f"Global Median of Historical NBWF: {ds['NBWF'].median().values} mm/day")
print(f"Global Standard Deviation of Historical NBWF: {ds['NBWF'].std().values} mm/day")
print(f"Global Minimum of Historical NBWF: {ds['NBWF'].min().values} mm/day")
print(f"Global Maximum of Historical NBWF: {ds['NBWF'].max().values} mm/day\n")

print("SSP370:")
print(f"Global Mean of SSP370 NBWF: {ds_ssp['NBWF'].mean().values} mm/day")
print(f"Global Median of SSP370 NBWF: {ds_ssp['NBWF'].median().values} mm/day")
print(f"Global Standard Deviation of SSP370 NBWF: {ds_ssp['NBWF'].std().values} mm/day")
print(f"Global Minimum of SSP370 NBWF: {ds_ssp['NBWF'].min().values} mm/day")
print(f"Global Maximum of SSP370 NBWF: {ds_ssp['NBWF'].max().values} mm/day\n")

print("Absolute Change:")
print(f"Global Mean Absolute Change of NBWF: {ds_change['NBWF'].mean().values} mm/day")
print(f"Global Median Absolute Change of NBWF: {ds_change['NBWF'].median().values} mm/day")
print(f"Global Standard Deviation of NBWF Absolute Change: {ds_change['NBWF'].std().values} mm/day")
print(f"Global Minimum of NBWF Absolute Change: {ds_change['NBWF'].min().values} mm/day")
print(f"Global Maximum of NBWF Absolute Change: {ds_change['NBWF'].max().values} mm/day\n")

print("Relative Change:")
print(f"Global Mean Relative Change of NBWF: {ds_change_rel['NBWF'].mean().values} %")
print(f"Global Median Relative Change of NBWF: {ds_change_rel['NBWF'].median().values} %")
print(f"Global Standard Deviation of NBWF Relative Change: {ds_change_rel['NBWF'].std().values} %")
print(f"Global Minimum of NBWF Relative Change: {ds_change_rel['NBWF'].min().values} %")
print(f"Global Maximum of NBWF Relative Change: {ds_change_rel['NBWF'].max().values} %")

In [None]:
fig = ds['NBWF/P'].plot(cmap=cmap, vmin=-0.5, vmax=0.5)
fig.figure.text(0.5, 0.9, "Historical Net Blue Water Flux / Precipitation (NBWF/P)", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_ssp['NBWF/P'].plot(cmap=cmap, vmin=-0.5, vmax=0.5)
fig.figure.text(0.5, 0.9, "SSP370 Net Blue Water Flux / Precipitation (NBWF/P)", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_change['NBWF/P'].plot(cmap=cmap, vmin=-0.2, vmax=0.2)
fig.figure.text(0.5, 0.9, "Net Blue Water Flux / Precipitation (NBWF/P) Absolute Change", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_change_rel['NBWF/P'].plot(cmap=cmap, vmin=-100, vmax=100)
fig.figure.text(0.5, 0.9, "Net Blue Water Flux / Precipitation (NBWF/P) Relative Change", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
print("Historical:")
print(f"Global Mean of Historical BWFB: {ds['BWFB'].mean().values}")
print(f"Global Median of Historical BWFB: {ds['BWFB'].median().values}")
print(f"Global Standard Deviation of Historical BWFB: {ds['BWFB'].std().values}")
print(f"Global Minimum of Historical BWFB: {ds['BWFB'].min().values}")
print(f"Global Maximum of Historical BWFB: {ds['BWFB'].max().values}\n")

print("SSP370:")
print(f"Global Mean of SSP370 BWFB: {ds_ssp['BWFB'].mean().values}")
print(f"Global Median of SSP370 BWFB: {ds_ssp['BWFB'].median().values}")
print(f"Global Standard Deviation of SSP370 BWFB: {ds_ssp['BWFB'].std().values}")
print(f"Global Minimum of SSP370 BWFB: {ds_ssp['BWFB'].min().values}")
print(f"Global Maximum of SSP370 BWFB: {ds_ssp['BWFB'].max().values}\n")

print("Absolute Change:")
print(f"Global Mean Absolute Change of BWFB: {ds_change['BWFB'].mean().values}")
print(f"Global Median Absolute Change of BWFB: {ds_change['BWFB'].median().values}")
print(f"Global Standard Deviation of BWFB Absolute Change: {ds_change['BWFB'].std().values}")
print(f"Global Minimum of BWFB Absolute Change: {ds_change['BWFB'].min().values}")
print(f"Global Maximum of BWFB Absolute Change: {ds_change['BWFB'].max().values}\n")

print("Relative Change:")
print(f"Global Mean Relative Change of BWFB: {ds_change_rel['BWFB'].mean().values} %")
print(f"Global Median Relative Change of BWFB: {ds_change_rel['BWFB'].median().values} %")
print(f"Global Standard Deviation of BWFB Relative Change: {ds_change_rel['BWFB'].std().values} %")
print(f"Global Minimum of BWFB Relative Change: {ds_change_rel['BWFB'].min().values} %")
print(f"Global Maximum of BWFB Relative Change: {ds_change_rel['BWFB'].max().values} %")

In [None]:
fig = ds['BWFB'].plot(cmap=cmap, vmin=0, vmax=5)
fig.figure.text(0.5, 0.9, "Historical Blue Water Flux Balance (BWFB)", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_ssp['BWFB'].plot(cmap=cmap, vmin=0, vmax=5)
fig.figure.text(0.5, 0.9, "SSP370 Blue Water Flux Balance (BWFB)", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_change['BWFB'].plot(cmap=cmap, vmin=0, vmax=5)
fig.figure.text(0.5, 0.9, "Blue Water Flux Balance (BWFB) Absolute Change", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
fig = ds_change_rel['BWFB'].plot(cmap=cmap, vmin=-100, vmax=100)
fig.figure.text(0.5, 0.9, "Blue Water Flux Balance (BWFB) Relative Change", ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

plt.show()

In [None]:
print("Historical:")
print(f"Global Mean of Historical BWFB: {ds['BWFB'].mean().values}")
print(f"Global Median of Historical BWFB: {ds['BWFB'].median().values}")
print(f"Global Standard Deviation of Historical BWFB: {ds['BWFB'].std().values}")
print(f"Global Minimum of Historical BWFB: {ds['BWFB'].min().values}")
print(f"Global Maximum of Historical BWFB: {ds['BWFB'].max().values}\n")

print("SSP370:")
print(f"Global Mean of SSP370 BWFB: {ds_ssp['BWFB'].mean().values}")
print(f"Global Median of SSP370 BWFB: {ds_ssp['BWFB'].median().values}")
print(f"Global Standard Deviation of SSP370 BWFB: {ds_ssp['BWFB'].std().values}")
print(f"Global Minimum of SSP370 BWFB: {ds_ssp['BWFB'].min().values}")
print(f"Global Maximum of SSP370 BWFB: {ds_ssp['BWFB'].max().values}\n")

print("Absolute Change:")
print(f"Global Mean Absolute Change of BWFB: {ds_change['BWFB'].mean().values}")
print(f"Global Median Absolute Change of BWFB: {ds_change['BWFB'].median().values}")
print(f"Global Standard Deviation of BWFB Absolute Change: {ds_change['BWFB'].std().values}")
print(f"Global Minimum of BWFB Absolute Change: {ds_change['BWFB'].min().values}")
print(f"Global Maximum of BWFB Absolute Change: {ds_change['BWFB'].max().values}\n")

print("Relative Change:")
print(f"Global Mean Relative Change of BWFB: {ds_change_rel['BWFB'].mean().values} %")
print(f"Global Median Relative Change of BWFB: {ds_change_rel['BWFB'].median().values} %")
print(f"Global Standard Deviation of BWFB Relative Change: {ds_change_rel['BWFB'].std().values} %")
print(f"Global Minimum of BWFB Relative Change: {ds_change_rel['BWFB'].min().values} %")
print(f"Global Maximum of BWFB Relative Change: {ds_change_rel['BWFB'].max().values} %")

### Timeline and historical ensmean map of variables for regions

In [None]:
ds_dict_hist_yearly = compute_yearly_means(ds_dict_hist_regions)

In [None]:
ds_dict_ssp370_yearly = compute_yearly_means(ds_dict_ssp370_regions)

### Plot Regional Data

In [None]:
def plot_data_with_regions(ds_dict_metric, ds_dict_hist_yearly, ds_dict_ssp370_yearly, regions, model, metric='mean', target_var='pr', save_fig=False):
    
    fig = plt.figure(figsize=(22, 10)) 

    ax_main = fig.add_axes([0.25, 0.25, 0.7, 0.7], projection=ccrs.PlateCarree())  # Define location and size of main plot
    
    # Compute anomalies
    ds_anomaly = ds_dict_metric[model][target_var].copy()
    #global_mean = ds_anomaly.mean()
    #ds_anomaly = ds_anomaly - global_mean
    
    # Define vmin vmax
    if target_var == 'lmrso_1m' or target_var == 'lmrso_2m' or target_var == 'WUE':
        vmin = -50
        vmax = 50
    else:
        vmin = -100
        vmax = 100
    
    # Plot the global data
    img = ds_anomaly.plot(ax=ax_main, vmin=vmin, vmax=vmax, cmap='bwr', transform=ccrs.PlateCarree(), add_colorbar=False)
    
    # Add gridlines
    gridlines = ax_main.gridlines(draw_labels=True, color='black', alpha=0.2, linestyle='--')
    gridlines.top_labels = gridlines.right_labels = False  # Only draw labels on bottom and left side
    
    long_name = {
        'Precipitation': 'Precipitation',
        'Total Runoff': 'Total Runoff',
        'Vapor Pressure Deficit': 'Vapor Pressure Deficit',
        'Evaporation Including Sublimation and Transpiration': 'Evapotranspiration',
        'Transpiration': 'Transpiration',
        'Leaf Area Index': 'Leaf Area Index',
        'Carbon Mass Flux out of Atmosphere Due to Gross Primary Production on Land [kgC m-2 s-1]': 'Gross Primary Production',
        'Total Liquid Soil Moisture Content of 1 m Column': '1 m Soil Moisture',
        'Total Liquid Soil Moisture Content of 2 m Column': '2 m Soil Moisture',
        'Water Use Efficiency': 'Water Use Efficiency' 
    }


    # Define the locations for each line plot
    line_plot_locations = {
        'Greenland': [0.25, 1.0, 0.2, 0.2],
        'Northern North America': [0, 1.0, 0.2, 0.2],
        'Southern North America':[0, 0.66, 0.2, 0.2],
        'Northern South America':[0, 0.33, 0.2, 0.2],
        'Southern South America':[0, 0, 0.2, 0.2],
        'Northern Europe':[0.75, 1.0, 0.2, 0.2],
        'Mediterranean and Middle East':[0.5, 1.0, 0.2, 0.2],
        'Sahara':[0.25, 0, 0.2, 0.2],
        'Sub-Sahara Africa':[0.75, 0, 0.2, 0.2],
        'Northern Asia': [1.0, 1.0, 0.2, 0.2],
        'Southwest Asia': [1.0, 0.33, 0.2, 0.2],
        'Southeast Asia': [1.0, 0.66, 0.2, 0.2],
        'Oceania': [1.0, 0, 0.2, 0.2]
    }
    
    # Define custom legend handles
    handle1_hist = mpatches.Patch(color='blue', alpha=0.2, label='10th-90th Percentile Historical')
    handle2_hist = mpatches.Patch(color='blue', alpha=0.1, label='25th-75th Percentile Historical')
    handle1_ssp370 = mpatches.Patch(color='orange', alpha=0.2, label='10th-90th Percentile SSP370')
    handle2_ssp370 = mpatches.Patch(color='orange', alpha=0.1, label='25th-75th Percentile SSP370')
    legend_handles = [handle1_hist, handle2_hist, handle1_ssp370, handle2_ssp370]
    legend_labels = ['10th-90th Percentile Historical', '25th-75th Percentile Historical', '10th-90th Percentile SSP370', '25th-75th Percentile SSP370']
    
    # Define a list of identifiers
    identifiers = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')

    # Before plotting the subplots, initialize an iterator over the identifiers
    identifiers_iter = iter(identifiers)
    
    offset_lon = 1  # Define an offset in longitude units. You may need to adjust this.
    offset_lat = -1  # Define an offset in latitude units. You may need to adjust this.
    
    ds_percentiles = {}

    # Loop over the regions and plot each one
    for region, coord in regions.items():
        
        identifier = next(identifiers_iter)
        
        lat_min = coord['lat'].start
        lat_max = coord['lat'].stop
        lon_min = coord['lon'].start
        lon_max = coord['lon'].stop
        
        # Calculate the center of the box
        lat_center = (lat_min + lat_max) / 2
        lon_center = (lon_min + lon_max) / 2
        
        # Draw the box
        ax_main.plot([lon_min, lon_max, lon_max, lon_min, lon_min],
                [lat_min, lat_min, lat_max, lat_max, lat_min],
                color='black', transform=ccrs.PlateCarree())
        
        # Add identifier to main plot
        ax_main.text(lon_min + offset_lon, lat_max + offset_lat, identifier, 
                     horizontalalignment='left', verticalalignment='top', 
                     transform=ccrs.PlateCarree(), fontsize=14, color='black', weight='bold')

        # Add title for map
        ax_main.set_title(f"{model} ({metric})")
        
        # Add coastlines
        ax_main.coastlines()
        
        
        # Get all the hist models data together for each region without the Ensemble
        all_hist_models_data = xr.concat([ds_dict_hist_yearly[region][models][target_var]
                                          for models in ds_dict_hist_yearly[region] if models != f'Ensemble_{metric}'], dim='models')
         # Calculate percentiles
        percentile_10_hist = all_hist_models_data.quantile(0.1, dim='models')
        percentile_25_hist = all_hist_models_data.quantile(0.25, dim='models')
        percentile_75_hist = all_hist_models_data.quantile(0.75, dim='models')
        percentile_90_hist = all_hist_models_data.quantile(0.9, dim='models')
        
        # Get all the ssp370 models data together for each region without the ensemble
        all_ssp370_models_data = xr.concat([ds_dict_ssp370_yearly[region][models][target_var]
                                          for models in ds_dict_ssp370_yearly[region] if models != f'Ensemble_{metric}'], dim='models')
        
        # Calculate percentiles
        percentile_10_ssp370 = all_ssp370_models_data.quantile(0.1, dim='models')
        percentile_25_ssp370 = all_ssp370_models_data.quantile(0.25, dim='models')
        percentile_75_ssp370 = all_ssp370_models_data.quantile(0.75, dim='models')
        percentile_90_ssp370 = all_ssp370_models_data.quantile(0.9, dim='models')

        
        # Plot the regional data at the specified location
        ax2 = plt.axes(line_plot_locations[region])
        
        # Add identifier to subplots
        ax2.text(0.05, 0.95, identifier,  # Position is 10% from the left and 10% from the top
                 horizontalalignment='left', verticalalignment='top', 
                 transform=ax2.transAxes, fontsize=14, color='black', weight='bold')
        
        # Add shading for the 10th and 90th percentiles
        ax2.fill_between(ds_dict_hist_yearly[region][model][target_var].year,
                         percentile_10_hist,
                         percentile_90_hist,
                         color='blue', alpha=0.1)
        
        # Add shading for the 25th and 75th percentiles
        ax2.fill_between(ds_dict_hist_yearly[region][model][target_var].year,
                         percentile_25_hist,
                         percentile_75_hist,
                         color='blue', alpha=0.2)
        
        # Add shading for the 10th and 90th percentiles
        ax2.fill_between(ds_dict_ssp370_yearly[region][model][target_var].year,
                         percentile_10_ssp370,
                         percentile_90_ssp370,
                         color='orange', alpha=0.1)
        
        # Add shading for the 25th and 75th percentiles
        ax2.fill_between(ds_dict_ssp370_yearly[region][model][target_var].year,
                         percentile_25_ssp370,
                         percentile_75_ssp370,
                         color='orange', alpha=0.2)
        
        # Plot the lines
        ax2.plot(ds_dict_hist_yearly[region][model][target_var].year,
                 ds_dict_hist_yearly[region][model][target_var],
                 label='Historical')
        ax2.plot(ds_dict_ssp370_yearly[region][model][target_var].year,
                 ds_dict_ssp370_yearly[region][model][target_var],
                 label='SSP370')
        
        ax2.set_title(region)
        ax2.set_ylabel(f"{ds_dict_metric[f'Ensemble_{metric}'][target_var].units}")#, fontsize=8)

        
    # Get line handles and labels from your plot
    line_handles, line_labels = ax2.get_legend_handles_labels()
    
    # Add line handles and labels to your custom handles and labels
    line_handles += legend_handles
    line_labels.extend(legend_labels)
    
    # Add legend for subplots, use legend_handles and legend_labels instead of getting them from the last plot
    fig.legend(handles=line_handles, labels=line_labels, loc='center', bbox_to_anchor=(0.6, 0.1), ncol=2)

    # Move the colorbar to the bottom
    cbar_ax = fig.add_axes([0.475, 0.23, 0.25, 0.02]) #left, bottom, width, height
    cbar = fig.colorbar(img, cax=cbar_ax, orientation='horizontal')
    
    # Add a label to the colorbar
    cbar.set_label(f"Relative Change (SSP370-historical) of {long_name[ds_dict_metric[f'Ensemble_{metric}'][target_var].long_name]} [{ds_dict_metric[f'Ensemble_{metric}'][target_var].units_rel}]", fontsize=12)

    plt.show()
    
    # Safe figure
    if save_fig:
        savepath = os.path.join('../..', 'results', 'CMIP6', 'ssp370-historical', 'time-space', metric, f'var_line_plots_and_anomaly_of_{metric}')
        os.makedirs(savepath, exist_ok=True)
        filename = f'{model}.{target_var}.line_plot.relative_change.png'
        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=300, bbox_inches='tight')
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

In [None]:
plot_data_with_regions(ds_dict_change_rel, ds_dict_hist_yearly, ds_dict_ssp370_yearly, regions, model='TaiESM1', metric='median', target_var='lai', save_fig=True)

In [None]:
# ========= Compute Correlations ========

In [None]:
ds_dict_hist_corr = compute_correlation_coefficients(ds_dict_hist, variables)

In [None]:
ds_dict_ssp370_corr = compute_correlation_coefficients(ds_dict_ssp370, variables)

In [None]:
# ========= Compute Change ========
ds_dict_corr_change = compute_change_corr(ds_dict_hist_corr, ds_dict_ssp370_corr)

In [None]:
# ======= Create unique variable pairs ========
variable_pairs = [(variables[i], variables[j]) for i in range(len(variables)) for j in range(i+1, len(variables))]

In [None]:
# ======== Create Plots for all variable pairs
for var in variable_pairs:
    corr_maps(ds_dict_corr_change, f'{var[0]} x {var[0+1]}', 'coolwarm', save_fig=True)

In [None]:
ensemble_corr_change_plot(ds_dict_corr_change, 'mrro', full_var_names_and_unit, 'coolwarm', save_fig=True)