# Master Utility Notebook for Lisa's plot

**Used in Conjunction with other notebooks for analysis and plotting** <br>
Phuong Loan Nguyen et al. 2025

In [None]:
# Import modules 
import pandas as pd
import os
import fnmatch
import xarray as xr
import xskillscore as xs
import numpy as np
import math
from scipy import stats
import matplotlib.pyplot as plt
import matplotlib.colors as colors

## File Input Functions

To get data from different directories within my CMIP5 database, f-strings are used in the analysis notebooks. <br>
**This will likely need to be updated based on your database structure** <br>
- The **CMIP6** database is set-up so that the Climpact indices follow the path: <br>
    - parent_directory/{indice-keyword}/{ANN-or-MON}/dataset_file.nc <br><br>
- The **observational** dataset (APHRODITE) is organized differently in the database: <br>
    - parent_directory/CMIP6/{obs}/{grid_type}/{climpact-or-daily-data}/{indice-keyword}/dataset_file_MON-or-ANN.nc

**Functions Include:**
- Get data model names from dictionaries below
- Get subset of models from dictionary defined in analysis notebooks
- Get model data paths
- Get data from paths
    - This includes exceptions to handle the 4-dimensionality of the SPI variable and the "day" unit of the CD(W)D and R_mm variables

### Define names of models
I select models from my database based on the GCM and RCM names included in the file name. This will likely need to be updated for use in other databases.

In [1]:
# List names of forcing GCMs
gcm_names = ['ACCESS-CM2', 'ACCESS-ESM1-5', 'BCC-CSM2-MR', 'CNRM-CM6-1', 'CNRM-CM6-1-HR', 'CNRM-ESM2-1', 'EC-Earth3-Veg', 
            'FGOALS-g3', 'GFDL-CM4', 'GFDL-ESM4', 'HadGEM3-GC31-LL', 'HadGEM3-GC31-MM', 'INM-CM4-8', 'INM-CM5-0', 'KACE-1-0-G',
    'KIOST-ESM', 'MIROC-ES2L', 'MPI-ESM1-2-HR', 'MRI-ESM2-0', 'NESM3', 'NorESM2-LM', 'NorESM2-MM', 'UKESM1-0-LL', 'CanESM5', 'CanESM5', 'MPI-ESM1-2-LR', 
            'CPC_v1.0', 'GPCC_FDD_v2022', 'REGEN_ALL_2019', '3B42_v7.0','GIRAFE', 'GSMAP-NRT-gauges-v8.0', 'IMERG-v07B-FC', 'PERSIANN_v1_r1', 
             'CHIRPS_v2.0','GPCP_V3.2', 'CMORPH_v1.0_CRT', 'MERRA2','JRA-55', 'ERA5', 'CFSR'] 

### Get Model Names

In [2]:
# Use the dictionaries above to get the model names from each file
def get_data_name(names, file):
    
    # Check the model name is in one of the above model lists
    for name in names:
        if fnmatch.fnmatch(file, "*" + name + "_*"):
            return name
    
    return None

### Get Model Data Paths

In [3]:
# Get model file paths and add to a Pandas Dataframe 
def get_model_files(model_data_master_path):
    
    # Initialise pandas df of model_paths
    model_paths = pd.DataFrame(columns=['driving_gcm', 'dataset_path'])

    # Loop through the data directory passed to the function, check that the file is a model in the above lists, and get the dataset paths
    for model_file in os.listdir(model_data_master_path):
    
        # Get the RCM and GCM files names
        driving_gcm_name = get_data_name(gcm_names, model_file)
    
        # Get full path to the dataset 
        model_complete_file_path = model_data_master_path + model_file
            
        # Add information to DataFrame
        model_paths.loc[len(model_paths.index)] = [driving_gcm_name, model_complete_file_path]
            
    return model_paths

### Get Model Data Paths for Subset

In [None]:
# Get model file paths for subset of models from list defined above
def get_model_files_subset(model_paths, subset_names):
    
    # Initialize new Pandas Dataframe to hold data paths for subsets of models
    model_paths_subset = pd.DataFrame(columns=['model_name', 'dataset_path'])

    for i, row in model_paths.iterrows():
        model_name = f'{row[0]}'
    
        if model_name in subset_names:
            model_paths_subset.loc[len(model_paths_subset.index)] = [model_name, f'{row[1]}']
        
    return model_paths_subset

### Get Data from File Paths

In [4]:
# Extract data from files based on user specifications 
def get_data_from_file(file_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None):
    
    # If iscale is not given, check other variable exceptions before extracting data
    if iscale is None:
    
        # If the variable is CDD/CWD (Consecutive Dry/Wet Days) or R*mm (number of heavy rain days), the data must be read in differently to convert the variable dtype to float
        if variable == 'cdd':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.cdd.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.cdd.attrs['units'] = 'days'
            data_var = data_ds.cdd.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
            
        elif variable == 'cwd':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.cwd.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.cwd.attrs['units'] = 'days'
            data_var = data_ds.cwd.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
    
        elif variable == 'r10mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r10mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r10mm.attrs['units'] = 'days'
            data_var = data_ds.r10mm.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
        
        elif variable == 'r20mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r20mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r20mm.attrs['units'] = 'days'
            data_var = data_ds.r20mm.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
        
        elif variable == 'r30mm':
        
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.r30mm.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.r30mm.attrs['units'] = 'days'
            data_var = data_ds.r30mm.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
            
        elif variable == 'fracprday':
            
            data_ds = xr.open_dataset(file_path, decode_cf=False)
            data_ds.fracprday.attrs['units'] = '1'
            data_ds = xr.decode_cf(data_ds)
            data_ds.fracprday.attrs['units'] = 'days'
            data_var = data_ds.fracprday.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
            
    
        else:
    
            # Open file and store data as a DataArray
            data_ds = xr.open_dataset(file_path)

            # Subset data based on user-specified spatiotemporal boundaries
            data_subset = data_ds.sel(time=time_slice, lat=lat_slice, lon=lon_slice)
    
            # Extract user-specified variable
        
            data_var = getattr(data_subset, variable)
   
    # Extract appropriate SPI data based on the user-defined scale if the iscale is not None
    # (1 for 3-month averaging period, 2 for 6-month averaging period, 3 for 12-month averaging period)
    else:
        
        # Extract user-specified SPI data
        data_ds = xr.open_dataset(file_path)
        data_var = data_ds.spi.sel(scale=iscale, time=time_slice, lat=lat_slice, lon=lon_slice)
        
    if season is None: 
        pass
        
    else:
        data_var = data_var.sel(time=data_var.time.dt.month.isin(season))
    
    return data_var

## Plotting Functions

Functions used when creating plots/figures. <br>
**Functions Include:** <br>
- Truncating a predefined colormap
- Define vertical bands to shade along a time series

### Truncate a color map (skip colors in pre-defined colormap)

In [4]:
# Function to be able to skip colors/subset colors within a predefined colormap
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

# Code sourced from https://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib

### Define vertical bands (regions) to shade along a time series

In [2]:
# Define function to find areas that need to be shaded using a boolean mask as input
def fill_vertical_columns(boolean_fill_mask):
    
    # Find areas in the mask when values change (i.e. boolean switched from True to False and vice versa)
    boolean_switch = np.diff(boolean_fill_mask)
    
    # Find start and end of sections where the boolean fill mask is True (places I want to shade)
    region_to_shade, = boolean_switch.nonzero()
    
    # Handle edge cases where condition starts or ends with True
    if boolean_fill_mask[0]:
        region_to_shade = np.r_[0, region_to_shade]
   
    if boolean_fill_mask[-1]:
        region_to_shade = np.r_[region_to_shade, len(boolean_fill_mask)]
    
    # Reshape the result into pairs of start/end indices
    region_to_shade = region_to_shade.reshape((-1, 2))
    
    return region_to_shade

## Spatially Averaged Functions

The following functions calculate spatially averaged metrics, primarily for plotting time series and scatterplots. <br>
'season=None' sets an optional argument for seasonal subsetting based on numeric month values <br>
'mask=None' sets an optional argument for masking, where the default is None <br>
'scale=None' sets an optional argument for selecting the scale of the 4-D SPI indice. This should only be assigned a value <br>
    if the variable is SPI. The scale can be 3, 6, or 12 months. <br>
**Functions Include:** <br>
- Annual Time Series
- Weighted Spatial Average

### Weighted Spatial Mean at Default Time Step (No Temporal Averaging)

In [1]:
# Get monthly averages to use for a line plot over a user-specified time period and spatial domain
def get_weighted_ts(data_path, variable, time_slice, lat_slice, lon_slice, season=None, iscale=None, mask=None):
    
    # Get data
    # Include SPI scale in data extraction if a scale is provided by the user
    if iscale is None:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, season)
    
    else:
        data = get_data_from_file(data_path, variable, time_slice, lat_slice, lon_slice, season, iscale)
              
    # Extract unmasked data if no mask provided
    if mask is None:
        pass

    else:
       
        # Apply mask to data
        data = data.where(mask==1)
    
    # Compute latitudinally weighted spatially averaged data
    # Create latitudinal weights
    weights = np.cos(np.deg2rad(data.lat))
    weights.name = "weights"

    # Weight: apply weighted lats to data
    data_weighted = data.weighted(weights)

    # Get weighted averages at each time step
    data_weighted_mean = data_weighted.mean(dim=('lon', 'lat'))
    
    
    # Return weighted spatial average at default time step
    return data_weighted_mean