In [1]:
import os
import yaml
import copy
import time
import numpy as np
import pandas as pd
import xarray as xr

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

### Tmean, Tmax, Tmean, TPmean, TPmax

In [8]:
save_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/'
base_dir = '/glade/derecho/scratch/ksha/EPRI_data/CESM2_SMYLE/'

In [4]:
var_map = ["TREFHT", "TREFHTMX", "PRECT"]
list_per_verif = []

# cache: init_file -> ensemble-mean dataset (lazy dask ok)
_cache = {}

for verif_year in range(1968, 2021):

    list_per_lead = []
    for lead_year in range(10):
        
        init_year = verif_year - lead_year
        fn_CESM = base_dir + f"SMYLE_{init_year-1}-11-01_daily_ensemble.zarr"

        if fn_CESM not in _cache:
            ds0 = xr.open_zarr(fn_CESM)[var_map]
            _cache[fn_CESM] = ds0.mean("member")

        ds_CESM = _cache[fn_CESM].sel(time=slice(f"{verif_year}-01-01", f"{verif_year}-12-31"))
        
        ds_max = ds_CESM[["PRECT", "TREFHTMX"]].max("time", skipna=True).rename(
            {"PRECT": "PRECT_max", "TREFHTMX": "TREFHTMX_max"}
        )
        ds_mean = ds_CESM[["PRECT", "TREFHT"]].mean("time", skipna=True).rename(
            {"PRECT": "PRECT_mean", "TREFHT": "TREFHT_mean"}
        )

        ds_merge = xr.merge([ds_max, ds_mean]).expand_dims(lead_year=[lead_year])
        list_per_lead.append(ds_merge)

    ds_per_verif = xr.concat(list_per_lead, dim="lead_year").expand_dims(valid_year=[verif_year])
    list_per_verif.append(ds_per_verif)

ds_all = xr.concat(list_per_verif, dim="valid_year")
ds_all = ds_all.chunk({"valid_year": -1, "lead_year": -1, "lat": 192, "lon": 288})

In [5]:
ds_all = ds_all.compute()

In [7]:
save_name = save_dir + 'CESM_minmax.zarr'
# ds_all.to_zarr(save_name, mode='w')
print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/CESM_minmax.zarr


### Detrend

In [3]:
def detrend_linear(da, dim="time"):
    t = xr.DataArray(np.arange(da.sizes[dim]), dims=dim, coords={dim: da[dim]})

    valid = np.isfinite(da)
    t_valid = t.where(valid)
    da_valid = da.where(valid)

    t_mean = t_valid.mean(dim, skipna=True)
    y_mean = da_valid.mean(dim, skipna=True)

    cov = ((t_valid - t_mean) * (da_valid - y_mean)).mean(dim, skipna=True)
    var = ((t_valid - t_mean) ** 2).mean(dim, skipna=True)

    slope = cov / var
    intercept = y_mean - slope * t_mean

    trend = slope * t + intercept
    return da - trend

In [4]:
fn = '/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/CESM_minmax.zarr'
ds_CESM = xr.open_zarr(fn)

In [6]:
year_valid = np.arange(1968, 2021)

# remove linea trend separately on lead times
list_per_lead = []
for lead_year in range(10):
    ds_CESM_sub = ds_CESM.isel(lead_year=lead_year)
    ds_CESM_sub = ds_CESM_sub.sel(valid_year=year_valid)

    ds_CESM_detrend = ds_CESM_sub #.copy
    for v in ds_CESM_sub.data_vars:
        ds_CESM_detrend[v] = detrend_linear(ds_CESM_sub[v], dim="valid_year")
    list_per_lead.append(ds_CESM_detrend)

ds_all_detrend = xr.concat(list_per_lead, dim='lead_year')
ds_all_detrend = ds_all_detrend.rename({v: f"{v}_detrend" for v in ds_all_detrend.data_vars})
ds_all_detrend = ds_all_detrend.chunk({"valid_year": -1, "lead_year": -1, "lat": 192, "lon": 288})

In [9]:
save_name = save_dir + 'CESM_minmax_detrend.zarr'
# ds_all_detrend.to_zarr(save_name, mode='w')
print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS_GLOBE/CESM_minmax_detrend.zarr


In [10]:
ds_all_detrend

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 223.59 MiB 223.59 MiB Shape (10, 53, 192, 288) (10, 53, 192, 288) Dask graph 1 chunks in 317 graph layers Data type float64 numpy.ndarray",10  1  288  192  53,

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 223.59 MiB 223.59 MiB Shape (10, 53, 192, 288) (10, 53, 192, 288) Dask graph 1 chunks in 317 graph layers Data type float64 numpy.ndarray",10  1  288  192  53,

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 223.59 MiB 223.59 MiB Shape (10, 53, 192, 288) (10, 53, 192, 288) Dask graph 1 chunks in 317 graph layers Data type float64 numpy.ndarray",10  1  288  192  53,

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 223.59 MiB 223.59 MiB Shape (10, 53, 192, 288) (10, 53, 192, 288) Dask graph 1 chunks in 317 graph layers Data type float64 numpy.ndarray",10  1  288  192  53,

Unnamed: 0,Array,Chunk
Bytes,223.59 MiB,223.59 MiB
Shape,"(10, 53, 192, 288)","(10, 53, 192, 288)"
Dask graph,1 chunks in 317 graph layers,1 chunks in 317 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
