In [1]:
# import functions
# OS interaction and time
import os
import sys
import cftime
import datetime
import time
import glob
import dask
import dask.bag as db
import calendar

# math and data
import numpy as np
import netCDF4 as nc
import xarray as xr
import scipy as sp
import scipy.linalg
from scipy.signal import detrend
import pandas as pd
import pickle as pickle
from sklearn import linear_model
import matplotlib.patches as mpatches
from shapely.geometry.polygon import LinearRing
import statsmodels.stats.multitest as multitest

# plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import ticker
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec
import matplotlib.image as mpimg
from matplotlib.colors import TwoSlopeNorm

from matplotlib.ticker import FormatStrFormatter
from mpl_toolkits.axes_grid1.axes_divider import HBoxDivider
import mpl_toolkits.axes_grid1.axes_size as Size
from mpl_toolkits.axes_grid1 import make_axes_locatable

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point

# random
from IPython.display import display
from IPython.display import HTML
import IPython.core.display as di # Example: di.display_html('<h3>%s:</h3>' % str, raw=True)

In [2]:
my_era5_path = '/glade/u/home/zcleveland/scratch/ERA5/'  # path to subset data
misc_data_path = '/glade/u/home/zcleveland/scratch/misc_data/'  # path to misc data
plot_out_path = '/glade/u/home/zcleveland/NAM_soil-moisture/ERA5_analysis/plots/'  # path to generated plots
scripts_main_path = '/glade/u/home/zcleveland/NAM_soil-moisture/scripts_main/'  # path to my dicts, lists, and functions

In [3]:
# import variable lists and dictionaries
if scripts_main_path not in sys.path:
    sys.path.insert(0, scripts_main_path)  # path to file containing these lists/dicts
import my_dictionaries

# my lists
sfc_instan_list = my_dictionaries.sfc_instan_list  # instantaneous surface variables
sfc_accumu_list = my_dictionaries.sfc_accumu_list  # accumulated surface variables
pl_var_list = my_dictionaries.pl_var_list  # pressure level variables
invar_var_list = my_dictionaries.invar_var_list  # invariant variables
NAM_var_list = my_dictionaries.NAM_var_list  # NAM-based variables
region_avg_list = my_dictionaries.region_avg_list  # region IDs for regional averages
flux_var_list = my_dictionaries.flux_var_list  # flux variables that need to be flipped (e.g., sensible heat so that it's positive up instead of down
misc_var_list = my_dictionaries.misc_var_list  # misc variables

# my dictionaries
var_dict = my_dictionaries.var_dict  # variables and their names
var_units = my_dictionaries.var_units  # variable units
region_avg_dict = my_dictionaries.region_avg_dict  # region IDs and names
region_avg_coords = my_dictionaries.region_avg_coords  # coordinates for regions
region_colors_dict = my_dictionaries.region_colors_dict  # colors to plot for each region

In [4]:
# define a function to get var files, open dataset, and subset if needed
def get_var_data(var, region='dsw', months=[i for i in range(1,13)], **kwargs):
    r"""
    Retrieves the data for a given variable from my subet ERA5 dataset.  User can choose to return a dataset or data array
    and whether to subset that data based on a region or time.  Any subset data is returned as a data array.

    Parameters
    ----------
    var : str
            The variable desired
    region : str
            The region desired
    months : list, int
            A list of months desired [1, 2, ..., 12]

    Returns
    -------
    var_data : xarray Data Array
            A data array containing the desired data, either in full or subset based on user input

    Kwargs
    ------
    subset_flag : bool
            True or False.  Whether to subset the data or not
    level : int
            The pressure level desired.  Only applied for pressure level data
    type : str
            Specify whether to return a dataset or data array
    mean_flag : bool
            True or False.  Whether to compute the mean (or sum) over the specified months
    group_type : str
            How to group data prior to computing mean or sum across time.
            Options include 'year', 'month', 'dayofyear', etc.

    See Also
    --------
    get_var_files : returns all files for specified variable
    open_var_data : opens the variable dataset or data array
    subset_var_data : subsets data array based on user input

    Notes
    -----

    """

    files = get_var_files(var, region, **kwargs)
    var_data = open_var_data(files, var, **kwargs)
    if kwargs.get('subset_flag', True):
        return subset_var_data(var_data, var, months, region, **kwargs)
    return var_data

In [5]:
# define a function to get the files for a given variable/region
def get_var_files(var, region, **kwargs):

    # grab files for sfc var
    if ((var in sfc_instan_list) or (var in sfc_accumu_list)):
        # dsw
        if region != 'global':
            files = glob.glob(f'{my_era5_path}dsw/*/{var.lower()}_*_dsw.nc')

        elif region == 'global':
            files = glob.glob(f'{my_era5_path}global/*/{var.lower()}_*_dsw.nc')

    # grab files for pl var
    elif var in pl_var_list:
        files = glob.glob(f'{my_era5_path}dsw/*/pl/{var.lower()}_*_dsw.nc')

    # grab files for NAM var
    elif var in NAM_var_list:
        files = glob.glob(f'{my_era5_path}dsw/NAM_{var}.nc')

    elif var in misc_var_list:
        files = glob.glob(f'{misc_data_path}{var}/{var}*.nc')

    elif var in invar_var_list:
        files = glob.glob(f'{my_era5_path}invariants/{var}_invariant.nc')

    # if something went wrong
    else:
        print('something went wrong finding files')
        files = []

    files.sort()
    return files

In [6]:
# define a function to open variable datasets
def open_var_data(files, var, **kwargs):
    # get kwargs
    var_type = kwargs.get('type', 'da')  # default to returning a data array

    # open dataset
    ds = xr.open_mfdataset(files)

    # return dataset if specified
    if type == 'ds':
        return ds

    # pull out actual variable name in the dataset since they can be different names/capitalized
    var_name = [v for v in ds.data_vars.keys() if f'{var.upper()}' in v.upper()][0]
    return ds[var_name]

In [7]:
# define a function to open subset an input data set (or array) by:
# latitude/longitude
# time
# averages
def subset_var_data(var_data, var, months, region, **kwargs):

    # subset to regional data if region is not DSW
    if region in region_avg_list:
        lats = slice(region_avg_coords[region][2], region_avg_coords[region][3])
        lons = slice(region_avg_coords[region][0], region_avg_coords[region][1])
    else:
        lat_sub = kwargs.get('lat_sub', [40, 20])
        lon_sub = kwargs.get('lon_sub', [240, 260])
        lats = slice(lat_sub[0], lat_sub[1])
        lons = slice(lon_sub[0], lon_sub[1])

    # subset data by lat/lon
    if 'latitude' in var_data.dims and 'longitude' in var_data.dims:
        var_data = var_data.sel(latitude=lats, longitude=lons)

    # subset to level if var is a pl var
    if var.lower() in pl_var_list:
        level = kwargs.get('level', None)
        if level is not None:
            var_data = var_data.sel(level=level)

    group_type = kwargs.get('group_type', 'year')

    # if var is NAM var, grouping type must be year
    if var.lower() in NAM_var_list:
        if ((var.lower() == 'onset') or (var.lower() == 'retreat')):
            var_data = var_data.dt.dayofyear  # convert to dayofyear (datetime -> integer)
        groupby_type = 'year'
    # if var is not NAM var, subset by months
    else:
        # subset the data specified by months
        var_data = var_data.sel(time=var_data['time.month'].isin(months))
        groupby_type = f'time.{group_type}'

    # subset further and compute mean/sum if specified by mean_flag
    if region in region_avg_list:  # default to averaging regional data over lat/lon
        dim_means = kwargs.get('dim_means', ['latitude', 'longitude'])
    else:  # default to not averaging dsw and global data
        dim_means = kwargs.get('dim_means', [])

    time_idx_list = ['time', 'year', 'month', 'day', 'dayofyear', 'season']
    time_idx = [v for v in time_idx_list if v in var_data.dims][0]
    if not time_idx:
        print('Something wrong with time index')
        return None

    return var_data.groupby(groupby_type).mean(dim=dim_means)

In [8]:
# define a function to check if inputs are list or not
def ensure_var_list(x):

    if not isinstance(x, list):
        return [x]
    return x

In [9]:
# define a function to turn a list of integers into months
def month_num_to_name(var, months, **kwargs):

    # make string for month letters from var_range (e.g. [6,7,8] -> 'JJA')
    if var in NAM_var_list:
        var_months = ''  # don't use months for onset, retreat, length
    elif len(months) == 1:
        var_months = calendar.month_name[months[0]]  # use full month name if only 1 month
    elif len(months) == 12 and kwargs.get('mean_flag', True):
        var_months = 'YEAR'
    elif ((len(months) > 1) & (len(months) <= 12)):
        var_months = ''.join([calendar.month_name[m][0] for m in months])  # make string of months, i.e. 3, 4, 5 is MAM
    return var_months

In [10]:
# define a function to detrend the data

# MANUALLY DETREND WITH LINEAR REGRESSION
def detrend_data(arr):

    # set up x array for the years
    arr_time = np.arange(0,len(arr))

    # mask out nan values
    mask = np.isfinite(arr)
    arr_time_mask = arr_time[mask]
    arr_mask = arr[mask]

    # make sure the array is not full of non-finite values
    if len(arr_mask) == 0:
        arr_detrend = np.empty(len(arr))
        arr_detrend[:] = np.nan

    else:
        # compute linear regression
        result = sp.stats.linregress(arr_time_mask, arr_mask)
        m, b = result.slope, result.intercept

        # detrend the data
        arr_detrend = arr - (m*arr_time + b)

    return arr_detrend


# define a function to mask data for detrending or correlating
def apply_detrend(da, **kwargs):

    input_dims = kwargs.get('input_dims', 'time')
    # load data
    da.load()

    da_detrend = xr.apply_ufunc(
        detrend_data, da,
        input_core_dims=[[input_dims]],
        output_core_dims=[[input_dims]],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[da.dtype]
    )

    return da_detrend

In [11]:
# define a function to regress data
def regress_data(arr1, arr2):

    # mask out nan values
    mask = np.isfinite(arr1) & np.isfinite(arr2)
    arr1_mask = arr1[mask]
    arr2_mask = arr2[mask]

    if len(arr1_mask) < 2:  # check if there are enough data points
        return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan

    else:
        # compute linear regression
        res = sp.stats.linregress(arr1_mask, arr2_mask)
        return res.slope, res.intercept, res.rvalue, res.pvalue, res.stderr, res.intercept_stderr


# define a function to mask data for detrending or correlating
def apply_regression(da1, da2, **kwargs):

    input_dims = kwargs.get('input_dims', 'time')
    # load data
    da1.load()
    da2.load()

    result = xr.apply_ufunc(
        regress_data, da1, da2,
        input_core_dims=[[input_dims], [input_dims]],
        output_core_dims=[[], [], [], [], [], []],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[float, float, float, float, float, float]
    )
    regression_ds = xr.Dataset({
        'slope': result[0],
        'intercept': result[1],
        'rvalue': result[2],
        'pvalue': result[3],
        'stderr': result[4],
        'intercept_stderr': result[5]
    })
    # regress_da = xr.DataArray(result)
    return regression_ds

In [12]:
# define a function to calculate the Pearson correlation and p-value statistic
def compute_corr_pval(arr1, arr2, **kwargs):
    # mask out nan and inf values
    mask = np.isfinite(arr1) & np.isfinite(arr2)
    filtered_arr1 = arr1[mask]
    filtered_arr2 = arr2[mask]

    if len(filtered_arr1) < 2:  # check if there are enough data points
        return np.nan, np.nan

    corr, pval = sp.stats.pearsonr(filtered_arr1, filtered_arr2)
    return corr, pval


# define a function to apply the ufunc to the data
def apply_correlation(da1, da2, **kwargs):
    da1.load()
    da2.load()
    input_dims = kwargs.get('input_dims', 'year')
    result = xr.apply_ufunc(
        compute_corr_pval, da1, da2,
        input_core_dims=[[input_dims], [input_dims]],
        output_core_dims=[[],[]],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[float, float]
    )
    corr_da = result[0]
    pval_da = result[1]

    corr_ds = xr.merge([corr_da.rename('pearson_r'), pval_da.rename('p_value')])
    return corr_ds

In [13]:
# define a function to calculate the Pearson correlation and p-value statistic
def compute_coherence(arr1, arr2):
    # mask out nan and inf values
    mask = np.isfinite(arr1) & np.isfinite(arr2)
    filtered_arr1 = arr1[mask]
    filtered_arr2 = arr2[mask]

    if len(filtered_arr1) < 2:  # check if there are enough data points
        return np.nan, np.nan

    corr, pval = sp.signal.coherence(filtered_arr1, filtered_arr2)
    return corr, pval


# define a function to apply the ufunc to the data
def apply_coherence(da1, da2):
    da1.load()
    da2.load()
    result = xr.apply_ufunc(
        compute_coherence, da1, da2,
        input_core_dims=[['year'], ['year']],
        output_core_dims=[[],[]],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[float, float]
    )
    corr_da = result[0]
    pval_da = result[1]

    corr_ds = xr.merge([corr_da.rename('pearson_r'), pval_da.rename('p_value')])
    return corr_ds

In [14]:
# define a function to calculate the principal components of a data array
def calc_pcs(da, **kwargs):

    # # normalize da with mean and std deviation along time dimension
    # da_mean = np.mean(da, axis=0)
    # da_std = np.std(da, axis=0)
    # da_norm = (da - da_mean) / da_std
    da_norm = da

    # calculate covariance matrix
    da_cov = np.cov(da_norm, rowvar=False)

    # perform eigen decomposition
    eigenvalues, eigenvectors = sp.linalg.eigh(da_cov)

    # sort eigenvalues and eigenvectors in descending order
    sorted_indices = np.argsort(eigenvalues)[::-1]
    eigenvalues = eigenvalues[sorted_indices]
    eigenvectors = eigenvectors[:, sorted_indices]

    # calculated principal components
    pcs = np.dot(da_norm, eigenvectors)

    return pcs, eigenvalues, eigenvectors


# define a function to calculate the explained variance of one varialbe by another
def calc_explained_variance(da, pcs, **kwargs):

    # regress da onto pcs
    regression = np.linalg.lstsq(pcs, da, rcond=None)[0]

    # calculate explained variance by da for each PC
    da_explained = np.dot(pcs, regression)

    # calculate total variance of original pcs
    total_variance = np.var(pcs, axis=0)

    # calculate proportion of variance explained by da for each PC
    explained_variance_ratio = np.var(da_explained, axis=0) / total_variance

    return explained_variance_ratio


# define the main function to calculate the EOF that identifies the
# variance of da2 explained by da1
def calc_eof(da1, da2, **kwargs):

    # ensure da has dimensions (time, space), where space is (lat, lon)
    # da1 = da1.stack(space=('latitude', 'longitude'))
    da2_stacked = da2.stack(space=('latitude', 'longitude'))

    # convert to numpy arrays for processing


    # get pcs, eigenvalues, and eigenvectors
    pcs, eigenvalues, eigenvectors = calc_pcs(da2_stacked)

    # get explained variance ratio
    evr = calc_explained_variance(da1, pcs)

    # reshape variance ratio back to spatial dimensions
    evr = evr.reshape((da2.sizes['latitude'], da2.sizes['longitude']))

    evr_da = xr.DataArray(evr, dims=['latitude', 'longitude'],
                          coords={'latitude': da2.coords['latitude'], 'longitude': da2.coords['longitude']})

    # evr = explained_variance_ratio.unstack()
    return evr_da

In [15]:
# define a function to zscore a variable
def calc_zscore_monthly(da, **kwargs):
    # assuming da is input with dimensions (year: , month: , ...)
    da_mean = da.mean(dim='year')
    da_std = da.std(dim='year')
    da_zscore = (da - da_mean) / da_std

    return da_zscore


# define a function to apply the calc_zscore function to an xarray dataset
def apply_zscore(da, **kwargs):
    pass
    # convert time index to year and month indices

In [None]:
# test cell

v1 = 'swvl1'
v1_months = [3, 4, 5]
v2 = 'onset'
v2_months = [6, 7, 8]

var1 = get_var_data(v1, 'cp')
var2 = get_var_data(v2)

# detrend the data
var1_detrend = apply_detrend(var1, input_dims='year')
var2_detrend = apply_detrend(var2, input_dims='year')

# compute monthly means
var1_detrend_monthly = time_to_year_month_avg(var1_detrend)
var2_detrend_monthly = time_to_year_month_avg(var2_detrend)

# compute zscores
var1_zscore = calc_zscore_monthly(var1_detrend_monthly)
var2_zscore = calc_zscore_monthly(var2_detrend_monthly)

# compute correlation and linear regression of summer z onto spring sd
var1_spring = var1_zscore.sel(month=v1_months).mean(dim='month')
var2_summer = var2_zscore.sel(month=v2_months).mean(dim='month')
corr_ds = apply_correlation(var1_spring, var2_summer, input_dims='year')
regression_ds = apply_regression(var1_spring, var2_summer, input_dims='year')


# plot correlation and pvalues
corr_da = corr_ds['pearson_r']
pval_da = corr_ds['p_value']
projection = ccrs.PlateCarree()
fig, ax = plt.subplots(figsize=(10, 6), subplot_kw=dict(projection=projection))

# create contour levels and hatches for plotting
corr_levels = np.arange(-1, 1.05, 0.05)

# plot the data using contourf
corr_cf = plt.contourf(corr_da.longitude, corr_da.latitude,
                       corr_da, levels=corr_levels,
                       cmap='RdBu_r', extend='both')

# extract coordinates where p-value < 0.1 (dots) and p-value < 0.05 (triangles)
lat, lon = np.meshgrid(pval_da.latitude, pval_da.longitude, indexing='ij')
mask_dots = (pval_da <= 0.1) & (pval_da >= 0.05)
mask_triangles = pval_da <= 0.05

# Plot dots (p-value < 0.1 and >= 0.05)
plt.scatter(lon[mask_dots], lat[mask_dots], color='black', marker='.',
            s=5, transform=ccrs.PlateCarree(), label='0.05 <= p < 0.1')

# Plot triangles (p-value < 0.05)
plt.scatter(lon[mask_triangles], lat[mask_triangles], color='black', marker='^',
            s=8, transform=ccrs.PlateCarree(), label='p < 0.05')

# add coastlines, state borders, and other features
ax.coastlines(linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
ax.add_feature(cfeature.STATES, linewidth=0.5)

plt.colorbar(corr_cf, ax=ax, label=f'pearson r', pad=0.02)
plt.tight_layout()

plt.show()
plt.close()


# plot regression data
var1_min = var1_spring.min(dim=['year'])
var1_max = var2_summer.max(dim=['year'])
var2_min = regression_ds['slope'] * var1_min + regression_ds['intercept']
var2_max = regression_ds['slope'] * var1_max + regression_ds['intercept']
var2_spread = var2_max - var2_min

projection = ccrs.PlateCarree()
fig, ax = plt.subplots(figsize=(12, 10), subplot_kw=dict(projection=projection))

# create contour levels and hatches for plotting
vmin = np.nanmin(var2_spread)
vmax = np.nanmax(var2_spread)
cf_levels = np.linspace(vmin, vmax, 50)

if vmin < 0 and vmax > 0:
    norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
    cmap = 'RdBu_r'
else:
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    cmap = 'Blues' if vmax <= 0 else 'Reds'

# plot the data using contourf
regress_cf = plt.contourf(var2_spread.longitude, var2_spread.latitude,
                          var2_spread, levels=cf_levels,
                          cmap=cmap, norm=norm, extend='both')

regress_cs = plt.contour(regression_ds.longitude, regression_ds.latitude,
                         regression_ds['slope'], levels=10, linewidths=0.5, linestyles='--', colors='black')

# add coastlines, state borders, and other features
ax.coastlines(linewidth=0.5)
ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
ax.add_feature(cfeature.STATES, linewidth=0.5)

plt.colorbar(regress_cf, ax=ax, label=f'm', pad=0.02)
plt.clabel(regress_cs, inline=True, fontsize=8, fmt='%1.1f')
plt.tight_layout()

plt.show()
plt.close()

In [None]:
def time_to_year_dayofyear(ds):
    year = ds.time.dt.year
    dayofyear = ds.time.dt.dayofyear

    # Assign new coordinates
    ds = ds.assign_coords(year=("time", year.data), dayofyear=("time", dayofyear.data))

    # Set the index to year and day and unstack
    return ds.set_index(time=("year", "dayofyear")).unstack("time")

In [20]:
def time_to_year_month_avg(ds):
    years = ds.groupby('time.year').mean('time').year
    months = ds.groupby('time.month').mean('time').month

    # make a pandas MultiIndex that is years x months
    midx = pd.MultiIndex.from_product([years.values,months.values], names=("year","month"))

    ds_temp = ds.resample(time='1M').mean(dim='time')

    # try to multiindex the dataset directly
    ds_temp = ds_temp.assign_coords({'time':midx})

    ds_out = ds_temp.unstack()

    return ds_out

In [21]:
def remove_non_leap_days(ds):
    # convert time dimension to pandas
    time_index = pd.DatetimeIndex(ds['time'].values)

    # identify dates that are feb 29 on non leap years
    feb_29_non_leap_years = time_index[(time_index.month == 2) & (time_index.day == 29) & (~time_index.is_leap_year)]

    # filter out feb 29 on non leap years
    filtered_time_index = time_index.difference(feb_29_non_leap_years)

    # index data array to exclude feb 29 on non leap years
    filtered_ds = ds.sel(time=filtered_time_index)
    return filtered_ds

In [22]:
def year_dayofyear_to_time(ds):
    # Stack year and day back into a multiindex time
    ds_stacked = ds.stack(time=('year', 'dayofyear'))

    # Convert the multiindex to datetime
    time_index = pd.to_datetime(ds_stacked.indexes['time'].to_frame().apply(lambda x: f'{x[0]}-{x[1]}', axis=1), format='%Y-%j')

    # Create a new DataArray with the correct time index
    ds_stacked = ds_stacked.assign_coords({'time':time_index.values})
    ds_stacked = ds_stacked.swap_dims({'time': 'time'})

    # Drop the old coordinates year and day
    ds_stacked = ds_stacked.reset_coords(['year', 'dayofyear'], drop=True)

    return ds_stacked

In [23]:
def year_day_to_time(ds):
    # Stack year and day back into a multiindex time
    ds_stacked = ds.stack(time=('year', 'day'))

    # Convert the multiindex to datetime
    time_index = pd.to_datetime(ds_stacked.indexes['time'].to_frame().apply(lambda x: f'{x[0]}-{x[1]}', axis=1), format='%Y-%j')

    # Create a new DataArray or Dataset with the correct time index and original data
    new_ds = ds_stacked.assign_coords(time=time_index).swap_dims({'time': 'time'}).reset_coords(['year', 'day'], drop=True)

    return new_ds

In [24]:
def time_to_split_time(ds, dim_to_split, new_dim1, new_dim2):
    # get the values for the new dimensions
    dim1 = getattr(ds[dim_to_split].dt, new_dim1)
    dim2 = getattr(ds[dim_to_split].dt, new_dim2)

    # assign new coordinates dynamically
    ds = ds.assign_coords(**{new_dim1: (dim_to_split, dim1.data), new_dim2: (dim_to_split, dim2.data)})

    # set the index to the new dimensions
    ds = ds.set_index({dim_to_split: (new_dim1, new_dim2)})

    # unstack the new dimensions to reshape the array
    ds = ds.unstack(dim_to_split)

    return ds

In [25]:
def time_to_split_time(ds, dim_to_split, new_dim1, new_dim2):
    # Get the values for the new dimensions
    dim1 = getattr(ds[dim_to_split].dt, new_dim1)
    dim2 = getattr(ds[dim_to_split].dt, new_dim2)

    # Ensure uniqueness by including the day of the month
    days = ds[dim_to_split].dt.day

    # Assign new coordinates dynamically
    ds = ds.assign_coords(
        **{new_dim1: (dim_to_split, dim1.data), 
           new_dim2: (dim_to_split, dim2.data),
           'day': (dim_to_split, days.data)}
    )

    # Set the index to the new dimensions
    ds = ds.set_index({dim_to_split: (new_dim1, new_dim2, 'day')})

    # Unstack the new dimensions to reshape the array
    ds = ds.unstack(dim_to_split)

    return ds

In [None]:
sd_reshaped = time_to_year_day(sd)
sd_mean = sd_reshaped.mean(dim='year')
sd_std = sd_reshaped.std(dim='year')
sd_norm = (sd_reshaped - sd_mean) / sd_std