## Ocean heat content in CESM2-LE
- Generate netCDF files that contain global ocean heat content at each level in CESM2-LE.

In [1]:
import warnings
# warnings.filterwarnings("ignore")
import intake
import matplotlib.pyplot as plt
import numpy as np
import netCDF4 as nc
import xarray as xr
import xesmf as xe
# from cmip6_preprocessing.preprocessing import combined_preprocessing
from xmip.preprocessing import combined_preprocessing
import cartopy.crs as ccrs
import cartopy.feature as cfeaturf
from cartopy.util import add_cyclic_point
from matplotlib import gridspec
import matplotlib.pylab as pl
import scipy.stats as ss
import scipy.signal as sg
import seaborn as sns
import pandas as pd
import zarr
import gcsfs
import requests
import sys
import gsw
catalog = intake.open_esm_datastore('https://raw.githubusercontent.com/NCAR/cesm2-le-aws/main/intake-catalogs/aws-cesm2-le.json')

def grid_area(lon,lat):
    earth_r = 6.371*1e6
    dlon_deg= np.append(np.diff(lon),lon[0]-lon[-1]+360.)
    dlat_deg= np.append(np.diff(lat),np.diff(lat)[-1])
    dlon    = np.deg2rad(dlon_deg)
    dlat    = np.deg2rad(dlat_deg)
    coslat  = np.cos(np.deg2rad(lat)) 
    dx      = earth_r * np.tile(coslat[:,np.newaxis], (1,len(lon))) * np.tile(dlon[np.newaxis,:],(len(lat),1))
    dy      = earth_r * np.tile(dlat[:,np.newaxis], (1,len(lon)))
    return (dx*dy)

# consts
t1d_lon = np.linspace(0.5, 359.5, 360) # T grid, 1 deg
t1d_lat = np.linspace(-89.5, 89.5, 180) # T grid, 1 deg
mon_wgt = np.array([31., 28., 31., 30., 31., 30., 31., 31., 30., 31., 30., 31.])
cp_0    = 3991.868 # [J/kg/K] isobaric heat capacity that relates potential enthalpy to Conservative Temperature (TEOS-10)
t1d_area= grid_area(t1d_lon, t1d_lat)

In [2]:
grid_subset = catalog.search(component='ocn', frequency='static', experiment='historical', forcing_variant='cmip6')
grid = grid_subset.to_dataset_dict(storage_options={'anon':True})['ocn.historical.static.cmip6']

# keep TLONG, TLAT, and raw data as DataArray before xESMF regridding (others are ndarray)
TLONG, TLAT = grid.coords['TLONG'].reset_coords(drop=True), grid.coords['TLAT'].reset_coords(drop=True)
ULONG, ULAT = grid.coords['ULONG'].reset_coords(drop=True), grid.coords['ULAT'].reset_coords(drop=True)
TAREA       = grid.coords['TAREA'].values*1e-4 #(cm^2->m^2)
TAREA_mod   = np.where(np.isnan(TAREA), 0., TAREA)

REGION_MASK = grid.coords['REGION_MASK']
dz          = grid.coords['dz'].values*1e-2 #(cm->m)

KMT         = grid.coords['KMT'].values # k Index of Deepest Grid Cell on T Grid (note: contains NaN)
KMT_mod     = np.where(np.isnan(KMT), 0., KMT)

ocn2dpass   = np.where(REGION_MASK>0., 1., np.nan) # land=np.nan, ocn=1 (only open-ocean)
ocn_area    = ocn2dpass*TAREA_mod

ilev_1d     = np.arange(len(dz))
ocn3dpass   = np.where(ilev_1d[:,np.newaxis,np.newaxis]>=KMT_mod[np.newaxis,:,:],np.nan,1.)
ocn_vol     = ocn3dpass*TAREA_mod[np.newaxis,:,:]*dz[:,np.newaxis,np.newaxis]

# Ocean volume regridding. Area weights applied in order to conserve the volume.
v_re = {}
grid_in              = {'lon': TLONG.values, 'lat': TLAT.values}
grid_out             = {'lon': t1d_lon, 'lat': t1d_lat}
regridder            = xe.Regridder(grid_in, grid_out, 'bilinear', periodic=True, ignore_degenerate=True)
v_re['ocn_vol_norm'] = regridder(ocn_vol/TAREA_mod[np.newaxis,:,:], skipna=True, na_thres=.5) 
# na_thres: set to NaN only if the "ratio" of missing values exceeds the threshold level 
v_re['ocn_vol']      = v_re['ocn_vol_norm']*t1d_area[np.newaxis,:,:]


--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


In [3]:
comp    = 'ocn' # atm, ocn, ice, lnd...
iv_ls   = ['TEMP', 'SALT']

ie1     = 'historical'
ie2     = 'ssp370'

tsta = "2004-01"
tend = "2021-12"
nyr  = 2021-2004+1
nt   = nyr*12
cat  = catalog.search(variable='TEMP', frequency='monthly', experiment=ie1, forcing_variant='cmip6')
dset = cat.to_dataset_dict(storage_options={'anon':True})[comp+'.'+ie1+'.monthly.cmip6']['TEMP']
m_ls = dset.coords['member_id'].values

# Select the range of the ensembles for the following OHC calculations
msta = 0
mend = 1
 
for imN in range(msta,mend,1):
    im = m_ls[imN]
    print(f'Ensemble {imN+1:d}.')
    v  = {}
    for iv in iv_ls:
        # historical temperature data under smbb forcing not available in cesm2-le-aws (Sep 26, 2023)
        cat        = catalog.search(variable=iv, frequency='monthly', experiment=ie1, forcing_variant='cmip6')
        dset_dict1 = cat.to_dataset_dict(storage_options={'anon':True})
        cat        = catalog.search(variable=iv, frequency='monthly', experiment=ie2, forcing_variant='cmip6')
        dset_dict2 = cat.to_dataset_dict(storage_options={'anon':True})
        v[iv]      = xr.concat([dset_dict1[comp+'.'+ie1+'.monthly.cmip6'][iv].sel(time=slice(tsta,"2014-12"), member_id=im), dset_dict2[comp+'.'+ie2+'.monthly.cmip6'][iv].sel(time=slice("2015-01",tend), member_id=im)], dim="time")
    for iv in iv_ls:
        grid_in    = {'lon': TLONG, 'lat': TLAT}
        grid_out   = {'lon': t1d_lon, 'lat': t1d_lat}
        regridder  = xe.Regridder(grid_in, grid_out, 'bilinear', periodic=True, ignore_degenerate=True)
        v_re[iv]   = regridder(v[iv], skipna=True, na_thres=.5) # na_thres: set to NaN only if the "ratio" of missing values exceeds the threshold level

    # get values (slow; ~1 min for each ensemble)
    v_re['pt'] = v_re['TEMP'].values
    v_re['sp'] = v_re['SALT'].values
    # anomalies relative to the whole-period averages
    v_re['pt_cl']  = np.average(v_re['pt'].reshape(-1,12,len(dz),len(t1d_lat),len(t1d_lon)),axis=0)
    v_re['pt_an']  = v_re['pt'].reshape(-1,12,len(dz),len(t1d_lat),len(t1d_lon)) - v_re['pt_cl'][np.newaxis,...]
    
    # Compute in-situ density and ocean heat content at each level
    z_t              = v['TEMP']['z_t'].values*1e-2 #(cm->m)
    lev_, lat_, lon_ = np.broadcast_arrays(z_t[:,np.newaxis,np.newaxis], t1d_lat[np.newaxis,:,np.newaxis], t1d_lon[np.newaxis,np.newaxis,:])
    v_re['p']        = gsw.p_from_z(-lev_, lat_)

    v_re['ohc_lev']  = np.full((nt,len(z_t),len(t1d_lat),len(t1d_lon)), np.nan, np.float32)
    for iy in range(nt):
        # print(iy)
        v_sa_iy  = gsw.SA_from_SP(v_re['sp'][iy,...], v_re['p'], lon_, lat_)
        v_ct_iy  = gsw.CT_from_pt(v_sa_iy, v_re['pt'][iy,...])
        v_rho_iy = gsw.density.rho(v_sa_iy, v_ct_iy, v_re['p'])
        v_re['ohc_lev'][iy,:,:,:] = cp_0*v_rho_iy*v_re['pt_an'].reshape(-1,len(dz),len(t1d_lat),len(t1d_lon))[iy,:,:,:]*v_re['ocn_vol']
    
    v_re['ohc_lev_ann'] = np.average(v_re['ohc_lev'].reshape(-1,12,len(dz),len(t1d_lat),len(t1d_lon)),axis=1,weights=mon_wgt)
    # --- Save
    out_f_name = '/home/jovyan/OHU/processed_data/OHC_'+tsta+'_'+tend+'_'+str(im)+'.nc'
    print(out_f_name)
    out_ds = nc.Dataset(out_f_name, 'w', format='NETCDF4')
    out_ds.createDimension('time', nyr)
    out_ds.createDimension('lev', len(z_t))
    out_ds.createDimension('lat', len(t1d_lat))
    out_ds.createDimension('lon', len(t1d_lon))        
    out_ds.description = 'Annual-mean OHC anomaly from '+tsta+' to '+tend+' in emsemble member - '+str(im)+' of CESM2 LE. Reference State is the average of the whole period.'
    time   = out_ds.createVariable('time', 'f4', ('time',))
    lev    = out_ds.createVariable('lev', 'f4', ('lev',))
    lat    = out_ds.createVariable('lat', 'f4', ('lat',))
    lon    = out_ds.createVariable('lon', 'f4', ('lon',))
    time[:]   = np.arange(nyr)
    lev[:]    = z_t
    lat[:]    = t1d_lat
    lon[:]    = t1d_lon
    out_v1               = out_ds.createVariable('ohc_lev_ann', 'f4', ('time', 'lev', 'lat', 'lon'))        
    out_v1.units         = 'J'
    out_v1.standard_name = 'OHC'
    out_v1[:]            = v_re['ohc_lev_ann']
    out_ds.close()


--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


Ensemble 1.

--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'



--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'



--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'



--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


/home/jovyan/OHU/processed_data/OHC_2004-01_2021-12_r10i1181p1f1.nc
