# Bias-correct CESM2 LENS temperature data using ERA5 reanalysis

In [21]:
# Display output of plots directly in Notebook
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

import intake
import numpy as np
import pandas as pd
import xarray as xr
# import s3fs
import seaborn as sns
import re
# import nest_asyncio
# nest_asyncio.apply()
import xesmf as xe

In [22]:
import fsspec.implementations.http as fshttp
from pelicanfs.core import PelicanFileSystem, PelicanMap, OSDFFileSystem 

In [23]:
import dask 
from dask_jobqueue import PBSCluster
from dask.distributed import Client
from dask.distributed import performance_report

In [24]:
init_year0  = '1991'
init_year1  = '2020'
final_year0 = '2071'
final_year1 = '2100'

In [25]:
def to_daily(ds):
    year = ds.time.dt.year
    day = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), day=("time", day.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "day")).unstack("time")

In [26]:
rda_scratch = '/gpfs/csfs1/collections/rda/scratch/harshah'
zarr_path   = rda_scratch + "/tas_zarr/"
mean_path   = zarr_path + "/means/"
stdev_path  = zarr_path + "/stdevs/"

## Create a PBS cluster

In [27]:
# Create a PBS cluster object
cluster = PBSCluster(
    job_name = 'dask-wk24-hpc',
    cores = 1,
    memory = '8GiB',
    processes = 1,
    local_directory = rda_scratch+'/dask/spill',
    log_directory = rda_scratch + '/dask/logs/',
    resource_spec = 'select=1:ncpus=1:mem=8GB',
    queue = 'casper',
    walltime = '5:00:00',
    #interface = 'ib0'
    interface = 'ext'
)

In [28]:
cluster.scale(20)

In [29]:
cluster

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/43023/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://128.117.208.96:44747,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/43023/status,Total threads: 0
Started: Just now,Total memory: 0 B


## Load CESM LENS2 temperature data

In [30]:
cesm_cat = intake.open_esm_datastore(rda_scratch + '/intake_catalogs/posix/aws-cesm2-le.json')
cesm_cat

Unnamed: 0,unique
Unnamed: 0,322
variable,53
long_name,51
component,4
experiment,2
forcing_variant,2
frequency,3
vertical_levels,3
spatial_domain,3
units,20


In [31]:
cesm_temp = cesm_cat.search(variable ='TREFHTMX', frequency ='daily')
cesm_temp

Unnamed: 0,unique
Unnamed: 0,4
variable,1
long_name,1
component,1
experiment,2
forcing_variant,2
frequency,1
vertical_levels,1
spatial_domain,1
units,1


In [32]:
cesm_temp.df['path'].values

array(['/glade/campaign/collections/rda/transfer/chifan_AWS/ncar-cesm2-lens/atm/daily/cesm2LE-historical-cmip6-TREFHTMX.zarr',
       '/glade/campaign/collections/rda/transfer/chifan_AWS/ncar-cesm2-lens/atm/daily/cesm2LE-historical-smbb-TREFHTMX.zarr',
       '/glade/campaign/collections/rda/transfer/chifan_AWS/ncar-cesm2-lens/atm/daily/cesm2LE-ssp370-cmip6-TREFHTMX.zarr',
       '/glade/campaign/collections/rda/transfer/chifan_AWS/ncar-cesm2-lens/atm/daily/cesm2LE-ssp370-smbb-TREFHTMX.zarr'],
      dtype=object)

In [33]:
dsets_cesm = cesm_temp.to_dataset_dict()


--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


In [34]:
historical_smbb  = dsets_cesm['atm.historical.daily.smbb']
future_smbb      = dsets_cesm['atm.ssp370.daily.smbb']

historical_cmip6 = dsets_cesm['atm.historical.daily.cmip6']
future_cmip6     = dsets_cesm['atm.ssp370.daily.cmip6']

In [35]:
# %%time
# merge_ds_smbb = xr.concat([historical_smbb, future_smbb], dim='time')
# merge_ds_smbb = merge_ds_smbb.dropna(dim='member_id')

# merge_ds_cmip6= xr.concat([historical_cmip6, future_cmip6], dim='time')
# merge_ds_cmip6 = merge_ds_cmip6.dropna(dim='member_id')

In [36]:
# merge_ds_cmip6

In [37]:
# t_smbb      = merge_ds_smbb.TREFHTMX
# t_cmip6     = merge_ds_cmip6.TREFHTMX
# t_init_cmip6 = t_cmip6.sel(time=slice(init_year0, init_year1))
# t_init_smbb  = t_smbb.sel(time=slice(init_year0, init_year1))
# t_init       = xr.concat([t_init_cmip6,t_init_smbb],dim='member_id')
# t_init

In [38]:
# t_init_day = to_daily(t_init)
# #t_init_day

In [39]:
# t_fut_cmip6 = t_cmip6.sel(time=slice(final_year0, final_year1))
# t_fut_smbb  = t_smbb.sel(time=slice(final_year0, final_year1))
# t_fut       = xr.concat([t_fut_cmip6,t_fut_smbb],dim='member_id')
# t_fut_day   = to_daily(t_fut)
# t_fut_day

### Save means and standard deviations

In [41]:
# init_means   = t_init_day.mean({'year','member_id'})
# init_stdevs  = t_init_day.std({'year','member_id'})
# final_means  = t_fut_day.mean({'year','member_id'})
# final_stdevs = t_fut_day.std({'year','member_id'})

In [None]:
# %%time
# #Save 
# init_means.to_dataset().to_zarr(mean_path + 'cesm2_'+ init_year0 + '_' + init_year1+ '_means.zarr',mode='w')
# init_stdevs.to_dataset().to_zarr(stdev_path + 'cesm2_'+ init_year0 + '_' + init_year1+ '_stdevs.zarr',mode='w') 
# final_means.to_dataset().to_zarr(mean_path + 'cesm2_'+ final_year0 + '_' + final_year1+ '_means.zarr',mode='w')
# final_stdevs.to_dataset().to_zarr(stdev_path + 'cesm2_'+ final_year0 + '_' + final_year1+ '_stdevs.zarr',mode='w') 

## Access ERA5 data and regrid CESM2 LENS data on the finer, ERA5 grid

In [42]:
%%time
tas_daily = xr.open_zarr(zarr_path + "e5_tas2m_daily_1940_2023.zarr").VAR_2T
tas_daily

CPU times: user 7.94 ms, sys: 175 μs, total: 8.12 ms
Wall time: 33.2 ms


Unnamed: 0,Array,Chunk
Bytes,118.79 GiB,288.45 MiB
Shape,"(30712, 721, 1440)","(1000, 139, 544)"
Dask graph,558 chunks in 2 graph layers,558 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 118.79 GiB 288.45 MiB Shape (30712, 721, 1440) (1000, 139, 544) Dask graph 558 chunks in 2 graph layers Data type float32 numpy.ndarray",1440  721  30712,

Unnamed: 0,Array,Chunk
Bytes,118.79 GiB,288.45 MiB
Shape,"(30712, 721, 1440)","(1000, 139, 544)"
Dask graph,558 chunks in 2 graph layers,558 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [54]:
init_means_ds = xr.open_zarr(mean_path + 'cesm2_'+ init_year0 + '_' + init_year1+ '_means.zarr')
init_means    = init_means_ds.TREFHTMX
init_means

Unnamed: 0,Array,Chunk
Bytes,76.99 MiB,76.99 MiB
Shape,"(192, 288, 365)","(192, 288, 365)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 76.99 MiB 76.99 MiB Shape (192, 288, 365) (192, 288, 365) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",365  288  192,

Unnamed: 0,Array,Chunk
Bytes,76.99 MiB,76.99 MiB
Shape,"(192, 288, 365)","(192, 288, 365)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [61]:
final_means  = xr.open_zarr(mean_path  + 'cesm2_'+ final_year0 + '_' + final_year1+ '_means.zarr').TREFHTMX
init_stdevs  = xr.open_zarr(stdev_path + 'cesm2_'+ init_year0 + '_' + init_year1+ '_stdevs.zarr').TREFHTMX
final_stdevs = xr.open_zarr(stdev_path + 'cesm2_'+ final_year0 + '_' + final_year1+ '_stdevs.zarr').TREFHTMX

In [45]:
#Create output grid
ds_out = xr.Dataset(
    coords={
        'latitude': tas_daily.coords['latitude'],
        'longitude': tas_daily.coords['longitude']
    }
)
ds_out = ds_out.rename({'latitude':'lat','longitude':'lon'})
ds_out

In [50]:
tas    = to_daily(tas_daily)
tas_ds = tas.to_dataset()
tas_ds

Unnamed: 0,Array,Chunk
Bytes,120.33 GiB,316.72 MiB
Shape,"(721, 1440, 85, 366)","(139, 544, 3, 366)"
Dask graph,558 chunks in 16 graph layers,558 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 120.33 GiB 316.72 MiB Shape (721, 1440, 85, 366) (139, 544, 3, 366) Dask graph 558 chunks in 16 graph layers Data type float32 numpy.ndarray",721  1  366  85  1440,

Unnamed: 0,Array,Chunk
Bytes,120.33 GiB,316.72 MiB
Shape,"(721, 1440, 85, 366)","(139, 544, 3, 366)"
Dask graph,558 chunks in 16 graph layers,558 chunks in 16 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [55]:
%%time 
regridder = xe.Regridder(init_means_ds, ds_out, "bilinear")
regridder

CPU times: user 15.7 s, sys: 438 ms, total: 16.1 s
Wall time: 18.3 s


xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_192x288_721x1440.nc 
Reuse pre-computed weights? False 
Input grid shape:           (192, 288) 
Output grid shape:          (721, 1440) 
Periodic in longitude?      False

In [57]:
init_means_regrid = regridder(init_means, keep_attrs=True)
init_means_regrid

Unnamed: 0,Array,Chunk
Bytes,1.41 GiB,1.41 GiB
Shape,"(365, 721, 1440)","(365, 721, 1440)"
Dask graph,1 chunks in 9 graph layers,1 chunks in 9 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.41 GiB 1.41 GiB Shape (365, 721, 1440) (365, 721, 1440) Dask graph 1 chunks in 9 graph layers Data type float32 numpy.ndarray",1440  721  365,

Unnamed: 0,Array,Chunk
Bytes,1.41 GiB,1.41 GiB
Shape,"(365, 721, 1440)","(365, 721, 1440)"
Dask graph,1 chunks in 9 graph layers,1 chunks in 9 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [62]:
%%time
# Regrid other variables
init_stdevs_regrid  = regridder(init_stdevs, keep_attrs=True)
final_means_regrid  = regridder(final_means, keep_attrs=True)
final_stdevs_regrid = regridder(final_stdevs, keep_attrs=True)

CPU times: user 931 ms, sys: 32 ms, total: 963 ms
Wall time: 1.02 s


## Now perform bias correction

In [23]:
# GMST function ###
# calculate global means

def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'time','member_id'}
    return (ds * weight).mean(other_dims)

### Calculate GMST 

#### Now compute (spatially weighted) Global Mean