In [1]:
import os
import yaml
import copy
import time
import numpy as np
import pandas as pd
import xarray as xr

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
from scipy.stats import t as student_t  # <-- define at module scope

def corr_pvalue(da_x: xr.DataArray, da_y: xr.DataArray, dim: str):
    r = xr.corr(da_x, da_y, dim=dim)

    n = xr.where(np.isfinite(da_x) & np.isfinite(da_y), 1, 0).sum(dim)
    df = n - 2

    denom = xr.where((1.0 - r**2) > 0, (1.0 - r**2), np.nan)

    tstat = r * xr.apply_ufunc(
        np.sqrt,
        df / denom,
        dask="parallelized",
        output_dtypes=[np.float64],
    )

    # use a top-level function (no missing 'stats' inside workers)
    def pval_func(tstat_np, df_np):
        return 2.0 * student_t.sf(np.abs(tstat_np), df_np)

    p = xr.apply_ufunc(
        pval_func,
        tstat, df,
        dask="parallelized",
        output_dtypes=[np.float64],
    ).where(df > 0)

    return r, p

In [4]:
save_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/'

In [5]:
fn = '/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/CESM_minmax.zarr'
ds_CESM_default = xr.open_zarr(fn)

fn = '/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/ERA5_minmax.zarr'
ds_ERA5_default = xr.open_zarr(fn)

ds_ERA5 = ds_ERA5_default
ds_CESM = ds_CESM_default

ds_ERA5 = ds_ERA5.rename({'year': 'valid_year'})
varnames = list(ds_ERA5.keys())

ds_collection = []

for i_lead in range(10):
    ds_CESM_slice = ds_CESM.isel(lead_year=i_lead)
    
    # align two datasets
    ds_ERA5_slice = ds_ERA5
    ds_CESM_slice, ds_ERA5_slice = xr.align(ds_CESM_slice, ds_ERA5_slice, join="inner")
    
    # compute ACC if there's enough time values
    if len(ds_CESM_slice['valid_year']) >= 10:
        
        out_vars = {}
        for varname in varnames:
            r, p = corr_pvalue(ds_CESM_slice[varname], ds_ERA5_slice[varname], dim="valid_year")
            out_vars[varname] = r
            out_vars[f"{varname}_pval"] = p
            
        ds_out = xr.Dataset(out_vars).expand_dims(lead_year=[i_lead])
        ds_collection.append(ds_out)

ds_ACC_all = xr.concat(ds_collection, dim='lead_year')

In [6]:
ds_ACC_all = ds_ACC_all.load()

In [7]:
save_name = save_dir + 'CESM_minmax_ACC.zarr'
# ds_ACC_all.to_zarr(save_name, mode='w')
print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/CESM_minmax_ACC.zarr


In [8]:
fn = '/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/CESM_minmax_detrend.zarr'
ds_CESM_detrend = xr.open_zarr(fn)

fn = '/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/ERA5_minmax_detrend.zarr'
ds_ERA5_detrend = xr.open_zarr(fn)

ds_ERA5 = ds_ERA5_detrend
ds_CESM = ds_CESM_detrend

ds_ERA5 = ds_ERA5.rename({'year': 'valid_year'})
varnames = list(ds_ERA5.keys())

ds_collection = []

for i_lead in range(10):
    ds_CESM_slice = ds_CESM.isel(lead_year=i_lead)
    
    # align two datasets
    ds_ERA5_slice = ds_ERA5
    ds_CESM_slice, ds_ERA5_slice = xr.align(ds_CESM_slice, ds_ERA5_slice, join="inner")
    
    # compute ACC if there's enough time values
    if len(ds_CESM_slice['valid_year']) >= 10:
        
        out_vars = {}
        for varname in varnames:
            r, p = corr_pvalue(ds_CESM_slice[varname], ds_ERA5_slice[varname], dim="valid_year")
            out_vars[varname] = r
            out_vars[f"{varname}_pval"] = p
            
        ds_out = xr.Dataset(out_vars).expand_dims(lead_year=[i_lead])
        ds_collection.append(ds_out)

ds_ACC_detrend = xr.concat(ds_collection, dim='lead_year')

In [9]:
ds_ACC_detrend = ds_ACC_detrend.load()

In [10]:
save_name = save_dir + 'CESM_minmax_ACC_detrend.zarr'
# ds_ACC_detrend.to_zarr(save_name, mode='w')
print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/CESM_minmax_ACC_detrend.zarr
