In [148]:
import xarray as xr
import numpy as np
import xcdat as xc
import xskillscore as xscore
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import cmcrameri.cm as cmc
from scipy.stats import linregress
from scipy.stats import f
from scipy import stats
from sklearn import linear_model
import statsmodels.api as sm
import statsmodels.regression as regression
import statsmodels.formula.api as smf

In [184]:
def make_stationary(ts: xr.DataArray, sanity_check: bool = False):
    """ Make time series stationary by removing mean, dividing by variance and linear detrending

    Args:
        ts (xr.DataArray): 1D time series
    """

    tss = (ts - ts.mean())/ts.std()
    reg = linregress(tss.time, tss)
    tsd = tss - (reg.slope*tss.time + reg.intercept)

    if sanity_check:
        # Log the original and stationary mean and variance
        print("Original Variance: ", ts.std().values, "and Original Mean: ", ts.mean().values)
        print("Stationary Variance: ", tsd.std().values, "and Stationary Mean: ", tsd.mean().values)
        # Plot the original and stationary time series
        plt.plot(ts.time, ts, label="original")
        plt.plot(tsd.time, tsd, label="detrended")
        plt.legend(); plt.grid(); plt.xlabel("time"); plt.ylabel("Variable")
    
    return tsd

In [278]:
def multivariable_regression(X: xr.DataArray, Y: xr.DataArray, verbose: bool = False):
    # Perform multivariable ordinaty least squares regression
    # model = linear_model.LinearRegression()
    # model.fit(X, Y)

    model = regression.linear_model.OLS(Y, X).fit()

    if verbose:
        print(model.summary())
    
    return model

def create_dataset(ds1: xr.DataArray, nlag: int = 10, include_zero_lag: bool = False):

    nsamples = ds1.shape[0] - nlag 
    if not include_zero_lag:
        nsamples = nsamples - 1

    X = np.full((nsamples, nlag), np.nan)
    Y = np.full((nsamples), np.nan)

    for i in range(nsamples):
        if not include_zero_lag:
            X[i,:] = ds1[i+1:i+1+nlag]
        else: 
            X[i,:] = ds1[i:i+nlag]
        Y[i] = ds1[i]
    return X, Y

def get_AIC_lag(x):
    return 10
    # NLAG = 50
    # print(Y1.shape[0] - NLAG - 1)
    # if NLAG*2 > (Y1.shape[0] - NLAG - 1): 
    #     NLAG = int(np.floor((Y1.shape[0] - NLAG - 1)/2))

def granger_causality(Xr: np.ndarray, X: np.ndarray, Y:np.ndarray, unrestricted_model, restricted_model, nlag: int = 10):
    """_summary_

    Args:
        Y1 (np.ndarray): _description_
        Y2 (np.ndarray): _description_
        unrestricted_model (_type_): _description_
        restricted_model (_type_): _description_
    """
    # Get Predictions
    Yp_restricted = restricted_model.predict(Xr) # Equivalent to np.matmul(X1, restricted_model.coef_) + restricted_model.intercept_
    Yp_unrestricted = unrestricted_model.predict(X) # Equivalent to np.matmul(X2, unrestricted_model.coef_) + unrestricted_model.intercept_

    # Get Residuals
    err_y1 = Y - Yp_restricted
    err_y2 = Y - Yp_unrestricted
    
    # F-test on residuals
    F = np.var(err_y1)/np.var(err_y2) # Ratio of variances
    df1, df2 = len(err_y1) - 1, len(err_y2) - 1 # Number of samples - 1
    p_value = f.sf(F, df1, df2) # Survival function (1 - cdf) at F

    # Signifigance of regression coefficients
    ALPHA = 0.05
    sig_reg_coeffs = unrestricted_model.pvalues[nlag:]

    # Test Granger Causality
    if np.any(sig_reg_coeffs <= ALPHA) and p_value < 0.05:
        magnitude = np.sum(unrestricted_model.params[nlag:])
        print(f"Reject Null Hypothesis: Granger Causality DOES exist: {magnitude} F-test p={p_value}, p-values of regression coefficients: {sig_reg_coeffs}")
    else: 
        print(f"Fail to Reject Null Hypothesis: Granger Causality DOES NOT exist: F-test p={p_value}, p-values of regression coefficients: {sig_reg_coeffs}")

    return Yp_restricted, Yp_unrestricted

In [281]:
def Granger_Pipeline(Y1, Y2, sanity_check: bool = False, include_zero_lag: bool = False, verbose: bool = True):
    Y1 = make_stationary(Y1, sanity_check=sanity_check)
    Y2 = make_stationary(Y2, sanity_check=sanity_check)

    MAXLAG = int(30)
    Y2_pred_Y1, Y1_pred_Y2 = {}, {}
    aic_Y2_pred_Y1, aic_Y1_pred_Y2 = np.inf, np.inf
    for NLAG in range(1, MAXLAG):

        ##### Does Y2 Granger Cause Y1? #####
        # Restricted Model
        X1_ds, Y1_ds = create_dataset(ds1=Y1, nlag=NLAG, include_zero_lag=include_zero_lag)
        restricted_model = multivariable_regression(X1_ds, Y1_ds, verbose=verbose)
        # Unrestricted Model
        X2_ds, Y2_ds = create_dataset(ds1=Y2, nlag=NLAG, include_zero_lag=include_zero_lag)
        X_ds = np.concatenate([X1_ds, X2_ds], axis=1) 
        unrestricted_model = multivariable_regression(X=X_ds, Y=Y1_ds, verbose=verbose) 

        if aic_Y2_pred_Y1 > unrestricted_model.aic:
            print("New best model found: ", unrestricted_model.aic)
            Y2_pred_Y1["restricted_model"] = restricted_model
            Y2_pred_Y1["unrestricted_model"] = unrestricted_model
            Y2_pred_Y1["nlag"] = NLAG
            Y2_pred_Y1["Xr"] = X1_ds
            Y2_pred_Y1["X"] = X_ds
            Y2_pred_Y1["Y"] = Y1_ds

        ##### Does Y1 Granger Cause Y2? #####
        # Restricted Model
        X2_ds, Y2_ds = create_dataset(ds1=Y2, nlag=NLAG, include_zero_lag=include_zero_lag)
        restricted_model = multivariable_regression(X1_ds, Y1_ds, verbose=verbose)
        # Unrestricted Model
        X1_ds, Y1_ds = create_dataset(ds1=Y1, nlag=NLAG, include_zero_lag=include_zero_lag)
        X_ds = np.concatenate([X2_ds, X1_ds], axis=1) 
        unrestricted_model = multivariable_regression(X=X_ds, Y=Y2_ds, verbose=verbose) 

        if aic_Y1_pred_Y2 > unrestricted_model.aic:
            print("New best model found: ", unrestricted_model.aic)
            Y1_pred_Y2["restricted_model"] = restricted_model
            Y1_pred_Y2["unrestricted_model"] = unrestricted_model
            Y1_pred_Y2["nlag"] = NLAG
            Y1_pred_Y2["Xr"] = X2_ds
            Y1_pred_Y2["X"] = X_ds
            Y1_pred_Y2["Y"] = Y2_ds

    Y1p_restricted, Y1p_unrestricted = granger_causality(
        Xr=Y2_pred_Y1["Xr"],
        X=Y2_pred_Y1["X"],
        Y=Y2_pred_Y1["Y"],
        unrestricted_model=Y2_pred_Y1["unrestricted_model"], 
        restricted_model=Y2_pred_Y1["restricted_model"], 
        nlag=Y2_pred_Y1["nlag"]
    )

    Y2p_restricted, Y2p_unrestricted = granger_causality(
        Xr=Y1_pred_Y2["Xr"],
        X=Y1_pred_Y2["X"],
        Y=Y1_pred_Y2["Y"],
        unrestricted_model=Y1_pred_Y2["unrestricted_model"], 
        restricted_model=Y1_pred_Y2["restricted_model"], 
        nlag=Y1_pred_Y2["nlag"]
    )
    

In [282]:
# Use for TESTING
SST_T_WE_CMIP6 = xr.open_dataarray("data/piControl/rolling_gradient_cmip6_WE-Trend.nc")
SST_T_EPT_CMIP6 = xr.open_dataarray("data/piControl/rolling_gradient_cmip6_eastPacificTriangle_trend.nc")
Granger_Pipeline(SST_T_WE_CMIP6.sel(model="CanESM5-1"), SST_T_EPT_CMIP6.sel(model="CanESM5-1"), verbose=False)

New best model found:  233.63080094618562
New best model found:  143.9432161266346
New best model found:  182.64179730351532
New best model found:  34.59429974586047
New best model found:  179.50022473484603
New best model found:  32.86241156991994
New best model found:  172.63635870196484
New best model found:  36.547162043104095
New best model found:  170.55039209529843
New best model found:  40.4987091100308
New best model found:  172.73483008763142
New best model found:  42.282232790632605
New best model found:  170.40932947497072
New best model found:  32.94479890551952
New best model found:  172.9089853502337
New best model found:  26.502853336901637
New best model found:  172.8711605084292
New best model found:  27.5693677941143
New best model found:  174.15817584826556
New best model found:  26.32340494239034
New best model found:  174.659034292872
New best model found:  21.27245300142914
New best model found:  177.87659263617942
New best model found:  20.586823020980205
New be