In [None]:
from glob import glob
import xarray as xr
import cftime
import nc_time_axis
import numpy as np
import matplotlib.pyplot as plt

## Read the CM4 data

In [None]:
CM4_Omon = '/archive/uda/CMIP6/CMIP/NOAA-GFDL/GFDL-CM4/historical/r1i1p1f1/Omon'
CM4_thetao = f'{CM4_Omon}/thetao/gr/v20180701'
CM4_volcello = f'{CM4_Omon}/volcello/gr/v20180701'

In [None]:
filelist = glob(f'{CM4_thetao}/*.nc')
filelist.append(glob(f'{CM4_volcello}/*.nc'))

In [None]:
# It seems that if I don't specify chunks as they exist in the netcdf file (ncdump -sh file | grep -i chunk)
# open_mfdataset uses the size of the file as chunk. Big memory problems ensue...
CM4_T = xr.open_mfdataset(filelist, chunks={'time': 1, 'lev': 18, 'lat': 90, 'lon': 180})

In [None]:
CM4_T

In [None]:
CM4_T = CM4_T.chunk({'time': 1, 'lev': 35, 'lat': 180, 'lon': 360})

In [None]:
CM4_T

## Compute Ocean Heat Content

In [None]:
OHC = 3992 * 1025 * CM4_T['thetao'] * CM4_T['volcello']

In [None]:
OHC

In [None]:
global_OHC_upper700m = OHC.sel(lev=slice(0,700)).sum(dim=('lev', 'lat', 'lon'))

In [None]:
global_OHC_upper2000m = OHC.sel(lev=slice(0,2000)).sum(dim=('lev', 'lat', 'lon'))

In [None]:
global_OHC_2000below = OHC.sel(lev=slice(2000,7000)).sum(dim=('lev', 'lat', 'lon'))

In [None]:
global_OHC = OHC.sum(dim=('lev', 'lat', 'lon'))

In [None]:
global_OHC_level = OHC.sum(dim=('lat', 'lon'))

## Go fast with dask cluster

In [None]:
%time _ = global_OHC_upper700m.load()

In [None]:
# chunks =  (12, 35, 180, 360) -> Wall time: 3min 15s  ! cost of concat netcdf chunks?
# chunks =  (1, 35, 180, 360) -> Wall time: 2min 29s
# chunks =  (1, 1, 180, 360) -> Wall time: 5min 53s ! chunks too small, graph takes forever to build

In [None]:
%time _ = global_OHC_upper2000m.load()

In [None]:
%time _ = global_OHC_2000below.load()

In [None]:
%time _ = global_OHC.load()

In [None]:
%time _ = global_OHC_level.load()

## Compare with Zanna et al. paper

In [None]:
Zanna = xr.open_dataset('/net2/rnd/Zanna_2018/OHC_GF_1870_2018.nc')
Zanna = Zanna.rename({'time (starting 1870)': 'time'})
Zanna = Zanna.set_coords(['time'])

In [None]:
dates = []
for year in Zanna['time'].values:
    dates.append(cftime.DatetimeNoLeap(year, 7, 16, hour=12))

In [None]:
Zanna['cftime'] = xr.DataArray(np.array(dates), dims='cftime')
Zanna = Zanna.set_coords(['cftime'])

In [None]:
Zanna

In [None]:
def anom_yearly_avg(da):
    # compute yearly values
    yearly = da.groupby(da.time.dt.year).mean(dim='time')
    # anomamly to 1870 like in Zanna et al.
    anom = yearly - yearly.sel(year=1870)
    return anom

gOHCanom_upper700m_annual = anom_yearly_avg(global_OHC_upper700m)
gOHCanom_upper2000m_annual = anom_yearly_avg(global_OHC_upper2000m)
gOHCanom_2000below_annual = anom_yearly_avg(global_OHC_2000below)
gOHCanom_annual = anom_yearly_avg(global_OHC)

## Plot the results

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_upper700m_annual / 1e21).plot(ax=ax, label='CM4', color='k')
Zanna['OHC_700m'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC upper 700m')
plt.grid()

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_upper700m / 1e21).plot(ax=ax, label='CM4 monthly', color='k')
Zanna['OHC_700m_cf'] = xr.DataArray(Zanna['OHC_700m'].values, dims=('cftime'))
Zanna['OHC_700m_cf'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC upper 700m')
plt.grid()

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_upper2000m_annual / 1e21).plot(ax=ax, label='CM4', color='k')
Zanna['OHC_2000m'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC upper 2000m')
plt.grid()

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_2000below_annual / 1e21).plot(ax=ax, label='CM4', color='k')
Zanna['OHC_below_2000m'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC below 2000m')
plt.grid()

In [None]:
plt.figure()
ax = plt.axes()
(gOHCanom_annual / 1e21).plot(ax=ax, label='CM4', color='k')
Zanna['OHC_full_depth'].plot(ax=ax, label='Zanna', color='r')
plt.legend(fontsize=16)
plt.title('OHC full depth')
plt.grid()

### All in one

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=[10,10])

(gOHCanom_upper700m_annual / 1e21).plot(ax=axs[0,0], label='CM4hist', color='k')
Zanna['OHC_700m'].plot(ax=axs[0,0], label='Zanna', color='r')
axs[0,0].legend(fontsize=16)
axs[0,0].set_title('OHC [ZJ] upper 700m')
axs[0,0].set_xlabel("")
axs[0,0].set_ylabel("")
axs[0,0].grid()

(gOHCanom_upper2000m_annual / 1e21).plot(ax=axs[0,1], label='CM4hist', color='k')
Zanna['OHC_2000m'].plot(ax=axs[0,1], label='Zanna', color='r')
axs[0,1].legend(fontsize=16)
axs[0,1].set_title('OHC [ZJ] upper 2000m')
axs[0,1].set_xlabel("")
axs[0,1].set_ylabel("")
axs[0,1].grid()

(gOHCanom_2000below_annual / 1e21).plot(ax=axs[1,0], label='CM4hist', color='k')
Zanna['OHC_below_2000m'].plot(ax=axs[1,0], label='Zanna', color='r')
axs[1,0].legend(fontsize=16)
axs[1,0].set_title('OHC [ZJ] below 2000m')
axs[1,0].set_xlabel("")
axs[1,0].set_ylabel("")
axs[1,0].grid()

(gOHCanom_annual / 1e21).plot(ax=axs[1,1], label='CM4hist', color='k')
Zanna['OHC_full_depth'].plot(ax=axs[1,1], label='Zanna', color='r')
axs[1,1].legend(fontsize=16)
axs[1,1].set_title('OHC [ZJ] full depth')
axs[1,1].set_xlabel("")
axs[1,1].set_ylabel("")
axs[1,1].grid()