In [1]:
import os
import yaml
import copy

import numpy as np
import pandas as pd
import xarray as xr

In [2]:
import xesmf as xe

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline

### Fort_Bragg: max total precip

In [4]:
key = 'Fort_Bragg'
base_dir = f'/glade/derecho/scratch/ksha/EPRI_data/METRICS/{key}/'

In [5]:
fn_CESM = base_dir + 'CESM_TP_max.zarr'
fn_ERA5 = base_dir + 'ERA5_TP_max.zarr'

ds_CESM = xr.open_zarr(fn_CESM)
ds_ERA5 = xr.open_zarr(fn_ERA5)

In [6]:
ds_CESM = ds_CESM.rename({'init_time': 'year_valid'})
ds_ERA5 = ds_ERA5.rename({'year': 'year_valid'})
varnames = list(ds_ERA5.keys())

ds_collection = []

for i_lead in range(10):
    ds_CESM_slice = ds_CESM.isel(year=i_lead)
    # convert ini_time to valid_time by adding lead times
    ds_CESM_slice['year_valid'] = ds_CESM_slice['year_valid'] + i_lead

    # align two datasets
    ds_ERA5_slice = ds_ERA5
    ds_CESM_slice, ds_ERA5_slice = xr.align(ds_CESM_slice, ds_ERA5_slice, join="inner")

    # compute ACC if there's enough time values
    if len(ds_CESM_slice['year_valid']) >= 10:
        da_collection = []
        for varname in varnames:
            da_collection.append(
                xr.corr(ds_CESM_slice[varname], ds_ERA5_slice[varname], dim="year_valid")
            )
            
        ds_ACC = xr.merge(da_collection)
        ds_collection.append(ds_ACC)

ds_ACC_all = xr.concat(ds_collection, dim='year_valid')

In [8]:
save_name = base_dir + 'ACC_TP_max.zarr'
# ds_ACC_all.to_zarr(save_name)
print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS/Fort_Bragg/ACC_TP_max.zarr


### Fort_Bragg: SPEI

In [11]:
fn_CESM = base_dir + 'CESM_SPEI.zarr'
fn_ERA5 = base_dir + 'ERA5_SPEI.zarr'

ds_CESM = xr.open_zarr(fn_CESM)
ds_ERA5 = xr.open_zarr(fn_ERA5)

In [12]:
# convert to annual min and mean
ds_group = ds_ERA5.groupby("time.year")
ds_mean = ds_group.mean(dim="time", skipna=False)
ds_min  = ds_group.min(dim="time",  skipna=False)
ds_mean = ds_mean.rename({v: f"{v}_mean" for v in ds_mean.data_vars})
ds_min  = ds_min.rename({v: f"{v}_min"  for v in ds_min.data_vars})
ds_ERA5_merge = xr.merge([ds_mean, ds_min])

In [13]:
bins = np.arange(0, 121, 12)

gb = ds_CESM.groupby_bins("lead_time_month", bins=bins, right=False)

ds_mean = (
    gb.mean("lead_time_month", skipna=False)
      .rename({"lead_time_month_bins": "year"})
      .assign_coords(year=np.arange(10))
)
ds_min = (
    gb.min("lead_time_month", skipna=False)
      .rename({"lead_time_month_bins": "year"})
      .assign_coords(year=np.arange(10))
)

ds_mean = ds_mean.rename({v: f"{v}_mean" for v in ds_mean.data_vars})
ds_min  = ds_min.rename({v: f"{v}_min"  for v in ds_min.data_vars})

ds_CESM_merge = xr.merge([ds_mean, ds_min], compat="override")

In [14]:
ds_ERA5_merge = ds_ERA5_merge.rename({'year': 'year_valid'})
ds_CESM_merge = ds_CESM_merge.rename({'init_time': 'year_valid'})

In [15]:
varnames = list(ds_ERA5_merge.keys())

ds_collection = []

for i_lead in range(10):
    ds_CESM_slice = ds_CESM_merge.isel(year=i_lead)
    # convert ini_time to valid_time by adding lead times
    ds_CESM_slice['year_valid'] = ds_CESM_slice['year_valid'] + i_lead
    
    # align two datasets
    ds_ERA5_slice = ds_ERA5_merge
    ds_CESM_slice, ds_ERA5_slice = xr.align(ds_CESM_slice, ds_ERA5_slice, join="inner")

    # compute ACC if there's enough time values
    if len(ds_CESM_slice['year_valid']) >= 10:
        da_collection = []
        for varname in varnames:
            da_collection.append(
                xr.corr(ds_CESM_slice[varname], ds_ERA5_slice[varname], dim="year_valid")
            )
            
        ds_ACC = xr.merge(da_collection)
        ds_collection.append(ds_ACC)

ds_ACC_all = xr.concat(ds_collection, dim='year_valid')

In [16]:
save_name = base_dir + 'ACC_SPEI.zarr'
# ds_ACC_all.to_zarr(save_name)
print(save_name)

/glade/derecho/scratch/ksha/EPRI_data/METRICS/Fort_Bragg/ACC_SPEI.zarr
