# Master function for Lisa's paper on inter-comparison precipitation extremes at regional scale

In [1]:
# Import modules 
import pandas as pd
import os
import fnmatch
import xarray as xr
import xskillscore as xs
import numpy as np
import math
#from scipy import stats
from scipy.stats import ks_2samp
from scipy.stats import mannwhitneyu
import matplotlib.pyplot as plt
import matplotlib.colors as colors

### Define names of datasets

In [2]:
# List names of all precipitation datasets
data_names = ["3B42_IR_v7.0", "3B42_MW_v7.0", "3B42RT_UNCAL_v7.0", "3B42RT_v7.0","3B42_v7.0", "ARC2", "CFSR", "CHIRPS_v2.0", "CHIRP_V1", "CMORPH_v1.0_CRT", "CMORPH_v1.0_RAW",
            "COSCH", "CPC_v1.0", "ERA5", "ERAi", "GPCC_FDD_v1.0", "GPCC_FDD_v2018", "GPCC_FDD_v2020", "GPCC_FDD_v2022", "GPCC_FG_v1.0", "GPCP_CDR_v1.3_not-enforced", "GPCP_CDR_v1.3_yes-enforced", 
             "GPCP_IP", "GPCP_V3.2", "GSMAP-gauges-NRT-v6.0", "GSMAP-nogauges-NRT-v6.0", "GSWP3", "HOAPSv4.0", "IMERG_V06_EU", "IMERG_V06_FC", "IMERG_V06_FU", "IMERG_V06_LU", "JRA-55", "MERRA1", 
             "MERRA2", "PERSIANN_CCS_CDR", "PERSIANN_v1_r1", "REGEN_ALL_2019", "REGEN_LONG_V1", "SM2RAIN-ASCAT", "TAMSAT_v3.1", "TAMSAT_v3", "TAPEER_v1.5", "GSMAP-NRT-gauges-v8.0", "IMERG-v07B-FC", "GIRAFE", 
              "IMERG_V07B_FC", "GsMAP-gauges-NRT-v8", "PERSIANN_v1_r1", "CFSR"]

### Get Model Names

In [3]:
# Use the dictionaries above to get the dataset names from each file
def get_data_name(names, file):
    
    # Check the model name is in one of the above data lists
    for name in names:
        if fnmatch.fnmatch(file, "*" + name + "_*"):
            return name
    
    return None


In [4]:
# Get model file paths and add to a Pandas Dataframe 
def get_data_files(data_master_path):
    
    # Initialise pandas df of model_paths
    data_paths = pd.DataFrame(columns=['dataset', 'dataset_path'])

    # Loop through the data directory passed to the function, check that the file is a model in the above lists, and get the dataset paths
    for data_file in os.listdir(data_master_path):
    
        # Get the RCM and GCM files names
        data_name = get_data_name(data_names, data_file)
    
        # Get full path to the dataset 
        data_complete_file_path = data_master_path + data_file
            
        # Add information to DataFrame
        data_paths.loc[len(data_paths.index)] = [data_name, data_complete_file_path]
            
    return data_paths

### Get Model Data Paths for Subset

In [5]:
# Get model file paths for subset of models from list defined above
def get_data_files_subset(data_paths, subset_names):
    
    # Initialize new Pandas Dataframe to hold data paths for subsets of models
    data_paths_subset = pd.DataFrame(columns=['data_name', 'dataset_path'])

    for i, row in data_paths.iterrows():
        data_name = f'{row[0]}'
    
        if data_name in subset_names:
            data_paths_subset.loc[len(data_paths_subset.index)] = [data_name, f'{row[1]}']
        
    return data_paths_subset

### Get Data from File Paths

In [6]:
# Extract data from files based on user specifications 
def get_data_from_file(file_path, variable,lat_slice, lon_slice, season=None, iscale=None):
    
    # If iscale is not given, check other variable exceptions before extracting data
    if iscale is None:
    
        # If the variable is CDD/CWD (Consecutive Dry/Wet Days) or R*mm (number of heavy rain days), the data must be read in differently to convert the variable dtype to float
        if variable == 'cdd':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.cdd.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.cdd.attrs['units'] = 'days'
            data_var = data_ds.cdd.sel(lat=lat_slice, lon=lon_slice)
            
        elif variable == 'cwd':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.cwd.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.cwd.attrs['units'] = 'days'
            data_var = data_ds.cwd.sel(lat=lat_slice, lon=lon_slice)
    
        elif variable == 'r10mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r10mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r10mm.attrs['units'] = 'days'
            data_var = data_ds.r10mm.sel(lat=lat_slice, lon=lon_slice)
        
        elif variable == 'r20mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r20mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r20mm.attrs['units'] = 'days'
            data_var = data_ds.r20mm.sel(lat=lat_slice, lon=lon_slice)
        
        elif variable == 'r30mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r30mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r30mm.attrs['units'] = 'days'
            data_var = data_ds.r30mm.sel(lat=lat_slice, lon=lon_slice)
            
        elif variable == 'fracprday':
            
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.fracprday.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.fracprday.attrs['units'] = 'days'
            data_var = data_ds.fracprday.sel(lat=lat_slice, lon=lon_slice)
            
        elif variable == 'tas':
            data_ds = xr.open_dataset(file_path)
            data_var=data_ds.tas.sel(lat=lat_slice, lon=lon_slice)
            
    
        else:
    
            # Open file and store data as a DataArray
            data_ds = xr.open_dataset(file_path)

            # Subset data based on user-specified spatiotemporal boundaries
            data_subset = data_ds.sel(lat=lat_slice, lon=lon_slice)
    
            # Extract user-specified variable
        
            data_var = getattr(data_subset, variable)
   
    # Extract appropriate SPI data based on the user-defined scale if the iscale is not None
    # (1 for 3-month averaging period, 2 for 6-month averaging period, 3 for 12-month averaging period)
    else:
        
        # Extract user-specified SPI data
        data_ds = xr.open_dataset(file_path)
        data_var = data_ds.spi.sel(lat=lat_slice, lon=lon_slice)
        
    if season is None: 
        pass
        
    else:
        data_var = data_var.sel(time=data_var.time.dt.month.isin(season))
    
    return data_var

## Spatially Averaged Functions

The following functions calculate spatially averaged metrics, primarily for plotting time series and scatterplots. <br>
'season=None' sets an optional argument for seasonal subsetting based on numeric month values <br>
'mask=None' sets an optional argument for masking, where the default is None <br>
'scale=None' sets an optional argument for selecting the scale of the 4-D SPI indice. This should only be assigned a value <br>
    if the variable is SPI. The scale can be 3, 6, or 12 months. <br>
**Functions Include:** <br>
- Weighted Spatial Mean at Default Time Step
- Monthly Averages over time/space to gauge seasonality (Homogenous Variable Names, like Climpact)
- Annual Averages
- Annual Time Series
- Weighted Spatial Average

### Weighted Spatial Mean at Default Time Step (No Temporal Averaging)

In [7]:
# Get monthly averages to use for a line plot over a user-specified time period and spatial domain
def get_weighted_spatial_average_at_default_time_step(data_path, variable, lat_slice, lon_slice, season=None, iscale=None, mask=None):
    
    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, season)
    
    else:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, season, iscale)
              
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
    
    # Compute latitudinally weighted spatially averaged data
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)

    # Get weighted averages at each time step
    data_weighted_mean = data_weighted.mean(('lon', 'lat'))
    #data_weighted_mean = data.mean(('longitude', 'latitude'))
    
    # Return weighted spatial average at default time step
    return data_weighted_mean

### Get Annual Time Series

In [9]:
# Get annual averages to use for a line plot over a user-specified time period and spatial domain
def get_annual_average(data_path, variable, lat_slice, lon_slice, iscale=None, mask=None, region_mask=None, region=None):
    
    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        data = get_data_from_file(data_path, variable, lat_slice, lon_slice)
    
    else:
        data = get_data_from_file(data_path, variable,lat_slice, lon_slice, iscale)
    
    # Check if we are doing an IPCC sub-region
    if region is None:
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
    
    # Compute latitudinally weighted spatially averaged data
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)

    # Get weighted averages at each time step
    data_weighted_mean = data_weighted.mean(('lon', 'lat'))

    # Get weighted monthly averages
    data_weighted_mean_ann = data_weighted_mean.groupby('time.year').mean(dim='time')
    
    return data_weighted_mean_ann

## Plotting Functions

### Truncate a color map (skip colors in pre-defined colormap)

In [10]:
# Function to be able to skip colors/subset colors within a predefined colormap
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

### Define vertical bands (regions) to shade along a time series

In [11]:
# Define function to find areas that need to be shaded using a boolean mask as input
def fill_vertical_columns(boolean_fill_mask):
    
    # Find areas in the mask when values change (i.e. boolean switched from True to False and vice versa)
    boolean_switch = np.diff(boolean_fill_mask)
    
    # Find start and end of sections where the boolean fill mask is True (places I want to shade)
    region_to_shade, = boolean_switch.nonzero()
    
    # Handle edge cases where condition starts or ends with True
    if boolean_fill_mask[0]:
        region_to_shade = np.r_[0, region_to_shade]
   
    if boolean_fill_mask[-1]:
        region_to_shade = np.r_[region_to_shade, len(boolean_fill_mask)]
    
    # Reshape the result into pairs of start/end indices
    region_to_shade = region_to_shade.reshape((-1, 2))
    
    return region_to_shade

## Get spatial map

In [12]:
def get_climatology(data_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=True, region_mask=None, region=None):
    
    if iscale is None:
        data = get_data_from_file(data_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        data = get_data_from_file(data_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
    
    # Check if we are doing an IPCC sub-region
    if region is None:
      
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1) 
    
    
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)
    
    # Extract season if one is given
    if season is None:
        
        # Calculate Annual climatology datasets
        data_climatology = data_weighted.mean(dim='time')
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        data_climatology = data_weighted.sel(time=data.time.dt.month.isin(season)).mean(dim='time')
        
    return data_climatology

In [13]:
def get_map(data_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=True, region_mask=None, region=None):
    
    if iscale is None:
        data = get_data_from_file(data_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        data = get_data_from_file(data_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
    
    # Check if we are doing an IPCC sub-region
    if region is None:
      
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1) #(mask>50,drop=True) - for only landmask in obs only plot 
    
    
    # Extract season if one is given
    if season is None:
        
        # Calculate Annual climatology datasets
        data_map = data
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        data_map = data.sel(time=data.time.dt.month.isin(season))
        
    return data_map

In [14]:
def get_map2(data_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=True, region_mask=None, region=None):
    
    if iscale is None:
        data = get_data_from_file(data_path, variable, lat_slice, lon_slice).sel(year=time_slice)
    
    else:
        data = get_data_from_file(data_path, variable,lat_slice, lon_slice, iscale).sel(year=time_slice)
    
    # Check if we are doing an IPCC sub-region
    if region is None:
      
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1) #(mask>50,drop=True) - for only landmask in obs only plot 
    
    
    # Extract season if one is given
    if season is None:
        
        # Calculate Annual climatology datasets
        data_map = data
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        data_map = data.sel(time=data.time.dt.month.isin(season))
        
    return data_map

### Get bias map and K-S test

In [15]:
def get_bias(global_path, regional_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None, region_mask=None, region=None):
    
    if iscale is None:
        
        global_data = get_data_from_file(global_path, variable, lat_slice, lon_slice).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        global_data = get_data_from_file(global_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
    
    # Check if we are doing an IPCC sub-region
    if region is None:
      
        pass
    
    else:
        
        global_data = global_data.where(region_mask==region)
        regional_data = regional_data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        global_data = global_data.where(mask==1) 
        regional_data =regional_data.where(mask==1)
    
    # Extract season if one is given
    if season is None:
        
        # Calculate Annual climatology datasets
        global_climatology = global_data.mean(dim='time')
        regional_climatology = regional_data.mean(dim='time')
        
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        global_climatology = global_data.sel(time=global_data.time.dt.month.isin(season)).mean(dim='time')
        regional_climatology = regional_data.sel(time=regional_data.time.dt.month.isin(season)).mean(dim='time')

    # Calculate Model bias
    bias = (global_climatology - regional_climatology)
    
    return bias
    

## relative bias

In [1]:
def get_bias_rel(global_path, regional_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None, region_mask=None, region=None):
    
    if iscale is None:
        
        global_data = get_data_from_file(global_path, variable, lat_slice, lon_slice).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        global_data = get_data_from_file(global_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
    
    # Check if we are doing an IPCC sub-region
    if region is None:
      
        pass
    
    else:
        
        global_data = global_data.where(region_mask==region)
        regional_data = regional_data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        global_data = global_data.where(mask==1) 
        regional_data =regional_data.where(mask==1)
    
    # Extract season if one is given
    if season is None:
        
        # Calculate Annual climatology datasets
        global_climatology = global_data.mean(dim='time')
        regional_climatology = regional_data.mean(dim='time')
        
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        global_climatology = global_data.sel(time=global_data.time.dt.month.isin(season)).mean(dim='time')
        regional_climatology = regional_data.sel(time=regional_data.time.dt.month.isin(season)).mean(dim='time')

    # Calculate Model bias
    bias = (global_climatology - regional_climatology)/regional_climatology*100
    
    return bias
    

### Get K-S test for two extreme distribution - between regional and global dataset

K-S test 2 samples for 3-D xarray data. 

In [16]:
import numpy as np
import xarray as xr
from scipy.stats import ks_2samp

def ks_test_3d_xarray(data1, data2):
    """
    Apply K-S test (2 samples) to compare two 3-D xarray DataArrays (time, lat, lon)
    with NaN values.
    
    Parameters:
        data1: xarray DataArray of shape (time, lat, lon)
        data2: xarray DataArray of shape (time, lat, lon)
        
    Returns:
        p_values: xarray DataArray of p-values with shape (lat, lon)
    """
    # Check if the shapes of the input data are the same
    assert data1.shape == data2.shape, "Input data arrays must have the same shape"

    lat, lon = data1.sizes['lat'], data1.sizes['lon']
    p_values = np.zeros((lat, lon))

    # Perform K-S test along the time dimension for each (lat, lon) pair
    for i in range(lat):
        for j in range(lon):
            # Extract data for the (lat, lon) pair
            d1 = data1[:, i, j].values
            d2 = data2[:, i, j].values
            
            # Remove NaN values and perform K-S test
            mask = ~np.isnan(d1) & ~np.isnan(d2)
            if np.any(mask):
                _, p_value = ks_2samp(d1[mask], d2[mask])
                p_values[i, j] = p_value
            else:
                # If there are no valid values, set p-value to NaN
                p_values[i, j] = np.nan

    # Create an xarray DataArray for p-values
    p_values_xr = xr.DataArray(p_values, dims=('lat', 'lon'), coords={'lat': data1['lat'], 'lon': data1['lon']})
    
    return p_values_xr

# Example usage:
# Generate sample data (replace this with your actual data)
time_steps = 100
lat = 10
lon = 10

# Create sample data with NaN values
data1 = xr.DataArray(np.random.rand(time_steps, lat, lon), dims=('time', 'lat', 'lon'), 
                     coords={'lat': np.linspace(-90, 90, lat), 'lon': np.linspace(-180, 180, lon)})
data2 = xr.DataArray(np.random.rand(time_steps, lat, lon), dims=('time', 'lat', 'lon'), 
                     coords={'lat': np.linspace(-90, 90, lat), 'lon': np.linspace(-180, 180, lon)})

# Introduce NaN values
data1.values[0, 2, 3] = np.nan
data2.values[0, 2, 3] = np.nan

# Apply K-S test
p_values = ks_test_3d_xarray(data1, data2)

print("Shape of p-values array:", p_values.shape)


Shape of p-values array: (10, 10)


## Manm-kendal Test for difference between global and regional dataset

In [17]:
def get_mk(regional_path, global_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None,egion_mask=None, region=None):
    
    if iscale is None:
        
        global_data = get_data_from_file(global_path, variable, lat_slice, lon_slice).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        global_data = get_data_from_file(global_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
    
    # Check if we are doing an IPCC sub-region
    if region is None:
      
        pass
    
    else:
        
        global_data = global_data.where(region_mask==region)
        regional_data = regional_data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        global_data = global_data.where(mask==1) 
        regional_data =regional_data.where(mask==1)
    if season is None:
        
        # Calculate Annual climatology datasets
        regional_data = regional_data
        global_data = global_data
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        regional_data = regional_data.sel(time=regional_data.time.dt.month.isin(season))
        global_data = global_data.sel(time=global_data.time.dt.month.isin(season))
    
    # Calculate Mann Whitney statistic and p-value at each grid point (across time; axis 0)
    mann_whitney_stat = mannwhitneyu(regional_data, global_data, axis=0)

    # Reformat the pvalue output to have the correct lat/lon coordinates
    pvalue_mk_gridded = xr.DataArray(mann_whitney_stat.pvalue, coords={'lat': regional_data['lat'], 'lon': regional_data['lon']}, dims=['lat', 'lon'])
    
    return pvalue_mk_gridded

## Find a commom shape (time, lat, long) between 2 datasets
return of time slice

In [18]:
import xarray as xr

def find_common_slices(data1, data2):
    """
    Find the common slices (time, lat, lon) between two 3-D xarray DataArrays (time, lat, lon)
    with NaN values and different time steps, numbers of latitudes, and numbers of longitudes.
    
    Parameters:
        data1: xarray DataArray of shape (time, lat, lon)
        data2: xarray DataArray of shape (time, lat, lon)
        
    Returns:
        common_slices: dictionary containing slices for each dimension
                       {'time_slice': slice, 'lat_slice': slice, 'lon_slice': slice}
                       or None if no common slices are found
    """
    # Extract dimensions from both DataArrays
    dims1 = set(data1.dims)
    dims2 = set(data2.dims)
    
    # Find the common dimensions
    common_dims = dims1.intersection(dims2)
    
    # Check if time, lat, and lon are common dimensions
    if 'time' in common_dims and 'lat' in common_dims and 'lon' in common_dims:
        # Find common slices for each dimension
        common_slices = {'time_slice': slice(None), 'lat_slice': slice(None), 'lon_slice': slice(None)}
        
        for dim in ['time', 'lat', 'lon']:
            common_coords = np.intersect1d(data1[dim].values, data2[dim].values)
            common_slices[dim+'_slice'] = slice(common_coords.min(), common_coords.max()+1)
        
        return common_slices
    else:
        return None

In [1]:
import xarray as xr
import numpy as np

def find_common_slices2(data1, data2):
    """
    Find the common slices (year, lat, lon) between two 3-D xarray DataArrays (time, lat, lon)
    with NaN values and different time steps, numbers of latitudes, and numbers of longitudes.
    
    Parameters:
        data1: xarray DataArray of shape (year, lat, lon)
        data2: xarray DataArray of shape (year, lat, lon)
        
    Returns:
        common_slices: dictionary containing slices for each dimension
                       {'year_slice': slice, 'lat_slice': slice, 'lon_slice': slice}
                       or None if no common slices are found
    """
    # Extract dimensions from both DataArrays
    dims1 = set(data1.dims)
    dims2 = set(data2.dims)
    
    # Find the common dimensions
    common_dims = dims1.intersection(dims2)
    
    # Check if year, lat, and lon are common dimensions
    # Check if time, lat, and lon are common dimensions
    if 'year' in common_dims and 'lat' in common_dims and 'lon' in common_dims:
        # Find common slices for each dimension
        common_slices = {'year_slice': slice(None), 'lat_slice': slice(None), 'lon_slice': slice(None)}
        
        for dim in ['year', 'lat', 'lon']:
            common_coords = np.intersect1d(data1[dim].values, data2[dim].values)
            common_slices[dim+'_slice'] = slice(common_coords.min(), common_coords.max()+1)
        
        return common_slices
    else:
        return None

## RMSE

In [19]:
# Get RMSE between model simulation and observational dataset. This function assumes area weighting and homogenous variable name between datsets
def get_rmse(global_path, regional_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None, region_mask=None, region=None):

    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        
        global_data = get_data_from_file(global_path, variable, lat_slice, lon_slice).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        global_data = get_data_from_file(global_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
     
    
    if region is None:
      
        pass
    
    else:
        
        global_data = global_data.where(region_mask==region)
        regional_data = regional_data.where(region_mask==region)
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        global_data = global_data.where(mask==1) 
        regional_data =regional_data.where(mask==1)
    
        
    if season is None:
        global_climatology = global_data.mean(dim='time')
        regional_climatology = regional_data.mean(dim='time')
        
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        global_climatology = global_data.sel(time=global_data.time.dt.month.isin(season)).mean(dim='time')
        regional_climatology = regional_data.sel(time=regional_data.time.dt.month.isin(season)).mean(dim='time')
    
    
    # Calculate weights
    weights = np.cos(np.deg2rad(regional_climatology.lat))
    weights.name = "weights"
    weights2d_0 = np.expand_dims(weights.to_numpy(), axis=1)
    weights2d = np.repeat(weights2d_0, len(regional_climatology.lon), axis=1)
    
    # Convert weights to Xarray for input into correlation function
    weights2d_xr = xr.DataArray(weights2d, coords={'lat': regional_climatology.lat, 'lon': regional_climatology.lon}, dims=['lat', 'lon'])
        
    # Calculate Weighted Spatial (i.e. Pattern) Correlation and reduce to 2 decimals
    rmse_i = xs.rmse(regional_climatology, global_climatology, dim=['lat', 'lon'], weights=weights2d_xr, skipna=True)
    rmse = rmse_i.astype(float).round(2)
    
    return rmse

In [4]:
# Get RMSE between model simulation and observational dataset. This function assumes area weighting and homogenous variable name between datsets
def get_mape(global_path, regional_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None, region_mask=None, region=None):

    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        
        global_data = get_data_from_file(global_path, variable, lat_slice, lon_slice).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        global_data = get_data_from_file(global_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
     
    
    if region is None:
      
        pass
    
    else:
        
        global_data = global_data.where(region_mask==region)
        regional_data = regional_data.where(region_mask==region)
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        global_data = global_data.where(mask==1) 
        regional_data =regional_data.where(mask==1)
    
        
    if season is None:
        global_climatology = global_data.mean(dim='time')
        regional_climatology = regional_data.mean(dim='time')
        
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        global_climatology = global_data.sel(time=global_data.time.dt.month.isin(season)).mean(dim='time')
        regional_climatology = regional_data.sel(time=regional_data.time.dt.month.isin(season)).mean(dim='time')
    
    
    # Calculate weights
    weights = np.cos(np.deg2rad(regional_climatology.lat))
    weights.name = "weights"
    weights2d_0 = np.expand_dims(weights.to_numpy(), axis=1)
    weights2d = np.repeat(weights2d_0, len(regional_climatology.lon), axis=1)
    
    # Convert weights to Xarray for input into correlation function
    weights2d_xr = xr.DataArray(weights2d, coords={'lat': regional_climatology.lat, 'lon': regional_climatology.lon}, dims=['lat', 'lon'])
        
    # Calculate Weighted Spatial (i.e. Pattern) Correlation and reduce to 2 decimals
    mape_i = xs.mape(regional_climatology, global_climatology, dim=['lat', 'lon'], weights=weights2d_xr, skipna=True)
    mape = mape_i.astype(float).round(2)
    
    return mape

##MEA

## Timming of extremes

In [20]:
# Get MEA between model simulation and observational dataset. This function assumes area weighting and homogenous variable name between datsets
def get_mae(global_path, regional_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None, region_mask=None, region=None):

    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        
        global_data = get_data_from_file(global_path, variable, lat_slice, lon_slice).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable, lat_slice, lon_slice).sel(time=time_slice)
    
    else:
        global_data = get_data_from_file(global_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
        regional_data = get_data_from_file(regional_path, variable,lat_slice, lon_slice, iscale).sel(time=time_slice)
     
    
    if region is None:
      
        pass
    
    else:
        
        global_data = global_data.where(region_mask==region)
        regional_data = regional_data.where(region_mask==region)
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        global_data = global_data.where(mask==1) 
        regional_data =regional_data.where(mask==1)
    
        
    if season is None:
        global_climatology = global_data.mean(dim='time')
        regional_climatology = regional_data.mean(dim='time')
        
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        global_climatology = global_data.sel(time=global_data.time.dt.month.isin(season)).mean(dim='time')
        regional_climatology = regional_data.sel(time=regional_data.time.dt.month.isin(season)).mean(dim='time')
    
    # Calculate weights
    weights = np.cos(np.deg2rad(regional_climatology.lat))
    weights.name = "weights"
    weights2d_0 = np.expand_dims(weights.to_numpy(), axis=1)
    weights2d = np.repeat(weights2d_0, len(regional_climatology.lon), axis=1)
    
    # Convert weights to Xarray for input into correlation function
    weights2d_xr = xr.DataArray(weights2d, coords={'lat': regional_climatology.lat, 'lon': regional_climatology.lon}, dims=['lat', 'lon'])
        
    # Calculate Weighted Spatial (i.e. Pattern) Correlation and reduce to 2 decimals
    mae_i = xs.mae(regional_climatology, global_climatology, dim=['lat', 'lon'], weights=weights2d_xr, skipna=True)
    mae = mae_i.astype(float).round(2)
    
    return mae

In [21]:
import xarray as xr
import numpy as np

def timming_grid(precip1, precip2):
    """
    Calculate the percentage similarity between two 3D datasets of monthly annual maxima precipitation
    for each grid point (lon, lat).
    
    Parameters:
    - precip1: xarray DataArray, the first 3D dataset (time, lon, lat)
    - precip2: xarray DataArray, the second 3D dataset (time, lon, lat)
    
    Returns:
    - similarity: xarray DataArray, a 2D dataset (lon, lat) with percentage similarity per grid point
    """
    
    # Ensure that both datasets have the same dimensions and coordinates
    if precip1.shape != precip2.shape:
        raise ValueError("Datasets must have the same shape and coordinates.")
    
    # Compare the two datasets element-wise
    matching_values = (precip1 == precip2)
    
    # Calculate the percentage of matching values for each grid point
    r = matching_values.mean(dim="year") * 100  # Averaging along the time dimension
    
    return r


In [1]:
import xarray as xr
import numpy as np

def timming_grid2(precip1, precip2):
    """
    Calculate the percentage similarity between two 3D datasets of monthly annual maxima precipitation
    for each grid point (lon, lat), handling NaN values.
    
    Parameters:
    - precip1: xarray DataArray, the first 3D dataset (time, lon, lat)
    - precip2: xarray DataArray, the second 3D dataset (time, lon, lat)
    
    Returns:
    - similarity: xarray DataArray, a 2D dataset (lon, lat) with percentage similarity per grid point
    """
    
    # Ensure that both datasets have the same dimensions
    if precip1.shape != precip2.shape:
        raise ValueError("Datasets must have the same shape and coordinates.")
    
    # Find valid (non-NaN) values in both datasets
    valid_mask = ~np.isnan(precip1) & ~np.isnan(precip2)
    
    # Compare the two datasets where both are valid
    matching_values = (precip1 == precip2) & valid_mask
    
    # Count valid values along the time dimension
    valid_counts = valid_mask.sum(dim="year")
    
    # Compute similarity percentage while avoiding division by zero
    similarity = (matching_values.sum(dim="year") / valid_counts) * 100
    
    # Ensure NaNs remain where there were no valid values
    similarity = similarity.where(valid_counts > 0)
    
    return similarity