# Master Utility Notebook

**Used in Conjunction with other notebooks for analysis and plotting** <br>
This notebook contains the definitions for the functions used in the CORDEX-Australasia Benchmarking Framework Analysis which has been accepted as the following manuscript: <br>
Isphording, R.N., L.V. Alexander, M. Bador, D. Green, J. P. Evans, and S. Wales. A Standardized Benchmarking Framework to Assess Downscaled Precipitation Simulations. Accepted in Journal of Climate. in revision <br>
These functions largely assume a common variable name between model and observation files (such as Climpact output files; see https://climpact-sci.org/). <br>
Author: Rachael N. Isphording

In [4]:
# Import modules 
import pandas as pd
import os
import fnmatch
import xarray as xr
import xskillscore as xs
import numpy as np
import math
import itertools
from astropy.stats import circcorrcoef
from astropy import units as u
from scipy.stats import bootstrap
import matplotlib.pyplot as plt
import matplotlib.colors as colors

## File Input Functions

To get data from different directories within my CORDEX-Australasia database, f-strings are used in the analysis notebooks. <br>
**This will likely need to be updated based on your database structure** <br>
- The **CORDEX-Australasia** database is set-up so that the Climpact indices follow the path: <br>
    - parent_directory/CORDEX-Australasia/data/{rcm-or-gcm}/{grid_type}/{climpact-or-daily-data}/{historical-or-future}/{indice-keyword}/{yr-or-mon}/dataset_file.nc <br><br>
- The **observational** dataset (AGCD) is organized differently in the database: <br>
    - parent_directory/CORDEX-Australasia/data/{obs}/{grid_type}/{climpact-or-daily-data}/{indice-keyword}/dataset_file_MON-or-ANN.nc

**Functions Include:**
- Get data model names from dictionaries below
- Get subset of models from dictionary defined in analysis notebooks
- Get model data paths
- Get data from paths
    - This includes exceptions to handle the 4-dimensionality of the SPI variable and the "day" unit of the CD(W)D and R_mm variables

### Define names of models
I select models from my database based on the GCM and RCM names included in the file name. This will likely to be updated for use in other databases or with other ensembles.

In [1]:
# List names of forcing GCMs
gcm_names = [
    "ACCESS1-0"
    , "CanESM2"
    , "CNRM-CM5"
    , "GFDL-ESM2M"
    , "HadGEM2-CC"
    , "HadGEM2-ES"
    , "MIROC5"
    , "MPI-ESM-LR"
    , "MPI-ESM-MR"
    , "NorESM1-M"
]

# List names of RCMs
rcm_names = [
    "CCAM-1704"
    , "CCAM-2008"
    , "CCLM5-0-15"
    , "RegCM4-7"
    , "REMO2015"
    , "WRF360J"
    , "WRF360K"
]

### Get Model Names

In [2]:
# Use the dictionaries above to get the model names from each file
def get_data_name(names, file):
    
    # Check the model name is in one of the above model lists
    for name in names:
        if fnmatch.fnmatch(file, "*" + name + "_*"):
            return name
    
    return None

### Get Model Data Paths

In [4]:
# Get model file paths and add to a Pandas Dataframe 
def get_model_files(model_data_master_path):
    
    # Initialise pandas df of model_paths
    model_paths = pd.DataFrame(columns=['rcm', 'driving_gcm', 'dataset_path'])

    # Loop through the data directory passed to the function, check that the file is a model in the above lists, and get the dataset paths
    for model_file in os.listdir(model_data_master_path):
    
        # Get the RCM and GCM files names
        rcm_name = get_data_name(rcm_names, model_file) 
        driving_gcm_name = get_data_name(gcm_names, model_file)
    
        # If the RCM or GCM name is invalid (not in list), continue
        if rcm_name is None or driving_gcm_name is None: continue
    
        # Get full path to the dataset 
        model_complete_file_path = model_data_master_path + model_file
            
        # Add information to DataFrame
        model_paths.loc[len(model_paths.index)] = [rcm_name, driving_gcm_name, model_complete_file_path]
            
    return model_paths

### Get Model Data Paths for Subset

In [6]:
# Get model file paths for subset of models from list defined above
def get_model_files_subset(model_paths, subset_names):
    
    # Initialize new Pandas Dataframe to hold data paths for subsets of models
    model_paths_subset = pd.DataFrame(columns=['model_name', 'dataset_path'])

    for i, row in model_paths.iterrows():
        model_name = f'{row[1]}   {row[0]}'
    
        if model_name in subset_names:
            model_paths_subset.loc[len(model_paths_subset.index)] = [model_name, f'{row[2]}']
        
    return model_paths_subset

### Get Data from File Paths

In [5]:
# Extract data from files based on user specifications 
def get_data_from_file(file_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None):
    
    # If iscale is not given, check other variable exceptions before extracting data
    if iscale is None:
    
        # If the variable is CDD/CWD (Consecutive Dry/Wet Days) or R*mm (number of heavy rain days), the data must be read in differently to convert the variable dtype to float; this list will need to be expanded if temperature indices are incorporated
        if variable == 'cdd':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.cdd.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.cdd.attrs['units'] = 'days'
            data_var = data_ds.cdd.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
            
        elif variable == 'cwd':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.cwd.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.cwd.attrs['units'] = 'days'
            data_var = data_ds.cwd.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
    
        elif variable == 'r10mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r10mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r10mm.attrs['units'] = 'days'
            data_var = data_ds.r10mm.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
        
        elif variable == 'r20mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r20mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r20mm.attrs['units'] = 'days'
            data_var = data_ds.r20mm.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
        
        elif variable == 'r30mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r30mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r30mm.attrs['units'] = 'days'
            data_var = data_ds.r30mm.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
            
        elif variable == 'fracprday':
            
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.fracprday.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.fracprday.attrs['units'] = 'days'
            data_var = data_ds.fracprday.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
        
        else:
    
            # Open file and store data as a DataArray
            data_ds = xr.open_dataset(file_path)

            # Subset data based on user-specified spatiotemporal boundaries
            data_subset = data_ds.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
    
            # Extract user-specified variable
            data_var = getattr(data_subset, variable)
   
    # Extract appropriate SPI data based on the user-defined scale if the iscale is not None
    # (1 for 3-month averaging period, 2 for 6-month averaging period, 3 for 12-month averaging period)
    else:
        
        # Extract user-specified SPI data
        data_ds = xr.open_dataset(file_path)
        data_var = data_ds.spi.sel(scale=iscale, time=time_slice, lat=lat_slice, lon=lon_slice)
        
    if season is None: 
        pass
        
    else:
        data_var = data_var.sel(time=data_var.time.dt.month.isin(season))
    
    return data_var

## Plotting Functions

Functions used when creating plots/figures. <br>
**Functions Include:** <br>
- Truncating a predefined colormap
- Define vertical bands to shade along a time series

### Truncate a color map (skip colors in pre-defined colormap)

In [4]:
# Function to be able to skip colors/subset colors within a predefined colormap
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

# Code sourced from https://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib

### Define vertical bands (regions) to shade along a time series

In [2]:
# Define function to find areas that need to be shaded using a boolean mask as input
def fill_vertical_columns(boolean_fill_mask):
    
    # Find areas in the mask when values change (i.e. boolean switched from True to False and vice versa)
    boolean_switch = np.diff(boolean_fill_mask)
    
    # Find start and end of sections where the boolean fill mask is True (places I want to shade)
    region_to_shade, = boolean_switch.nonzero()
    
    # Handle edge cases where condition starts or ends with True
    if boolean_fill_mask[0]:
        region_to_shade = np.r_[0, region_to_shade]
   
    if boolean_fill_mask[-1]:
        region_to_shade = np.r_[region_to_shade, len(boolean_fill_mask)]
    
    # Reshape the result into pairs of start/end indices
    region_to_shade = region_to_shade.reshape((-1, 2))
    
    return region_to_shade

## Spatially Averaged Functions

The following functions calculate spatially averaged metrics, primarily for plotting time series and scatterplots. Areal weighting is incorporated into the functions. <br>
'season=None' sets an optional argument for seasonal subsetting based on numeric month values <br>
'mask=None' sets an optional argument for masking, where the default is None <br>
'scale=None' sets an optional argument for selecting the scale of the 4-D SPI indice. This should only be assigned a value <br>
'region_mask=None' sets an optional additional mask for subregions, such as the IPCC regions or NRM regions (assuming the region mask is also a netCDF file. The default is None. <br>
'region=None' should be used with the region_mask to specify which region in the region mask. This was created using the IPCC subregions where each region has a numerical value. The default is None. <br>
    If the variable is SPI. The scale can be 3, 6, or 12 months. <br>
**Functions Include:** <br>
- Weighted Spatial Average from Pre-processed Data (such as a climatology)
- Spatial Pattern Correlation (Homogenous Variable Names, like Climpact output)
- Mean Absolute Percentage Error (Homogenous Variable Names, like Climpact output)
- Monthly Averages over time/space to gauge seasonality (Homogenous Variable Names, like Climpact output)
- Annual Averages
- Normalized Root Mean Square Error for pre-processed data (i.e. maps of climatologies or other calculated values)
- Paired Bootstrapping function to calculate the Circular Correlation Coefficient on a random subset of two sets of data
- Get Weighted Spatial Average at Default Time Step

### Get Weighted Spatial Average from Preprocessed Data (such as a climatology or model bias)

In [1]:
def get_weighted_spatial_average_from_data(data, region_mask=None, region=None):
    
    # Check if there is a mask for a sub-region
    if region is None:
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Compute latitudinally weighted spatially averaged data
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)

    # Get weighted spatial average
    data_weighted_spatial_average = data_weighted.mean(('lon', 'lat'))
    
    return data_weighted_spatial_average

### Get Spatial Correlation

In [2]:
# Get spatial pattern correlation (using the Pearson Correlation Coefficient) between model simulation and observational dataset
def get_spatial_correlation(model_path, obs_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None, centered=True, region_mask=None, region=None):

    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        # Get observational dataset
        obs_data = get_data_from_file(obs_path, variable, time_slice, lat_slice, lon_slice)

        # Get model dataset
        model_data = get_data_from_file(model_path, variable, time_slice, lat_slice, lon_slice)
    
    else:

        # Get observational dataset
        obs_data = get_data_from_file(obs_path, variable, time_slice, lat_slice, lon_slice, iscale)
    
        # Get model dataset
        model_data = get_data_from_file(model_path, variable, time_slice, lat_slice, lon_slice, iscale)
        
    # Check if there is a mask for a sub-region
    if region is None:
        pass
    
    else:
        
        model_data = model_data.where(region_mask==region)
        obs_data = obs_data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        model_data = model_data.where(mask==1)
        obs_data = obs_data.where(mask==1)
    
    if season is None:
        
        # Calculate Annual climatology datasets
        if time_slice is None:
            obs_climatology = obs_data
            model_climatology = model_data
        else:
            obs_climatology = obs_data.mean(dim='time')
            model_climatology = model_data.mean(dim='time')
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        obs_climatology = obs_data.sel(time=obs_data.time.dt.month.isin(season)).mean(dim='time')
        model_climatology = model_data.sel(time=model_data.time.dt.month.isin(season)).mean(dim='time')
    
    # Calculate weighted pattern correlation and reduce to two decimal points
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(obs_data.lat))
    weights.name = "weights"
    
    obs_weighted_mean = obs_climatology.weighted(weights).mean()
    model_weighted_mean = model_climatology.weighted(weights).mean()
    weights2d_0 = np.expand_dims(weights.to_numpy(), axis=1)
    weights2d = np.repeat(weights2d_0, len(obs_data.lon), axis=1)
    
    # Check if user wants centered or uncentered correlation; default is centered for rainfall climatologies
    # Calculate anomalies for centered correlation
    if centered is True:
        obs_input = obs_climatology - obs_weighted_mean
        model_input = model_climatology - model_weighted_mean
    
    # Use means for uncentered correlation
    elif centered is False:
        obs_input = obs_climatology
        model_input = model_climatology
        
    # Calculate weighted covariance and variance
    cov = (obs_input*model_input*weights2d).sum(skipna=True)
    obs_var = ((obs_input**2)*weights2d).sum(skipna=True)
    model_var = ((model_input**2)*weights2d).sum(skipna=True)
    
    # Calculate spatial correlation
    spatial_cor_i = cov / (np.sqrt(obs_var)*np.sqrt(model_var))
    
    # Round spatial correlation to 2 decimal places
    spatial_cor = spatial_cor_i.astype(float).round(2)
    
    return spatial_cor

### Get Mean Absolute Percentage Error

In [3]:
# Get Mean absolute Percentage Error (MAPE) between model simulation and observational dataset. This function assumes area weighting and homogenous variable name between datsets
def get_mape(model_path, obs_path, variable, time_slice, lat_slice, lon_slice, dataype, season=None, iscale=None, mask=None, region_mask=None, region=None):

    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        # Get observational dataset
        obs_data = get_data_from_file(obs_path, variable, time_slice, lat_slice, lon_slice)

        # Get model dataset
        model_data = get_data_from_file(model_path, variable, time_slice, lat_slice, lon_slice)
    
    else:

        # Get observational dataset
        obs_data = get_data_from_file(obs_path, variable, time_slice, lat_slice, lon_slice, iscale)
    
        # Get model dataset
        model_data = get_data_from_file(model_path, variable, time_slice, lat_slice, lon_slice, iscale)
        
        
    # Check data type to properly convert RCM precip data to mm/day from kg/m^2/s
    if data_type == "model":
        model_data = model_data * 86400
    
    else: 
         pass  
    
    # Check if there is a mask for a sub-region
    if region is None:
        pass
    
    else:
        
        model_data = model_data.where(region_mask==region)
        obs_data = obs_data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        model_data = model_data.where(mask==1)
        obs_data = obs_data.where(mask==1)
    
    if season is None:
        
        # Calculate Annual climatology datasets
        obs_climatology = obs_data.mean(dim='time')
        model_climatology = model_data.mean(dim='time')
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        obs_climatology = obs_data.sel(time=obs_data.time.dt.month.isin(season)).mean(dim='time')
        model_climatology = model_data.sel(time=model_data.time.dt.month.isin(season)).mean(dim='time')
    
    # Calculate weights
    weights = np.cos(np.deg2rad(obs_data.lat))
    weights.name = "weights"
    weights2d_0 = np.expand_dims(weights.to_numpy(), axis=1)
    weights2d = np.repeat(weights2d_0, len(obs_data.lon), axis=1)
    
    # Convert weights to Xarray for input into correlation function
    weights2d_xr = xr.DataArray(weights2d, coords={'lat': obs_data.lat, 'lon': obs_data.lon}, dims=['lat', 'lon'])
        
    # Calculate Weighted Spatial (i.e. Pattern) Correlation and reduce to 2 decimals
    mape_i = xs.mape(obs_climatology, model_climatology, dim=['lat', 'lon'], weights=weights2d_xr, skipna=True)
    mape = mape_i.astype(float).round(2)
    
    return mape

### Get Monthly Weighted Averages

In [4]:
# Get monthly averages to use for a line plot over a user-specified time period and spatial domain
def get_monthly_averages(data_path, variable, time_slice, lat_slice, lon_slice, data_type, iscale=None, mask=None, region_mask=None, region=None):
    
    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice)
    
    else:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, iscale)
    
    
    # Check if there is a mask for a sub-region
    if region is None:
        pass
    
    else:
        
        data = data.where(region_mask==region) 
    
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
       
    
    # Check data type to properly convert RCM precip data to mm/day from kg/m^2/s
    if data_type == "model":
        data = data * 86400
    
    else: 
         pass
    
    # Compute latitudinally weighted spatially averaged data
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)

    # Get weighted averages at each time step
    data_weighted_mean = data_weighted.mean(('lon', 'lat'))

    # Get weighted monthly averages
    data_weighted_mean_mon = data_weighted_mean.groupby('time.month').mean(dim='time')
    
    return data_weighted_mean_mon

### Get Annual Time Series

In [2]:
# Get annual averages to use for a line plot over a user-specified time period and spatial domain
def get_annual_average(data_path, variable, time_slice, lat_slice, lon_slice, data_type, iscale=None, mask=None, region_mask=None, region=None):
    
    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice)
    
    else:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, iscale)
    
    # Check data type to properly convert daily precip data to mm/day from kg/m^2/s (model data)
    if data_type == "model":
        data = data * 86400
    
    else: 
        pass
    
    # Check if we are doing an IPCC sub-region
    if region is None:
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
    
    # Compute latitudinally weighted spatially averaged data
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)

    # Get weighted averages at each time step
    data_weighted_mean = data_weighted.mean(('lon', 'lat'))

    # Get weighted monthly averages
    data_weighted_mean_ann = data_weighted_mean.groupby('time.year').mean(dim='time')
    
    return data_weighted_mean_ann

### Get the Normalized Root Mean Square Error for pre-processed Data

In [5]:
def get_nrmse_mean_ppdata(obs_data, model_data):

    weights = np.cos(np.deg2rad(obs_data.lat))
    weights.name = "weights"

    mse = (np.square(np.subtract(model_data, obs_data)).weighted(weights)).mean()
    rmse = math.sqrt(mse)
    
    # Calculate normalized root mean square error using the observational average and reduce to two decimal points
    #Normalized using Obs Mean
    nrmse_i = rmse/((obs_data.weighted(weights)).mean())
    nrmse = np.round(nrmse_i, 2)
    
    return nrmse

### Circular Correlation Coefficient for Bootstrap Confidence Interval

In [3]:
def paired_bootstrap(data1, data2, resample_percentage):

    # Resample using XX percentage of the data
    resample_percentage = resample_percentage
    
    n = len(data1)
    resample_size = int(n * resample_percentage)
    
    # Get random indices of the subset percentage of the data
    indices = np.random.choice(np.arange(n), size=resample_size, replace=False)
    
    # Resample the original data using the subset of random indices
    resampled_data1 = data1[indices]
    resampled_data2 = data2[indices]
    
    # Calculate the circular correlation coefficient 
    circular_corr = circcorrcoef(resampled_data1, resampled_data2)
    
    return circular_corr

### Get Weighted Spatial Average at Default Time Step

In [7]:
def get_weighted_spatial_average_at_default_time_step(data_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None, region_mask=None, region=None):
    
    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, season)
    
    else:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, season, iscale)
              
    # Check if we are masking with an IPCC sub-region or similar type of mask
    if region is None:
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Extract unmasked data if no mask is provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
    
    # Compute latitudinally weighted spatially averaged data
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)

    # Get weighted averages at each time step
    data_weighted_mean = data_weighted.mean(('lon', 'lat'))
    
    # Return weighted spatial average at default time step
    return data_weighted_mean

## Metrics for Maps (Temporally Averaged)

The following functions calculate temporally averaged metrics, primarily for plotting maps. <br>
'data_type=None' assumes rainfall units of mm/day. If data_type = 'model' is used, the function converts rainfall units of kg/m^2/s to mm/day. <br>
'season=None' sets an optional argument for seasonal subsetting based on numeric month values <br>
'scale=None' sets an optional argument for selecting the scale of the 4-D SPI index. This should only be assigned a value if the variable is SPI. The scale can be 3, 6, or 12 months. <br>
'mask=None' sets an optional argument for masking, where the default is None <br>
'region_mask=None' sets an optional additional mask for subregions, such as the IPCC regions or NRM regions (assuming the region mask is also a netCDF file. The default is None. <br>
'region=None' should be used with the region_mask to specify which region in the region mask. This was created using the IPCC subregions where each region has a numerical value. The default is None. <br>
**Functions Include:** <br>
- Get Bias
- Get Climatology
- Get the Amplitude of the Annual Cycle at Each Grid Point
- Get the Phase of the Annual Cycle at Each Grid Point

### Get Model Bias (Model - Obs)

In [10]:
def get_bias(model_path, obs_path, variable, time_slice, lat_slice, lon_slice, data_type=None, season=None, iscale=None, mask=None, region_mask=None, region=None):

    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        # Get observational dataset
        obs_data = get_data_from_file(obs_path, variable, time_slice, lat_slice, lon_slice)

        # Get model dataset
        model_data = get_data_from_file(model_path, variable, time_slice, lat_slice, lon_slice)
    
    else:

        # Get observational dataset
        obs_data = get_data_from_file(obs_path, variable, time_slice, lat_slice, lon_slice, iscale)
    
        # Get model dataset
        model_data = get_data_from_file(model_path, variable, time_slice, lat_slice, lon_slice, iscale)
        
    # Check if we are doing a sub-region
    if region is None:
        pass
    
    else:
        
        model_data = model_data.where(region_mask==region)
        obs_data = obs_data.where(region_mask==region)
        
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        model_data = model_data.where(mask==1) # This is for the combined quality mask; for only land-mask use (mask>50, drop=True)
        obs_data = obs_data.where(mask==1)
    
    # Check if data is model data and convert units to mm/day from kg/m^2/s
    if data_type == "model":
        model_data = model_data * 86400
    
    else: 
        pass
    
    # Extract season if one is given
    if season is None:
        
        # Calculate Annual climatology datasets
        obs_climatology = obs_data.mean(dim='time')
        model_climatology = model_data.mean(dim='time')
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        obs_climatology = obs_data.sel(time=obs_data.time.dt.month.isin(season)).mean(dim='time')
        model_climatology = model_data.sel(time=model_data.time.dt.month.isin(season)).mean(dim='time')

    # Calculate Model bias
    bias = (model_climatology - obs_climatology)
    
    return bias
    

### Get Climatology

In [7]:
def get_climatology(data_path, variable, time_slice, lat_slice, lon_slice, data_type=None, season=None, iscale=None, mask=None, region_mask=None, region=None):

    # Get data
    data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, iscale)
    
    # Check if we are doing an IPCC or other sub-region
    if region is None:
      
        pass
    
    else:
        
        data = data.where(region_mask==region)
    
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1) #(mask>50,drop=True) - for only landmask in obs only plot 
    
    # Check if data is model data and convert units to mm/day from kg/m^2/s
    if data_type == "model":
        data = data * 86400
    
    else: 
        pass
    
    # Extract season if one is given
    if season is None:
        
        # Calculate Annual climatology datasets
        data_climatology = data.mean(dim='time')
    
    else:
        
        # Calculate seasonal climatology datasets; the season is defined by the user as a list of month numbers
        data_climatology = data.sel(time=data.time.dt.month.isin(season)).mean(dim='time')
        
    return data_climatology

### Get Amplitude of Annual Cycle at Each Grid Point

In [1]:
def get_amplitude_of_annual_cycle(file_path, variable, time_slice, lat_slice, lon_slice, data_type=None, iscale=None, mask=None):
    
    # Get data
    data = get_data_from_file(file_path, variable, time_slice, lat_slice, lon_slice, iscale)
        
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
    
    # Check data type and convert model data to mm/day from kg/m^2/s
    if data_type == "model":
        data = data * 86400
    
    else: 
        pass
    
    # Calculate the Annual Cycle at each grid point
    data_ann_cycle = data.groupby('time.month').mean(dim='time')
    
    # Calculate the Amplitude at each grid point (Max - Mean)
    data_amplitude = data_ann_cycle.max(('month')) - data_ann_cycle.mean(('month'))
    
    return data_amplitude

### Get Phase of Annual Cycle at Each Grid Point

In [2]:
def get_phase_of_annual_cycle(data_path, variable, time_slice, lat_slice, lon_slice, iscale=None, mask=None):
    
    # Get data
    data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, iscale)
        
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
    
    # Calculate the Annual Cycle at each grid point
    data_ann_cycle = data.groupby('time.month').mean(dim='time')   
    
    # Get phase at each grid point (month of maximum)
    # Set up empty xarray with correct lat and lon coordinates
    data_phase = data.isel(time=0)

    # Loop through each grid point and get the phase
    for i, j in itertools.product(range(len(data.lat)), range(len(data.lon))):
      
        # Calculate the month with the maximum value (i.e. phase)
        max_mon = np.where(data_ann_cycle.isel(lat=i, lon=j) == data_ann_cycle.isel(lat=i, lon=j).max())
        
        # Add nan values for grid boxes without data
        try:
            data_phase[i,j] = max_mon[0].item(0)
        except:
            data_phase[i,j] = np.nan
            
    return data_phase