In [2]:
import xarray as xr
import numpy as np
import xcdat as xc
import xskillscore as xscore
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import cmcrameri.cm as cmc
from scipy.stats import linregress
from scipy.stats import f
from scipy import stats
from sklearn import linear_model
import statsmodels.api as sm
import statsmodels.regression as regression
import statsmodels.formula.api as smf

# Granger Causality Core Functionality

In [3]:
def make_stationary(ts: xr.DataArray, sanity_check: bool = False):
    """ Make time series stationary by removing mean, dividing by variance and linear detrending

    Args:
        ts (xr.DataArray): 1D time series
    """

    tss = (ts - ts.mean())/ts.std()
    reg = linregress(tss.time, tss)
    tsd = tss - (reg.slope*tss.time + reg.intercept)

    if sanity_check:
        # Log the original and stationary mean and variance
        print("Original Variance: ", ts.std().values, "and Original Mean: ", ts.mean().values)
        print("Stationary Variance: ", tsd.std().values, "and Stationary Mean: ", tsd.mean().values)
        # Plot the original and stationary time series
        plt.plot(ts.time, ts, label="original")
        plt.plot(tsd.time, tsd, label="detrended")
        plt.legend(); plt.grid(); plt.xlabel("time"); plt.ylabel("Variable")
    
    return tsd

In [54]:
def multivariable_regression(X: xr.DataArray, Y: xr.DataArray, verbose: bool = False):
    # Perform multivariable ordinaty least squares regression
    # model = linear_model.LinearRegression()
    # model.fit(X, Y)

    model = regression.linear_model.OLS(Y, X).fit()

    if verbose:
        print(model.summary())
    
    return model

def create_dataset(ds1: xr.DataArray, nlag: int = 10, include_zero_lag: bool = False):

    nsamples = ds1.shape[0] - nlag 
    if not include_zero_lag:
        nsamples = nsamples - 1

    X = np.full((nsamples, nlag), np.nan)
    Y = np.full((nsamples), np.nan)

    for i in range(nsamples):
        if not include_zero_lag:
            X[i,:] = ds1[i+1:i+1+nlag]
        else: 
            X[i,:] = ds1[i:i+nlag]
        Y[i] = ds1[i]
    return X, Y

def get_AIC_lag(x):
    return 10
    # NLAG = 50
    # print(Y1.shape[0] - NLAG - 1)
    # if NLAG*2 > (Y1.shape[0] - NLAG - 1): 
    #     NLAG = int(np.floor((Y1.shape[0] - NLAG - 1)/2))

def granger_causality(Xr: np.ndarray, X: np.ndarray, Y:np.ndarray, unrestricted_model, restricted_model, nlag: int = 10, verbose: bool = False):
    """_summary_

    Args:
        Y1 (np.ndarray): _description_
        Y2 (np.ndarray): _description_
        unrestricted_model (_type_): _description_
        restricted_model (_type_): _description_
    """
    # Get Predictions
    Yp_restricted = restricted_model.predict(Xr) # Equivalent to np.matmul(X1, restricted_model.coef_) + restricted_model.intercept_
    Yp_unrestricted = unrestricted_model.predict(X) # Equivalent to np.matmul(X2, unrestricted_model.coef_) + unrestricted_model.intercept_

    # Get Residuals
    err_y1 = Y - Yp_restricted
    err_y2 = Y - Yp_unrestricted
    
    # F-test on residuals
    F = np.var(err_y1)/np.var(err_y2) # Ratio of variances
    df1, df2 = len(err_y1) - 1, len(err_y2) - 1 # Number of samples - 1
    p_value = f.sf(F, df1, df2) # Survival function (1 - cdf) at F

    # Signifigance of regression coefficients
    ALPHA = 0.05
    sig_reg_coeffs = unrestricted_model.pvalues[nlag:]

    # Test Granger Causality
    if np.any(sig_reg_coeffs <= ALPHA) and p_value < 0.05:
        causality = True
        magnitude = np.sum(unrestricted_model.params[nlag:])
        if verbose: print(f"Reject Null Hypothesis: Granger Causality DOES exist: {magnitude} with nlag={nlag}, F-test p={p_value}, p-values of regression coefficients: {sig_reg_coeffs}")
    else: 
        magnitude = np.nan
        causality = False
        if verbose: print(f"Fail to Reject Null Hypothesis: Granger Causality DOES NOT exist: F-test p={p_value}, p-values of regression coefficients: {sig_reg_coeffs}")

    return Yp_restricted, Yp_unrestricted, magnitude, causality

In [66]:
def Granger_Pipeline(Y1, Y2, sanity_check: bool = False, include_zero_lag: bool = False, verbose: bool = True):
    Y1 = make_stationary(Y1, sanity_check=sanity_check)
    Y2 = make_stationary(Y2, sanity_check=sanity_check)

    MAXLAG = int(20)
    Y2_pred_Y1, Y1_pred_Y2 = {}, {}
    aic_Y2_pred_Y1, aic_Y1_pred_Y2 = np.inf, np.inf
    for NLAG in range(1, MAXLAG):

        ##### Does Y2 Granger Cause Y1? #####
        # Restricted Model
        X1_ds, Y1_ds = create_dataset(ds1=Y1, nlag=NLAG, include_zero_lag=include_zero_lag)
        restricted_model = multivariable_regression(X1_ds, Y1_ds, verbose=verbose)
        # Unrestricted Model
        X2_ds, Y2_ds = create_dataset(ds1=Y2, nlag=NLAG, include_zero_lag=include_zero_lag)
        X_ds = np.concatenate([X1_ds, X2_ds], axis=1) 
        unrestricted_model = multivariable_regression(X=X_ds, Y=Y1_ds, verbose=verbose) 

        if aic_Y2_pred_Y1 > unrestricted_model.aic:
            aic_Y2_pred_Y1 = unrestricted_model.aic
            if verbose: print("New best model found: ", unrestricted_model.aic)
            Y2_pred_Y1["restricted_model"] = restricted_model
            Y2_pred_Y1["unrestricted_model"] = unrestricted_model
            Y2_pred_Y1["nlag"] = NLAG
            Y2_pred_Y1["Xr"] = X1_ds
            Y2_pred_Y1["X"] = X_ds
            Y2_pred_Y1["Y"] = Y1_ds

        ##### Does Y1 Granger Cause Y2? #####
        # Restricted Model
        X2_ds, Y2_ds = create_dataset(ds1=Y2, nlag=NLAG, include_zero_lag=include_zero_lag)
        restricted_model = multivariable_regression(X2_ds, Y2_ds, verbose=verbose)
        # Unrestricted Model
        X1_ds, Y1_ds = create_dataset(ds1=Y1, nlag=NLAG, include_zero_lag=include_zero_lag)
        X_ds = np.concatenate([X2_ds, X1_ds], axis=1) 
        unrestricted_model = multivariable_regression(X=X_ds, Y=Y2_ds, verbose=verbose) 

        if aic_Y1_pred_Y2 > unrestricted_model.aic:
            aic_Y1_pred_Y2 = unrestricted_model.aic
            if verbose: print("New best model found: ", unrestricted_model.aic)
            Y1_pred_Y2["restricted_model"] = restricted_model
            Y1_pred_Y2["unrestricted_model"] = unrestricted_model
            Y1_pred_Y2["nlag"] = NLAG
            Y1_pred_Y2["Xr"] = X2_ds
            Y1_pred_Y2["X"] = X_ds
            Y1_pred_Y2["Y"] = Y2_ds

    Y1p_restricted, Y1p_unrestricted, magnitudeY1, causalityY1 = granger_causality(
        Xr=Y2_pred_Y1["Xr"],
        X=Y2_pred_Y1["X"],
        Y=Y2_pred_Y1["Y"],
        unrestricted_model=Y2_pred_Y1["unrestricted_model"], 
        restricted_model=Y2_pred_Y1["restricted_model"], 
        nlag=Y2_pred_Y1["nlag"],
        verbose=verbose
    )

    Y2p_restricted, Y2p_unrestricted, magnitudeY2, causalityY2 = granger_causality(
        Xr=Y1_pred_Y2["Xr"],
        X=Y1_pred_Y2["X"],
        Y=Y1_pred_Y2["Y"],
        unrestricted_model=Y1_pred_Y2["unrestricted_model"], 
        restricted_model=Y1_pred_Y2["restricted_model"], 
        nlag=Y1_pred_Y2["nlag"],
        verbose=verbose
    )
    
    return causalityY1, causalityY2, magnitudeY1, magnitudeY2
    

In [6]:
# Use for TESTING
SST_T_WE_CMIP6 = xr.open_dataarray("data/piControl/rolling_gradient_cmip6_WE-Trend.nc")
SST_T_EPT_CMIP6 = xr.open_dataarray("data/piControl/rolling_gradient_cmip6_eastPacificTriangle_trend.nc")
Granger_Pipeline(SST_T_WE_CMIP6.sel(model="CanESM5-1"), SST_T_EPT_CMIP6.sel(model="CanESM5-1"), verbose=False)

# Load and Prep SEB 

In [8]:
SEB_CMIP5 = xr.open_dataset("data/piControl/SEB_CMIP5_full.nc")
SEB_CMIP6 = xr.open_dataset("data/piControl/SEB_CMIP6_full.nc")
SEB_CMIP6

In [17]:
def get_region(name, flux): 
    flux = remove_land_full(ds=flux[name], var=name)
    flux = fix_coords(flux.rename(name).to_dataset())

    flux_west = flux.sel(lat=slice(-5, 5), lon=slice(110, 165)).spatial.average(name)[name]
    flux_east = flux.sel(lat=slice(-5, 5), lon=slice(-135, -80)).spatial.average(name)[name]
    flux_WE = flux_west - flux_east

    flux_EPSA = flux.sel(lat=slice(-40, -5), lon=slice(-95, -70)).spatial.average(name)[name]

    flux_SO = flux.sel(lat=slice(-70, -50), lon=slice(-180, -75)).spatial.average(name)[name]

    flux_E = flux.sel(lat=slice(-5, 5), lon=slice(-135, -80)).spatial.average(name)[name]

    flux_EPT = get_triangle(flux[name])
    flux_EPT = fix_coords(flux_EPT.rename(name).to_dataset()).spatial.average(name)[name]

    return {"WE": flux_WE, "EPSA": flux_EPSA, "SO": flux_SO, "EP": flux_E, "EPT": flux_EPT}


def fix_coords(data):
    data = data.bounds.add_bounds("X")
    data = data.bounds.add_bounds("Y")
    data = data.bounds.add_bounds("T")
    data = xc.swap_lon_axis(data, to=(-180, 180))
    return data


def get_triangle(tos, latmin: float = -38.75, latmax: float = -1.25, lonmin: float = -178.75, lonmax: float = -71.25, RES: float = 2.5):
    DY = latmax - latmin
    DX = lonmax - lonmin 
    dx = RES*round(DX/DY)
    dy = RES

    # print(f"For each latitude step of {dy} degrees, longitude step is {dx}")

    latcoords = np.arange(latmax, latmin-dy, -dy)
    loncoords = np.arange(lonmin, lonmax+dx, dx)
    lonraw = np.arange(lonmin, lonmax+dx, RES)

    ctos = tos.sel(lon=slice(lonmin, lonmax), lat=slice(latmin, latmax))
    nmodel, _, nlon, ntime = ctos.shape
    # print(ctos)

    for i, clon in enumerate(lonraw):
        j = np.where(clon == loncoords)[0]

        if i == nlon: break

        # print("j prior: ", j)
        if len(j) == 0: 
            j = jold
        else: 
            j = j[0]
             
        # print("j: ", j)
        nlats = int(len(latcoords) - j) # nlats below diag
        # print("nlats: ", nlats)
        ctos[:,:nlats,i,:] = np.full((nmodel, nlats,ntime), np.nan) 
        
        jold = j
    
    return ctos


def remove_land_full(ds, var="skt"):
    ds = ds.rename(var).to_dataset()
    ds = xc.swap_lon_axis(ds, to=(-180, 180))
    from global_land_mask import globe
    # Set land to NaN
    lon_grid,lat_grid = np.meshgrid(ds.lon, ds.lat)
    globe_land_mask = globe.is_land(lat_grid,lon_grid)
    globe_land_mask_nd = np.tile(globe_land_mask,(ds[var].shape[0],ds[var].shape[1], 1,1))
    ds_no_land = xr.where(globe_land_mask_nd==True,np.nan,ds[var]) 
    return ds_no_land



In [35]:
LH_EPT = get_region("LH_Other", SEB_CMIP6)["EPT"]
SW_EPSA = get_region("SW", SEB_CMIP6)["EPSA"]

# Run Granger Causality Analysis

In [67]:
for model in LH_EPT.model.values:
    print(model)
    try:
        # Remove NANS
        LH_EPT_model = LH_EPT.sel(model=model).dropna("time")
        SW_EPSA_model = SW_EPSA.sel(model=model, time=LH_EPT_model.time)
        print(LH_EPT_model.shape, SW_EPSA_model.shape)
        # Run Granger Causality Pipeline
        causalityY1, causalityY2, magnitudeY1, magnitudeY2 = Granger_Pipeline(Y1=SW_EPSA_model, Y2=LH_EPT_model, verbose=False)
        print(causalityY1, causalityY2, magnitudeY1, magnitudeY2)
    except Exception as e:
        print(e)
        continue

BCC-CSM2-MR
(109,) (109,)
True False -0.040153155567059645 nan
CESM2-WACCM
(109,) (109,)
True True -0.05103028594762532 -0.016149591607814692
GISS-E2-1-G
(109,) (109,)
True False -0.042565448898546576 nan
EC-Earth3-Veg
(109,) (109,)
True False -0.13658281095499536 nan
GISS-E2-2-H
(109,) (109,)
True False -0.022033532667739997 nan
E3SM-1-1
(109,) (109,)
False False nan nan
SAM0-UNICON
(109,) (109,)
False False nan nan
CESM2
(109,) (109,)
False False nan nan
MPI-ESM-1-2-HAM
(109,) (109,)
False True nan -0.09134496896487107
CMCC-CM2-SR5
(109,) (109,)
False False nan nan
GFDL-ESM4
(109,) (109,)
False False nan nan
GISS-E2-1-H
(109,) (109,)
True True 0.3556908796232929 0.567802682347146
E3SM-1-0
(109,) (109,)
False True nan 0.42795775380580403
CESM2-FV2
(109,) (109,)
True False -0.10387771156835113 nan
INM-CM4-8
(109,) (109,)
True False 0.09762128248682606 nan
FGOALS-f3-L
(109,) (109,)
True False 0.01483211812062113 nan
CESM2-WACCM-FV2
(109,) (109,)
False True nan 0.5058785532047976
EC-Eart