# Documentation

**Authors:** Mu-Ting Chien & Spencer Ressel

**Created:** June 25th, 2020

This jupyter notebook contains a multitude of functions useful for analyzing geospatial data relevant to the Madden-Julian Oscillation (MJO). 

# Imports 

In [8]:
import numpy as np
from scipy.signal import butter, filtfilt

# Functions

## Basic Functions

In [9]:
def inpaint_nans(x, critical_val, big_or_small):  # assume x is 3-dim, assume nan is the same in time
    """
    Basics: remove nan, calculate anomaly, transform time format(yyyymmdd to year_month_day), 
           select seasons, meridional average
    """
    
    
    if big_or_small == 1:  # >critical value-->nan
        i = np.squeeze(np.argwhere(x[1, :, :] > critical_val))
        # print(i)
        print(np.shape(i))
    elif big_or_small == 0:  # <critical value-->nan
        i = np.squeeze(np.argwhere(x[1, :, :] < critical_val))
    elif big_or_small == -1:
        i = np.squeeze(np.argwhere(np.isnan(x) == 1))
        print(i)
        print(np.shape(i))
    if np.sum(i) == 0:
        print("no nan data")
        x_nonan = x
    else:
        print("has nan data!!!")
        nansize = np.size(i, 0)
        a = np.empty([np.size(x, 0), nansize])
        a[:] = np.nan
        x_nonan = x
        # x_nonan[i[:,0],i[:,1],i[:,2]] = a
        x_nonan[:, i[:, 0], i[:, 1]] = a

        for j in range(0, 1):  # nansize):
            # j0 = i[j,0]
            j1 = i[j, 0]  # [j,1]
            j2 = i[j, 1]  # [j,2]

            """
            if j0 ==0:
                J0L = np.nan
            else:
                J0L = x_nonan[j0-1,j1,j2]
                
            if j0 == np.size(x_nonan,0)-1:
                J0R = np.nan
            else:
                J0R = x_nonan[j0+1,j1,j2]
            """
            if j1 == 0:
                J1L = np.empty([np.size(x, 0)])
                J1L[:] = np.nan
            else:
                # J1L = x_nonan[j0,j1-1,j2]
                J1L = x_nonan[:, j1 - 1, j2]
            if j1 == np.size(x_nonan, 1) - 1:
                J1R = np.empty([np.size(x, 0)])
                J1R[:] = np.nan
            else:
                # J1R = x_nonan[j0,j1+1,j2]
                J1R = x_nonan[:, j1 + 1, j2]
            if j2 == 0:
                J2L = np.empty([np.size(x, 0)])
                J2L[:] = np.nan
            else:
                # J2L = x_nonan[j0,j1,j2-1]
                J2L = x_nonan[:, j1, j2 - 1]
            if j2 == np.size(x_nonan, 2) - 1:
                J2R = np.nan([np.size(x, 0)])
            else:
                J2R = np.empty([np.size(x, 0)])
                J2R[:] = np.nan
            # x_nonan[j0,j1,j2] = np.nanmean(np.array([J0L,J0R,J1L,J1R,J2L,J2R]))
            # for k in range(0,np.size(x,0)):
            x_nonan[:, j1, j2] = np.nanmean(np.array([J1L, J1R, J2L, J2R]), 0)
            print(J1L)
            print(J1R)
            print(J2L)
            print(J2R)
            print(x_nonan[:, j1, j2])
    return x_nonan

In [10]:
def inpaint_nans_2(
    x, critical_val, big_or_small
):  # assume x is 3-dim, assume nan is the same in time dimension
    if big_or_small == 1:  # >critical value-->nan
        x = np.where(x > critical_val, np.nan, x)
    elif big_or_small == 0:  # <critical value-->nan
        x = np.where(x < critical_val, np.nan, x)
    i = np.argwhere(np.isnan(x[1, :, :]) == 1)
    # print(np.shape(i))
    print(
        "portion of nan data:" + str(np.size(i, 0) / ((np.size(x, 1) * np.size(x, 2))))
    )
    if np.sum(np.shape(i)) == 0:
        print("no nan data")
        x_nonan = x
    else:
        print("has nan data!!!")
        nansize = np.size(i, 0)
        x_nonan = x

        for j in range(0, nansize):
            j1 = i[j, 0]  # [j,1]
            j2 = i[j, 1]  # [j,2]
            if j1 == 0:
                J1L = np.empty([np.size(x, 0)])
                J1L[:] = np.nan
            else:
                J1L = x_nonan[:, j1 - 1, j2]
            if j1 == np.size(x_nonan, 1) - 1:
                J1R = np.empty([np.size(x, 0)])
                J1R[:] = np.nan
            else:
                J1R = x_nonan[:, j1 + 1, j2]
            if j2 == 0:
                J2L = np.empty([np.size(x, 0)])
                J2L[:] = np.nan
            else:
                J2L = x_nonan[:, j1, j2 - 1]
            if j2 == np.size(x_nonan, 2) - 1:
                J2R = np.empty([np.size(x, 0)])
                J2R[:] = np.nan
            else:
                J2R = x_nonan[:, j1, j2 + 1]
            x_nonan[:, j1, j2] = np.nanmean(np.array([J1L, J1R, J2L, J2R]), 0)
    print(np.argwhere(np.isnan(x_nonan) == 1))
    return x_nonan

In [11]:
def inpaint_nans_3(
    x, critical_val, big_or_small
):  # assume x is 3-dim, assume nan is the same in time dimension
    if big_or_small == 1:  # >critical value-->nan
        x = np.where(x > critical_val, np.nan, x)
    elif big_or_small == 0:  # <critical value-->nan
        x = np.where(x < critical_val, np.nan, x)
    x_nonan = x
    if np.sum(np.shape(np.argwhere(np.isnan(x) == 1))) == 0:
        print("no nan data")
    else:
        print("has nan data!!!")
        for t in range(0, np.size(x, 0)):
            i = np.argwhere(np.isnan(x[t, :, :]) == 1)
            # print(np.shape(i))
            if t == 0:
                print(
                    "portion of nan data:"
                    + str(np.size(i, 0) / ((np.size(x, 1) * np.size(x, 2))))
                )
            nansize = np.size(i, 0)

            for j in range(0, nansize):
                j1 = i[j, 0]  # [j,1]
                j2 = i[j, 1]  # [j,2]
                if j1 == 0:
                    # J1L = np.empty([np.size(x,0)])
                    J1L = np.nan
                else:
                    J1L = x_nonan[t, j1 - 1, j2]
                if j1 == np.size(x, 1) - 1:
                    # J1R = np.empty([np.size(x,0)])
                    J1R = np.nan
                else:
                    J1R = x_nonan[t, j1 + 1, j2]
                if j2 == 0:
                    # J2L = np.empty([np.size(x,0)])
                    J2L = np.nan
                else:
                    J2L = x_nonan[t, j1, j2 - 1]
                if j2 == np.size(x, 2) - 1:
                    # J2R = np.empty([np.size(x,0)])
                    J2R = np.nan
                else:
                    J2R = x_nonan[t, j1, j2 + 1]
                x_nonan[t, j1, j2] = np.nanmean(np.array([J1L, J1R, J2L, J2R]))
                if np.isnan(x_nonan[t, j1, j2]) == 1:
                    print("t=" + str(t))
                    print("j1=" + str(j1))
                    print("j2=" + str(j2))
            if len(np.argwhere(np.isnan(x_nonan) == 1)) == 0:
                print("succesfully inpaint nan")
    # print(np.argwhere(np.isnan(x_nonan)==1))
    return x_nonan

In [12]:
def filled_to_nan(x, critical_val, big_or_small):

    if big_or_small == 1:  # >critical value-->nan
        x_nan = np.where(x > critical_val, np.nan, x)
    elif big_or_small == 0:  # <critical value-->nan
        x_nan = np.where(x < critical_val, np.nan, x)
    if np.sum(np.isnan(x_nan)) == 0:
        pass
        # print('Checking for nan data')
        # print('======================')
    else:
        print("Signal has nan data!!!")
    return x_nan

In [13]:
def remove_annual_cycle(data, n_harmonics=3):
    """
    This function removes the annual cycle and it's first n harmonics from a given signal. 
    Note: The first dimension must be time. The number and order of other dimensions is arbitrary.

    Parameters
    ----------
    data : xarray.core.dataarray.DataArray
        An array containing gridded data from which to remove the annual cycle.
    n_harmonics: int, (default 3.0)
        An integer specifying the number of harmonics to be removed from the signal

    Returns
    -------
    data_deannualized : xarray.core.dataarray.DataArray
        An array containing the processed signal with the annual cycle removed.
    annual_cycle : xarray.core.dataarray.DataArray
        An array containing the signal of the annual cycle that was removed.

    """
    # mmax = 7  # (7-1)/2= 3; remove mean and first 3 harmonics
    mmax = 2*n_harmonics+1
    harmonics = np.arange(1, n_harmonics+1)

    # Length of the time axis
    n_times = np.size(data, axis=0)

    # An array increasing over the time axis
    t = np.arange(1, n_times + 1, 1)

    # Arrays specifying the odd and even columns of the matrix A
    odds = np.arange(1, mmax, 2)
    evens = np.arange(2, mmax + 1, 2)
    
    # The matrix A has a shape that depends on the shape of the input data
    A = np.ones([mmax, n_times])

    # Specify the odd columns of A
    A[odds] = np.cos(2*np.pi*harmonics[:, np.newaxis]*t[np.newaxis, :]/365)

    # Specify the even columns of A
    A[evens] = np.sin(2*np.pi*harmonics[:, np.newaxis]*t[np.newaxis, :]/365)

    # Matrix multiplication of the A matrix and the input data
    C = np.einsum("ij, j...->i...", A, data)/n_times

    # Calculate the annual cycle
    annual_cycle = np.einsum("ij,i...->j...", A, C)

    # Remove the annual cycle from the input data
    data_deannualized = data - annual_cycle
        
    return data_deannualized, annual_cycle

In [14]:
def mean_var(x, x_ano, x_f):  # caution: assume time in the 0th direction
    # Calculate mean of original data, variance of anomaly/filtered data
    # if dim == 3:
    x_m = np.nanmean(x, 0)  # mean
    # if dim == 4:
    # x_m = np.mean(x[:,ilev,:,:],0) # mean
    x_av = np.nanvar(x_ano, 0)  # variance of anomaly
    x_fv = np.nanvar(x_f, 0)  # variance of intraseasonal signal
    return x_m, x_av, x_fv

In [15]:
def yyyymmdd_y_m_d(dates):  # transfer yyyymmdd into year mon day (3 matrix)
    year = np.zeros(np.size(dates))
    mon = np.zeros(np.size(dates))
    day = np.zeros(np.size(dates))
    for i in range(0, np.size(dates)):
        time_str = str(dates[i])  # original format of time is yyyymmdd
        year[i] = int(time_str[0:4])
        mon[i] = int(time_str[4:6])
        day[i] = int(time_str[6:8])
    return year, mon, day

In [16]:
def yyyymmdd_to_datetime64(dates):
    from datetime import datetime

    year = np.zeros(np.size(dates), dtype=int)
    month = np.zeros(np.size(dates), dtype=int)
    day = np.zeros(np.size(dates), dtype=int)
    dates_long = []
    for i in range(0, np.size(dates)):
        time_str = str(dates[i])  # original format of time is yyyymmdd
        year[i] = int(time_str[0:4])
        month[i] = int(time_str[4:6])
        day[i] = int(time_str[6:8])
        dates_long.append(np.datetime64(datetime(year[i], month[i], day[i])))
    return np.array(dates_long)

In [17]:
def datetime64_to_yyyymmdd(dates):
    int_dates = np.zeros(np.size(dates), dtype=int)

    for i in range(0, np.size(dates)):
        int_dates[i] = int(
            dates[i].astype(str)[0:4]
            + dates[i].astype(str)[5:7]
            + dates[i].astype(str)[8:10]
        )
    return int_dates

In [18]:
def mer_ave(x, lat, latdim):  # assume x is (time,lat,lon)
    cos_lat = np.cos(np.deg2rad(lat))
    if np.sum(np.isnan(x) == 1) == 0:  # no nan
        x_mer_ave = np.average(x, latdim, weights=cos_lat)
    else:  # is nan
        print("has nan, but use nanmean")
        nt = np.size(x, 0)
        nlon = np.size(x, 2)
        nlat = np.size(x, 1)
        x_mer_ave = np.empty((nt, nlon))
        x_mer_ave[:] = np.nan
        for it in range(0, nt):
            for ilon in range(0, nlon):
                x2 = x[it, :, ilon]
                indices = ~np.isnan(x2)
                x_mer_ave[it, ilon] = np.average(x2[indices], weights=cos_lat[indices])
    return x_mer_ave

In [19]:
def mer_ave_2d(x, lat):  # assume x is (lat,lon)
    cos_lat = np.cos(np.deg2rad(lat))
    if np.sum(np.isnan(x) == 1) == 0:
        x_mer_ave = np.average(x, 0, weights=cos_lat)
    else:
        print("has nan, but use nanmean")
        nlon = np.size(x, 1)
        nlat = np.size(x, 0)
        if np.size(np.shape(x)) == 3:
            nt = np.size(x, 2)
            x_mer_ave = np.empty((nlon, nt))
        else:
            nt = 1
            x_mer_ave = np.empty((nlon))
        x_mer_ave[:] = np.nan
        for it in range(0, nt):
            for ilon in range(0, nlon):
                if nt != 1:
                    x2 = x[:, ilon, it]
                else:
                    x2 = x[:, ilon]
                indices = ~np.isnan(x2)
                if nt == 1:
                    x_mer_ave[ilon] = np.average(x2[indices], weights=cos_lat[indices])
                else:
                    x_mer_ave[ilon, it] = np.average(
                        x2[indices], weights=cos_lat[indices]
                    )
    return x_mer_ave

## Filtering Functions

In [20]:
def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype="band")
    return b, a

In [21]:
def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

In [22]:
def butter_highpass(cut, fs, order=4):
    nyq = 0.5 * fs
    cut = cut / nyq
    b, a = butter(order, cut, btype="highpass")
    return b, a

In [23]:
def butter_highpass_filter(data, cut, fs, order=5):
    b, a = butter_highpass(cut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

In [24]:
def butter_lowpass(cut, fs, order=4):
    nyq = 0.5 * fs
    cut = cut / nyq
    b, a = butter(order, cut, btype="lowpass")
    return b, a

In [25]:
def butter_lowpass_filter(data, cut, fs, order=5):
    b, a = butter_lowpass(cut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

In [26]:
def lanczos_lowpass_filter(data, cut, fs, order=101):
    """
    This function returns a filtered signal in which all frqeuencies above the 
    cut-off have been attenuated. This is done using a Lanczos filter. The 
    filter works on any dimensional data, as long as the filtering dimension
    is along axis=-1.

    Parameters
    ----------
    data : numpy.ndarray
        The data to be filtered.
    cut : float
        The frequency cut-off. All frequencies above 'cut' are attenuated.
    fs : float
        The sampling frequency of the data.
    order : int
        The number of weights to use in the Lanczos filter. Must be odd.

    Returns
    -------
    data_filtered : numpy.ndarray
        The filtered data.

    """

    # 2n-1 total weights
    n = int((order + 1) / 2)

    # Define the nyquist frequency
    nyq = 0.5 * fs

    # Define the range of the filter
    k = np.arange(1, n, 1)

    # Calculate the Lanczos sigma factor
    sigma = np.empty((order))
    sigma[n:] = np.sinc(2 * k * nyq * fs / n)
    sigma[: n - 1] = np.sinc(2 * k[::-1] * nyq * fs / n)
    sigma[n - 1] = np.sinc(0)

    # Calculate the ideal response factor
    w = np.empty((order))
    w[n:] = np.sin(2.0 * np.pi * cut * k) / (np.pi * k)
    w[: n - 1] = np.sin(2.0 * np.pi * cut * k[::-1]) / (np.pi * k[::-1])
    w[n - 1] = 2 * cut

    # Combine the effects of the ideal response factor and the Lanczos factor
    w_bar = w * sigma

    # Filter the data along axis=-1
    data_filtered = np.apply_along_axis(
        lambda m: np.convolve(m, w_bar, mode="same"), axis=-1, arr=data
    )

    return data_filtered

In [27]:
def lanczos_highpass_filter(data, cut, fs, order=101):
    """
    This function returns a filtered signal in which all frqeuencies below the 
    cut-off have been attenuated. This is done using a Lanczos filter. The 
    filter works on any dimensional data, as long as the filtering dimension 
    is along axis=-1.

    Parameters
    ----------
    data : numpy.ndarray
        The data to be filtered.
    cut : float
        The frequency cut-off. All frequencies below 'cut' are attenuated.
    n_points : int
        The number of weights to use in the Lanczos filter. Must be odd.
    fs : float
        The sampling frequency of the data.

    Returns
    -------
    data_filtered : numpy.ndarray
        The filtered data.

    """

    # 2n-1 total weights
    n = int((order + 1) / 2)

    # Define the nyquist frequency
    nyq = 0.5 * fs

    # Define the range of the filter
    k = np.arange(1, n, 1)

    # Calculate the Lanczos sigma factor
    sigma = np.empty((order))
    sigma[n:] = np.sinc(2 * k * nyq * fs / n)
    sigma[: n - 1] = np.sinc(2 * k[::-1] * nyq * fs / n)
    sigma[n - 1] = np.sinc(0)

    # Calculate the ideal response factor
    w = np.empty((order))
    w[n:] = -np.sin(2.0 * np.pi * cut * k) / (np.pi * k)
    w[: n - 1] = -np.sin(2.0 * np.pi * cut * k[::-1]) / (np.pi * k[::-1])
    w[n - 1] = 1 - (2 * cut)

    # Combine the effects of the ideal response factor and the Lanczos factor
    w_bar = w * sigma

    # Filter the data along axis=-1
    data_filtered = np.apply_along_axis(
        lambda m: np.convolve(m, w_bar, mode="same"), axis=-1, arr=data
    )

    return data_filtered

In [28]:
def lanczos_bandpass_filter(data, lowcut, highcut, fs, filter_axis, order=101):
    """
    This function returns a filtered signal in which all frqeuencies outside 
    the cut-off range have been attenuated. This is done using a Lanczos filter.
    The filter works on any dimensional data, as long as the filtering 
    dimension is along axis=-1.

    Parameters
    ----------
    data : numpy.ndarray
        The data to be filtered.
    lowcut : float
        The low frequency cut-off. All frequencies below 'lowcut' 
        are attenuated
    highcut : float
        The high frequency cut-off. All frequencies above 'highcut' 
        are attenuated
    fs : float
        The sampling frequency of the data.
    order : int
        The number of weights to use in the Lanczos filter. Must be odd.

    Returns
    -------
    data_filtered : numpy.ndarray
        The filtered data.
    """

    # 2n-1 total weights
    n = int((order + 1) / 2)

    # Define the nyquist frequency
    nyq = 0.5 * fs

    # Define the range of the filter
    k = np.arange(1, n, 1)

    # Calculate the Lanczos sigma factor
    # Calculate the Lanczos sigma factor
    sigma = np.empty((order))
    sigma[n:] = np.sinc(2 * k * nyq * fs / n)
    sigma[: n - 1] = np.sinc(2 * k[::-1] * nyq * fs / n)
    sigma[n - 1] = np.sinc(0)

    # Calculate the ideal response factor
    w = np.empty((order))
    w[n:] = np.sin(2.0 * np.pi * highcut * k) / (np.pi * k) - np.sin(
        2.0 * np.pi * lowcut * k
    ) / (np.pi * k)
    w[: n - 1] = np.sin(2.0 * np.pi * highcut * k[::-1]) / (np.pi * k[::-1]) - np.sin(
        2.0 * np.pi * lowcut * k[::-1]
    ) / (np.pi * k[::-1])
    w[n - 1] = 2 * (highcut - lowcut)

    # Combine the effects of the ideal response factor and the Lanczos factor
    w_bar = w * sigma

    # Filter the data along axis=-1
    data_filtered = np.apply_along_axis(
        lambda m: np.convolve(m, w_bar, mode="same"), axis=filter_axis, arr=data
    )

    return data_filtered

In [29]:
def fft_lowpass_filter(data, cut, fs):

    data_fft = np.fft.fft(data, axis=-1)
    frequencies = np.fft.fftfreq(data_fft.shape[-1], 1 / fs)

    if np.size(data_fft.shape) == 1:
        data_fft[np.abs(frequencies) <= cut] = 0
    elif np.size(data_fft.shape) == 2:
        data_fft[:, np.abs(frequencies) <= cut] = 0
    elif np.size(data_fft.shape) == 3:
        data_fft[:, :, np.abs(frequencies) <= cut] = 0
    else:
        print("Data must be 1D, 2D, or 3D")
    data_filtered = np.real(np.fft.ifft(data_fft, axis=-1))

    return data_filtered

In [30]:
def fft_highpass_filter(data, cut, fs):

    data_fft = np.fft.fft(data, axis=-1)
    frequencies = np.fft.fftfreq(data_fft.shape[-1], 1 / fs)

    if np.size(data_fft.shape) == 1:
        data_fft[np.abs(frequencies) >= cut] = 0
    elif np.size(data_fft.shape) == 2:
        data_fft[:, np.abs(frequencies) >= cut] = 0
    elif np.size(data_fft.shape) == 3:
        data_fft[:, :, np.abs(frequencies) >= cut] = 0
    else:
        print("Data must be 1D, 2D, or 3D")
    data_filtered = np.real(np.fft.ifft(data_fft, axis=-1))

    return data_filtered

In [31]:
def fft_bandpass_filter(data, lowcut, highcut, fs):

    data_fft = np.fft.fft(data, axis=-1)
    frequencies = np.fft.fftfreq(data_fft.shape[-1], 1 / fs)

    if np.size(data_fft.shape) == 1:
        data_fft[(np.abs(frequencies) <= lowcut) | (np.abs(frequencies) >= highcut)] = 0
    elif np.size(data_fft.shape) == 2:
        data_fft[
            :, (np.abs(frequencies) <= lowcut) | (np.abs(frequencies) >= highcut)
        ] = 0
    elif np.size(data_fft.shape) == 3:
        data_fft[
            :, :, (np.abs(frequencies) <= lowcut) | (np.abs(frequencies) >= highcut)
        ] = 0
    else:
        print("Data must be 1D, 2D, or 3D")
    data_filtered = np.real(np.fft.ifft(data_fft, axis=-1))

    return data_filtered

## EOF Functions

In [32]:
def normalize_before_ceof(u850_f_merave, u200_f_merave, olr_f_merave):  # (time,lon)
    # normalize the data before doing eof,
    # you need to do this because the unit of the two dataset is not the same
    u850_f_merave = np.transpose(u850_f_merave)
    u200_f_merave = np.transpose(u200_f_merave)
    olr_f_merave = np.transpose(olr_f_merave)
    nlon = np.size(u850_f_merave, 0)
    nt = np.size(u850_f_merave, 1)

    mu_u850 = np.nanmean(u850_f_merave)
    std_u850 = np.nanstd(u850_f_merave)
    u850_norm = (u850_f_merave - mu_u850) / std_u850

    mu_u200 = np.nanmean(u200_f_merave)
    std_u200 = np.nanstd(u200_f_merave)
    u200_norm = (u200_f_merave - mu_u200) / std_u200

    mu_olr = np.nanmean(olr_f_merave)
    std_olr = np.nanstd(olr_f_merave)
    olr_norm = (olr_f_merave - mu_olr) / std_olr

    X = np.zeros([nlon * 3, nt])
    X[0:nlon, :] = u850_norm
    X[nlon : 2 * nlon, :] = u200_norm
    X[2 * nlon : 3 * nlon, :] = olr_norm
    return X, mu_u850, std_u850, mu_u200, std_u200, mu_olr, std_olr

In [33]:
def eof(xx):  # x=x(structure dim, sampling dim)
    u, s, v = np.linalg.svd(xx, full_matrices=False)
    EOF = np.transpose(u)  # EOFi=EOF[i-1] EOF1=EOF[0],EOF2=EOF[1],....
    PC = np.matmul(np.transpose(u), xx)  # PCi =pc[i-1], pc1=pc[0],pc2=pc[1],...
    nt = np.size(xx, 1)
    eigval = s ** 2 / nt
    eigval_explained_var = eigval / np.sum(eigval) * 100  # percent

    # calculate degree of freedom so that we can do North test
    L = 1  # one-lag auto-corelation
    B = 0
    for k in range(L - 1, nt - L):
        B = B + np.sum(xx[:, k] * xx[:, k + L])
    phi_L = 1 / (nt - 2 * L) * B
    phi_0 = 1 / nt * np.sum(xx ** 2)
    r_L = phi_L / phi_0
    # r_L     = np.nanmean(phi_L)/np.nanmean(phi_0)
    dof = (1 - r_L ** 2) / (1 + r_L ** 2) * nt

    eigval_err = eigval_explained_var * np.sqrt(2 / dof)
    return EOF, PC, eigval, eigval_explained_var, eigval_err, dof, phi_0, phi_L

In [34]:
def eof(xx):  # x=x(structure dim, sampling dim)
    u, s, v = np.linalg.svd(xx, full_matrices=False)
    EOF = np.transpose(u)  # EOFi=EOF[i-1] EOF1=EOF[0],EOF2=EOF[1],....
    PC = np.matmul(np.transpose(u), xx)  # PCi =pc[i-1], pc1=pc[0],pc2=pc[1],...
    nt = np.size(xx, 1)
    eigval = s ** 2 / nt
    eigval_explained_var = eigval / np.sum(eigval) * 100  # percent

    # calculate degree of freedom so that we can do North test
    L = 1  # one-lag auto-corelation
    B = 0
    for k in range(L - 1, nt - L):
        B = B + np.sum(xx[:, k] * xx[:, k + L])
    phi_L = 1 / (nt - 2 * L) * B
    phi_0 = 1 / nt * np.sum(xx ** 2)
    r_L = phi_L / phi_0
    # r_L     = np.nanmean(phi_L)/np.nanmean(phi_0)
    dof = (1 - r_L ** 2) / (1 + r_L ** 2) * nt

    eigval_err = eigval_explained_var * np.sqrt(2 / dof)
    return EOF, PC, eigval, eigval_explained_var, eigval_err, dof, phi_0, phi_L


def rmm_eight_phase_index(rmm1, rmm2, time_f, rmm1_ann, rmm2_ann):
    rmm1_norm = (rmm1 - np.mean(rmm1_ann)) / np.std(rmm1_ann)  # normalized rmm1
    rmm2_norm = (rmm2 - np.mean(rmm2_ann)) / np.std(rmm2_ann)  # normalized rmm2
    n = np.zeros(9)
    RMM_ind = np.empty((9, np.size(time_f)))
    RMM_ind[:] = np.NaN
    for i in range(0, np.size(time_f)):
        n = n.astype(int)
        RMM1 = rmm1_norm[i]
        RMM2 = rmm2_norm[i]
        A = RMM1 ** 2 + RMM2 ** 2
        if (A < 1).any():  # weak MJO
            RMM_ind[8, n[8]] = i
            n[8] = n[8] + 1
        elif RMM1 < 0 and RMM2 < 0 and np.abs(RMM1) > np.abs(RMM2):  # PHASE1
            RMM_ind[0, n[0]] = i
            n[0] = n[0] + 1
        elif RMM1 < 0 and RMM2 < 0 and np.abs(RMM1) < np.abs(RMM2):  # 2
            RMM_ind[1, n[1]] = i
            n[1] = n[1] + 1
        elif RMM1 > 0 and RMM2 < 0 and np.abs(RMM1) < np.abs(RMM2):  # 3
            RMM_ind[2, n[2]] = i
            n[2] = n[2] + 1
        elif RMM1 > 0 and RMM2 < 0 and np.abs(RMM1) > np.abs(RMM2):  # 4
            RMM_ind[3, n[3]] = i
            n[3] = n[3] + 1
        elif RMM1 > 0 and RMM2 > 0 and np.abs(RMM1) > np.abs(RMM2):  # 5
            RMM_ind[4, n[4]] = i
            n[4] = n[4] + 1
        elif RMM1 > 0 and RMM2 > 0 and np.abs(RMM1) < np.abs(RMM2):  # 6
            RMM_ind[5, n[5]] = i
            n[5] = n[5] + 1
        elif RMM1 < 0 and RMM2 > 0 and np.abs(RMM1) < np.abs(RMM2):  # 7
            RMM_ind[6, n[6]] = i
            n[6] = n[6] + 1
        elif RMM1 < 0 and RMM2 > 0 and np.abs(RMM1) > np.abs(RMM2):  # 8
            RMM_ind[7, n[7]] = i
            n[7] = n[7] + 1
    return n, RMM_ind, rmm1_norm, rmm2_norm

In [35]:
def eight_phase_composite(x_f, RMM_ind):
    x_f_8ph = np.zeros([8, np.size(x_f, 1), np.size(x_f, 2)])
    for ph in range(0, 8):
        i = np.squeeze(np.argwhere(~np.isnan(RMM_ind[ph, :])))
        ii = RMM_ind[ph, i]
        ii = ii.astype(int)
        x_f_8ph[ph, :, :] = np.mean(x_f[ii, :, :], 0)
    return x_f_8ph

# Miscellaneous Functions

In [36]:
def modified_colormap(colormap, central_color, central_width, blend_strength):    
    '''
    This function modifies a colormap to set the central region to be white. 
    Within the region specified by the 'width' parameter, the colormap is blended towards white using a linspace.
    
    Parameters:
        colormap (str): The name of an existing matplotlib colormap
        central_width (float): The width of the region to be set to white
        blend_strength (float): The width of the regions to be blended to white

    Returns:
        modified_colormap (matplotlib.colors.LinearSegmentedColormap): The modified colormap
    
    '''
    # Import libraries
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import colors as mcolors
    
    try:
        c = mcolors.cnames[central_color]
    except: 
        raise KeyError('Not a matplotlib named color')
        
    central_color = list(mcolors.to_rgba(central_color))
    
    # Raise an error if the width is not between 0 and 1
    if ((central_width < 0)+(central_width > 1)):
        raise ValueError('Central width must be in range [0, 1]')
    elif ((blend_strength < 0) + (blend_strength > 1)):
        raise ValueError('Blend strength must be in range [0, 1]')
    
    # Convert the widths to the range [0, 127]
    else:             
        central_width = int(127*central_width)
        blend_strength = int(blend_strength*(127-central_width))

    # Get the colormap values
    original_colormap = plt.cm.get_cmap(colormap)
    newcolors = original_colormap(np.linspace(0, 1, 256))
    
    # Get the value of the colormap 'width' values left of the center, and blend from that value to white at the center
    newcolors[128-central_width-blend_strength:128-central_width, :] = np.linspace(
        newcolors[128-central_width-blend_strength, :], 
        central_color, 
        blend_strength
    )
    
    newcolors[128-central_width:128+central_width, :] = central_color
    
    # Get the value of the colormap 'width' values right of the center, and blend from white at the center to that value
    newcolors[128+central_width:128+central_width+blend_strength, :] = np.linspace(
        central_color,
        newcolors[128+central_width+blend_strength, :], 
        blend_strength
    )
    
    # Create a new colormap object from the modified map
    modified_colormap = mcolors.LinearSegmentedColormap.from_list(colormap+'_modified', newcolors)
    
    return modified_colormap

In [37]:
def parabolic_cylinder_function(y, order):
    from scipy import special
    poly = special.hermite(order)
    return poly(y)*np.exp(-y**2/2)

In [1]:
def tick_labeller(ticks, direction, degree_symbol=True):
    label = []
    for i in range(len(ticks)):
        if degree_symbol == True:
            if direction=='lon':
                if ticks[i] == 0 or np.abs(ticks[i]) >= 180: 
                    label.append(f"{np.abs(ticks[i]):.0f}°")
                elif ticks[i] < 0:
                    label.append(f"{np.abs(ticks[i]):.0f}°W")
                elif ticks[i] > 0:
                    label.append(f"{np.abs(ticks[i]):.0f}°E")
            elif direction=='lat':
                if ticks[i] == 0:
                    label.append(f"{np.abs(ticks[i]):.0f}°")
                elif ticks[i] < 0:
                    label.append(f"{np.abs(ticks[i]):.0f}°S")
                elif ticks[i] > 0:
                    label.append(f"{np.abs(ticks[i]):.0f}°N")
        else:
            if direction=='lon':
                if ticks[i] == 0 or np.abs(ticks[i]) >= 180: 
                    label.append(f"{np.abs(ticks[i]):.0f}")
                elif ticks[i] < 0:
                    label.append(f"{np.abs(ticks[i]):.0f}W")
                elif ticks[i] > 0:
                    label.append(f"{np.abs(ticks[i]):.0f}E")
            elif direction=='lat':
                if ticks[i] == 0:
                    label.append(f"{np.abs(ticks[i]):.0f}")
                elif ticks[i] < 0:
                    label.append(f"{np.abs(ticks[i]):.0f}S")
                elif ticks[i] > 0:
                    label.append(f"{np.abs(ticks[i]):.0f}N")
    return label