# Climate Mode Calculation
- calculates climate mode index across CMIP6 models for intercomparison of yield response for phases

## Imports

In [11]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import glob as glob
import dask
import os 
import pandas as pd

## Processing SST (tos) data from the 18 underlying LOCA2 models
- need to regrid to common grid, everything is in .gn, not .gr

### Inputs

In [12]:
base_path = "/storage/group/pches/default/users/cta5244/CMIP6_tos/ssp370_omon_tos"
hist_path = "/storage/group/pches/default/users/cta5244/CMIP6_tos/hist_omon_tos"
output_path = "/storage/group/pches/default/users/cta5244/enso4_loca2_underlying_models/ssp370"
model_paths = sorted(glob.glob(f"{base_path}/*"))
yearly_paths = sorted(glob.glob(f"{model_paths[0]}/*"))

### Computing

In [3]:
from dask_jobqueue import SLURMCluster

cluster = SLURMCluster(
    # account="pches",
    account="open",
    cores=1,
    memory="150GiB",
    walltime="02:00:00",
)

cluster.scale(jobs=2) 

Perhaps you already have a cluster running?
Hosting the HTTP server on port 40633 instead


In [8]:
from dask.distributed import Client

client = Client(cluster)

In [9]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://146.186.150.12:40633/status,

0,1
Dashboard: http://146.186.150.12:40633/status,Workers: 2
Total threads: 2,Total memory: 300.00 GiB

0,1
Comm: tcp://146.186.150.12:41165,Workers: 2
Dashboard: http://146.186.150.12:40633/status,Total threads: 2
Started: Just now,Total memory: 300.00 GiB

0,1
Comm: tcp://10.6.1.160:34935,Total threads: 1
Dashboard: http://10.6.1.160:41209/status,Memory: 150.00 GiB
Nanny: tcp://10.6.1.160:46871,
Local directory: /tmp/dask-scratch-space/worker-zrs7ri5r,Local directory: /tmp/dask-scratch-space/worker-zrs7ri5r

0,1
Comm: tcp://10.6.1.159:39189,Total threads: 1
Dashboard: http://10.6.1.159:40029/status,Memory: 150.00 GiB
Nanny: tcp://10.6.1.159:33187,
Local directory: /tmp/dask-scratch-space/worker-2zn7z_31,Local directory: /tmp/dask-scratch-space/worker-2zn7z_31


In [16]:
def calculate_enso_index(model_path):
    """
    given model_path, calculates enso index
    inputs 
    - model_path
    - save_out_path
    """
    yearly_paths = sorted(glob.glob(f"{model_path}/*"))
    hist_paths = sorted(glob.glob(f"{hist_path}/{model_path.split('/')[-1]}/*"))
    # output string with gr not gn 
    new_list = ['gr' if item == 'gn' else item for item in yearly_paths[0].split("/")[-1].split("_")]
    output_unique_str = "_".join(new_list)
    # opening & regridding datasets
    print('a')
    ds_cmip6 = xr.open_mfdataset(yearly_paths + hist_paths)
    print('b')
    ds_cmip6.load()
    print('c')
    two_dim_bool = False
    if ('latitude' in ds_cmip6.coords) and (len(ds_cmip6.latitude.shape) == 1):
        ds_cmip6 = ds_cmip6.rename({'latitude':'lat'})
        ds_cmip6 = ds_cmip6.rename({'longitude':'lon'})
        ds_cmip6 = ds_cmip6.assign_coords(longitude=(ds_cmip6['lon'] % 360))
        lat = ds_cmip6['lat']
        lon = ds_cmip6['lon']
        lon2d, lat2d = xr.broadcast(lon, lat)
        mask_n4 = ((lat2d >= -5) & (lat2d <= 5) & (lon2d >= 160) & (lon2d <= 210))
    elif ('lat' in ds_cmip6.coords) and (len(ds_cmip6.lat.shape) == 1):
        ds_cmip6 = ds_cmip6.assign_coords(longitude=(ds_cmip6['lon'] % 360))
        lat = ds_cmip6['lat']
        lon = ds_cmip6['lon']
        lon2d, lat2d = xr.broadcast(lon, lat)
        mask_n4 = ((lat2d >= -5) & (lat2d <= 5) & (lon2d >= 160) & (lon2d <= 210))
    elif ('latitude' in ds_cmip6.coords) and (len(ds_cmip6.latitude.shape) == 2):
        ds_cmip6 = ds_cmip6.assign_coords(longitude=(ds_cmip6['latitude'] % 360))
        lon2d, lat2d = ds_cmip6.longitude, ds_cmip6.latitude
        mask_n4 = ((lat2d >= -5) & (lat2d <= 5) & (lon2d >= 160) & (lon2d <= 210))
        two_dim_bool = True
    else:
        print(ds_cmip6.coords)
        return
    print('letter after c')
    tos = ds_cmip6['tos'].where(mask_n4)
    weights = xr.ufuncs.cos(np.deg2rad(lat2d)).where(mask_n4)
    weights = weights.fillna(0)
    clim = tos.sel(time=slice('1991','2020')).groupby('time.month').mean('time')
    anom = tos.groupby('time.month') - clim
    anom = anom.sel(time=slice('2015','2100'))
    # weighting mean 
    spatial_dims = tuple(d for d in anom.dims if d != "time")
    w = weights.broadcast_like(anom).fillna(0)
    nino4 = anom.weighted(w).mean(spatial_dims)   
    # postprocessing of last time series 
    nino4_3m = nino4.rolling(time=3, center=True).mean()
    nino4_var = nino4_3m - nino4_3m.rolling(time=121, center=True).mean()   
    return nino4_var


In [17]:
results = []
for model_path in model_paths[:1]:
    out = dask.delayed(calculate_enso_index)(model_path=model_path)
    results.append(out)

In [18]:
results = dask.compute(*results)


KeyboardInterrupt



In [None]:
results

## Compute (no dask) future

In [None]:
import os
sigma_threshold = True
persistence_length = 5

for model_path in model_paths[8:9]:
    
    yearly_paths = sorted(glob.glob(f"{model_path}/*"))
    hist_paths = sorted(glob.glob(f"{hist_path}/{model_path.split('/')[-1]}/*"))
    # output string with gr not gn 
    new_list = ['gr' if item == 'gn' else item for item in yearly_paths[0].split("/")[-1].split("_")]
    output_unique_str = "_".join(new_list)
    # opening & regridding datasets
    out_path = f"{output_path}/{output_unique_str}_enso_labels.nc"

    if os.path.exists(out_path) and os.path.getsize(out_path) > 0:
        print(f"SKIP (exists): {out_path}")
        continue
        
    else:
        ds_cmip6 = xr.open_mfdataset(
            yearly_paths,  # + hist_paths if needed
            combine="by_coords",
            chunks={"time": 120},          
            parallel=True,
            data_vars="minimal",
            coords="minimal",
            compat="override",
            preprocess=lambda d: d[["tos"]].astype({"tos": "float32"}),
        )
        two_dim_bool = False
        if ('latitude' in ds_cmip6.coords) and (len(ds_cmip6.latitude.shape) == 1):
            ds_cmip6 = ds_cmip6.rename({'latitude':'lat'})
            ds_cmip6 = ds_cmip6.rename({'longitude':'lon'})
            ds_cmip6 = ds_cmip6.assign_coords(longitude=(ds_cmip6['lon'] % 360))
            lat = ds_cmip6['lat']
            lon = ds_cmip6['lon']
            lon2d, lat2d = xr.broadcast(lon, lat)
            mask_n4 = ((lat2d >= -5) & (lat2d <= 5) & (lon2d >= 160) & (lon2d <= 210))
        elif ('lat' in ds_cmip6.coords) and (len(ds_cmip6.lat.shape) == 1):
            ds_cmip6 = ds_cmip6.assign_coords(longitude=(ds_cmip6['lon'] % 360))
            lat = ds_cmip6['lat']
            lon = ds_cmip6['lon']
            lon2d, lat2d = xr.broadcast(lon, lat)
            mask_n4 = ((lat2d >= -5) & (lat2d <= 5) & (lon2d >= 160) & (lon2d <= 210))
        elif ('latitude' in ds_cmip6.coords) and (len(ds_cmip6.latitude.shape) == 2):
            ds_cmip6 = ds_cmip6.assign_coords(longitude=(ds_cmip6['longitude'] % 360))
            lon2d, lat2d = ds_cmip6.longitude, ds_cmip6.latitude
            mask_n4 = ((lat2d >= -5) & (lat2d <= 5) & (lon2d >= 160) & (lon2d <= 210))
            two_dim_bool = True
        elif ('lat' in ds_cmip6.coords) and (len(ds_cmip6.lat.shape) == 2):
            ds_cmip6 = ds_cmip6.assign_coords(longitude=(ds_cmip6['lon'] % 360))
            lon2d, lat2d = ds_cmip6.lon, ds_cmip6.lat
            mask_n4 = ((lat2d >= -5) & (lat2d <= 5) & (lon2d >= 160) & (lon2d <= 210))
            two_dim_bool = True
            
        else:
            print(ds_cmip6.coords)
            
        tos = ds_cmip6["tos"].where(mask_n4)
        weights = xr.ufuncs.cos(np.deg2rad(lat2d)).where(mask_n4)
        weights = weights.fillna(0)
        spatial_dims = tuple(d for d in tos.dims if d != "time")
        
        num = (tos * weights).sum(spatial_dims)
        den = weights.sum(spatial_dims)
        ts = (num / den).chunk({"time": 120})   
        
        clim = ts.sel(time=slice('1991','2020')).groupby('time.month').mean('time')
        nino4 = ts.groupby('time.month') - clim
        nino4 = nino4.sel(time=slice('2015','2100'))
        
        # postprocessing of last time series 
        #nino4_3m = nino4.rolling(time=3, center=True).mean()
        #nino4_hp = nino4_3m.rolling(time=121, center=True).mean()
        #nino4_var = nino4_3m - nino4_hp   
        y = nino4 
        low_pass = y.rolling(time=121, center=True, min_periods=61).mean()  
        y_highpass = (y - low_pass).rename("nino4_hp")
        nino4_3m_det = y_highpass.rolling(time=3, center=True, min_periods=2).mean()
        sigma = float(nino4_3m_det.sel(time=slice("1991","2020")).std("time"))  # 1σ in K
        sigma_threshold = True
        if sigma_threshold:
            thr = sigma
        else:
            thr = .5
        mode_indx = nino4_3m_det.compute()               
        
        pos_persist = (mode_indx >=  thr).rolling(time=persistence_length).sum() >= persistence_length
        neg_persist = (mode_indx <= -thr).rolling(time=persistence_length).sum() >= persistence_length
        
        labels = xr.where(pos_persist, 1, xr.where(neg_persist, -1, 0)).astype("int8")
        labels = labels.rename("enso_phase")
        labels.attrs.update({
            "long_name": "ENSO phase label",
            "description": "Month-by-month phase labels: +1 El Niño, -1 La Niña, 0 Neutral. "
                           f"Computed from 3-month smoothed, detrended Niño-4 index with threshold={thr} K "
                           f"and persistence {persistence_length} consecutive 3 month periods.",
            "flag_values": np.array([-1, 0, 1], dtype="int8"),
            "flag_meanings": "la_nina neutral el_nino",
            "threshold_units": "K",
            "threshold_value": float(thr),
            "persistence_months": int(persistence_length),
            "index_source": "Nino4 3-month running mean (detrended)",
            "unique_str": f"{output_unique_str}",
            "sigma_threshold": f"{sigma_threshold}"
        })
        ds_out = xr.Dataset(
            {
                "enso_phase": labels,          
                "nino4_index": mode_indx,   
                "threshold": thr
            }
        )
        
        ds_out.to_netcdf(out_path)
        print("saved:", out_path)
    


### Selecting years with pos, neg, neu for each run

In [3]:
# table of number of events
ssp = 'ssp370'
output_path = f"/storage/group/pches/default/users/cta5244/enso4_loca2_underlying_models"
paths = sorted(glob.glob(f"{output_path}/{ssp}/*"))

rows = []
long_rows = []

for path in paths:
    model_id = f"{(path.split('/')[-1]).split('_')[2]}_{ssp}"
    ds_phase = xr.open_dataset(path)
    thr = ds_phase.threshold.values.item()
    jas = ds_phase.sel(time=ds_phase.time.dt.month.isin([7, 8, 9]))
    pos_years, neg_years, neu_years = [], [], []

    for year_i in np.unique(jas.time.dt.year.values):
        sel = jas.sel(time=jas.time.dt.year == year_i)
        # guard in case of missing months:
        if int(sel.time.size) < 3:
            continue
        if np.all(sel.oni_index >  thr):
            pos_years.append(int(year_i))
        elif np.all(sel.oni_index < -thr):
            neg_years.append(int(year_i))
        else:
            neu_years.append(int(year_i))

    rows.append({
        "model": model_id,
        "threshold": thr,
        "pos_years": pos_years,
        "neg_years": neg_years,
        "neu_years": neu_years,
        "n_pos": len(pos_years),
        "n_neg": len(neg_years),
        "n_neu": len(neu_years),
    })

    for y in pos_years:
        long_rows.append({"model": model_id, "year": y, "label":  1})
    for y in neg_years:
        long_rows.append({"model": model_id, "year": y, "label": -1})
    for y in neu_years:
        long_rows.append({"model": model_id, "year": y, "label":  0})
        
#pd.DataFrame(rows)[["model","threshold","n_pos","n_neg","n_neu"]].to_csv(
#   f"{output_path}/enso_JAS_counts_by_model_{ssp}.csv", index=False
#)
pd.DataFrame(long_rows).sort_values(["model","year"]).to_csv(
    f"{output_path}/enso_JAS_years_by_model_{ssp}_4tenths_thr.csv", index=False
)

## Compute, historical

In [13]:
base_path = "/storage/group/pches/default/users/cta5244/noaa_ersstv5/noaa_ersstv5.nc"
out_path_hist = "/storage/group/pches/default/users/cta5244/noaa_ersstv5/noaa_ersstv5_oni.nc"
persistence_length = 5 
ds_hist = xr.open_dataset(base_path)

In [15]:
ds_hist = ds_hist.assign_coords(lon=(ds_hist['lon'] % 360))
lat = ds_hist['lat']
lon = ds_hist['lon']
lon2d, lat2d = xr.broadcast(lon, lat)
mask_34 = ((lat2d >= -5) & (lat2d <= 5) &
           (lon2d >= 190) & (lon2d <= 240))
tos = ds_hist["sst"].where(mask_34)
weights = xr.ufuncs.cos(np.deg2rad(lat2d)).where(mask_34)
weights = weights.fillna(0)
spatial_dims = tuple(d for d in tos.dims if d != "time")

num = (tos * weights).sum(spatial_dims)
den = weights.sum(spatial_dims)
ts = (num / den).chunk({"time": 120})   

clim = ts.sel(time=slice('1991','2020')).groupby('time.month').mean('time')
oni = ts.groupby('time.month') - clim

oni_3m = oni.rolling(time=3, center=True, min_periods=3).mean()
thr = .4

mode_indx = oni_3m.compute()               

pos_persist = (mode_indx >=  thr).rolling(time=persistence_length).sum() >= persistence_length
neg_persist = (mode_indx <= -thr).rolling(time=persistence_length).sum() >= persistence_length

labels = xr.where(pos_persist, 1, xr.where(neg_persist, -1, 0)).astype("int8")
labels = labels.rename("enso_phase")
labels.attrs.update({
    "long_name": "ENSO phase label",
    "description": "Month-by-month phase labels: +1 El Niño, -1 La Niña, 0 Neutral. "
                   f"Computed from 3-month smoothed, detrended ONI index with threshold={thr} K "
                   f"and {persistence_length} consecutive 5 month periods.",
    "flag_values": np.array([-1, 0, 1], dtype="int8"),
    "flag_meanings": "la_nina neutral el_nino",
    "threshold_units": "K",
    "threshold_value": float(thr),
    "persistence_months": int(persistence_length),
    "index_source": "ONI 3-month running mean (detrended)",
    "unique_str": f"noaa_ersstv5",
    "sigma_threshold": f"{sigma_threshold}"
})
ds_out = xr.Dataset(
    {
        "enso_phase": labels,          
        "oni_index": mode_indx,   
        "threshold": thr,
    }
)

ds_out.to_netcdf(out_path_hist)
print("saved:", out_path_hist)

saved: /storage/group/pches/default/users/cta5244/noaa_ersstv5/noaa_ersstv5_oni.nc


In [16]:
# table of number of events
ssp = 'ssp370'
out_path_hist = "/storage/group/pches/default/users/cta5244/noaa_ersstv5/noaa_ersstv5_oni.nc"
save_path = "/storage/group/pches/default/users/cta5244/noaa_ersstv5/"
paths = [out_path_hist]
rows = []
long_rows = []

for path in paths:
    model_id = 'NOAA_ersstv5'
    ds_phase = xr.open_dataset(path)
    thr = ds_phase.threshold.values.item()
    jas = ds_phase.sel(time=ds_phase.time.dt.month.isin([7, 8, 9]))
    pos_years, neg_years, neu_years = [], [], []

    for year_i in np.unique(jas.time.dt.year.values):
        sel = jas.sel(time=jas.time.dt.year == year_i)
        # guard in case of missing months:
        if int(sel.time.size) < 3:
            continue
        if np.all(sel.oni_index >  thr):
            pos_years.append(int(year_i))
        elif np.all(sel.oni_index < -thr):
            neg_years.append(int(year_i))
        else:
            neu_years.append(int(year_i))

    rows.append({
        "model": model_id,
        "threshold": thr,
        "pos_years": pos_years,
        "neg_years": neg_years,
        "neu_years": neu_years,
        "n_pos": len(pos_years),
        "n_neg": len(neg_years),
        "n_neu": len(neu_years),
    })

    for y in pos_years:
        long_rows.append({"model": model_id, "year": y, "label":  1})
    for y in neg_years:
        long_rows.append({"model": model_id, "year": y, "label": -1})
    for y in neu_years:
        long_rows.append({"model": model_id, "year": y, "label":  0})
        
#pd.DataFrame(rows)[["model","threshold","n_pos","n_neg","n_neu"]].to_csv(
#   f"{output_path}/enso_JAS_counts_by_model_{ssp}.csv", index=False
#)
pd.DataFrame(long_rows).sort_values(["model","year"]).to_csv(
    f"{save_path}/enso_JAS_years_by_model_4tenths_thr.csv", index=False
)