<span style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">An Exception was encountered at '<a href="#papermill-error-cell">In [13]</a>'.</span>

# Horizontal Mean difference and RMS: model versus observations

`mom6_tools.horizontalMean` collection of functions for computing horizontal mean of **difference** and **rms** (model versus obs). This notebook servers as an example on how to compute the following operations:


$$diff(t,z)= A_{TOT}(z)^{-1}\sum_{i=1}^n (y_i(z) - \hat{y_i(x,y,z)}) A_i(z),$$

$$rms(t,z)= [A_{TOT}(z)^{-1}\sum_{i=1}^n (y_i(z) - \hat{y_i}(x,y,z))^2 A_i(z)]^{1/2},$$

where $y$(z) is the model output at point $i$ and level $z$, $\hat{y}(z)$ is the observation at point $i$ and level $z$, $n$ is the total number of grid points in the horizontal (i.e., NX x NY), $A_{i}(z)$ is the area of grid cell $i$ at level $z$, and $A_{TOT}(z) = \sum_{i=1}^n A_i(z)$ is the total ocean area at level z. 

**Important**:

With the porpuses of calculating T and S changes at specific regions, $A_{i}(z)$ is multiplied by basin masks generated via ``mom6_tools.m6toolbox.genBasinMasks``. See [notebook](https://mom6-tools.readthedocs.io/en/latest/examples/region_masks.html) showing how to generate these masks. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from mom6_tools.DiagsCase import DiagsCase
from mom6_tools.MOM6grid import MOM6grid
from mom6_tools.drift import HorizontalMeanDiff_da, HorizontalMeanRmse_da 
from mom6_tools.m6plot import ztplot
from ncar_jobqueue import NCARCluster
from dask.distributed import Client
from mom6_tools.m6toolbox import genBasinMasks, request_workers, \
                                 weighted_temporal_mean, add_global_attrs
from IPython.display import display, Markdown, Latex
import yaml, intake, os
import xarray as xr
import matplotlib
import numpy as np
%matplotlib inline

Basemap module not found. Some regional plots may not function properly


In [3]:
# Read in the yaml file
diag_config_yml_path = "diag_config.yml"
diag_config_yml = yaml.load(open(diag_config_yml_path,'r'), Loader=yaml.Loader)

# load avg dates
avg = diag_config_yml['Avg']

# Create the case instance
dcase = DiagsCase(diag_config_yml['Case'])
DOUT_S = dcase.get_value('DOUT_S')
if DOUT_S:
  OUTDIR = dcase.get_value('DOUT_S_ROOT')+'/ocn/hist/'
else:
  OUTDIR = dcase.get_value('RUNDIR')

print('Output directory is:', OUTDIR)
print('Casename is:', dcase.casename)

Output directory is: /glade/scratch/gmarques/archive/g.e23_b15.GJRAv4.TL319_t232_zstar_N65.baseline.001/ocn/hist/
Casename is: g.e23_b15.GJRAv4.TL319_t232_zstar_N65.baseline.001


In [4]:
# The following parameters must be set accordingly
######################################################

# create an empty class object
class args:
  pass

args.start_date = avg['start_date']
args.end_date = avg['end_date']
args.casename = dcase.casename
args.obs = "woa-2018-tx2_3v2-annual-all"
args.z = dcase.casename+diag_config_yml['Fnames']['z']
args.static = dcase.casename+diag_config_yml['Fnames']['static']
args.savefigs = False
args.nw = 6 # requesting 6 workers

In [5]:
# Parameters
sname = "placeholder-sname"
subset_kwargs = {}
product = "/glade/u/home/eromashkova/codes/mom6-tools/docs/source/examples/computed_notebooks/placeholder-sname/TS_drift.ipynb"


In [6]:
# read grid info
grd = MOM6grid(OUTDIR+'/'+args.static, xrformat=True)
try:
  depth = grd.depth_ocean.values
except:
  depth = grd.deptho.values

try:
  area = grd.area_t.where(grd.wet > 0)
except:
  area = grd.areacello.where(grd.wet > 0)

MOM6 grid successfully loaded... 



In [7]:
# remote Nan's, otherwise genBasinMasks won't work
depth[np.isnan(depth)] = 0.0
basin_code = genBasinMasks(grd.geolon.values, grd.geolat.values, depth, xda=True)

In [8]:
cluster = NCARCluster()
cluster.scale(args.nw)
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: /proxy/45807/status,

0,1
Dashboard: /proxy/45807/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.12.206.57:42319,Workers: 0
Dashboard: /proxy/45807/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [9]:
def preprocess(ds):
    if 'thetao' not in ds.variables:
        ds["thetao"] = xr.zeros_like(ds.h)
    if 'so' not in ds.variables:
        ds["so"] = xr.zeros_like(ds.h)

    return ds

In [10]:
# read dataset
ds = xr.open_mfdataset(OUTDIR+'/'+args.z,
    parallel=True,
    combine="nested", # concatenate in order of files
    concat_dim="time", # concatenate along time
    preprocess=preprocess,
    ).chunk({"time": 12})

In [11]:
# Compute thetao climatologies
var = 'thetao'
attrs =  {
         'description': 'Annual mean climatology for '+var,
         'start_date': args.start_date,
         'end_date': args.end_date,
         'reduction_method': 'annual mean weighted by days in each month',
         'casename': dcase.casename
         }

thetao_model = weighted_temporal_mean(ds,var)
thetao_model.attrs = attrs

In [12]:
# Compute thetao climatologies
var = 'so'
attrs =  {
         'description': 'Annual mean climatology for '+var,
         'start_date': args.start_date,
         'end_date': args.end_date,
         'reduction_method': 'annual mean weighted by days in each month',
         'casename': dcase.casename
         }

salt_model = weighted_temporal_mean(ds,var)
salt_model.attrs = attrs

<span id="papermill-error-cell" style="color:red; font-family:Helvetica Neue, Helvetica, Arial, sans-serif; font-size:2em;">Execution using papermill encountered an exception here and stopped:</span>

In [13]:
# load obs
catalog = intake.open_catalog(diag_config_yml['oce_cat'])
obs = catalog[args.obs].to_dask()

IndexError: list index out of range

In [None]:
temp_diff = thetao_model - obs.thetao
salt_diff = salt_model - obs.so

## Construct a 3D area with land values masked

In [None]:
area3d = np.repeat(area.values[np.newaxis, :, :], len(temp_diff.z_l), axis=0)
mask3d = xr.DataArray(area3d, dims=(temp_diff.dims[1:4]), coords= {temp_diff.dims[1]: temp_diff.z_l,
                                                                   temp_diff.dims[2]: temp_diff.yh,
                                                                   temp_diff.dims[3]: temp_diff.xh})
area3d_masked = mask3d.where(temp_diff[0,:] == temp_diff[0,:])

## Horizontal Mean difference (model - obs)

In [None]:
%%time
temp_bias = HorizontalMeanDiff_da(temp_diff,weights=area3d_masked, basins=basin_code)

In [None]:
print('Saving temp_bias...\n')
if not os.path.isdir('ncfiles'):
      os.system('mkdir -p ncfiles')
 
var = 'thetao'
attrs = {'casename': dcase.casename,
         'description': 'Annual mean bias for '+var,
         'obs': args.obs
        }

add_global_attrs(temp_bias,attrs)
temp_bias.to_netcdf('ncfiles/'+str(dcase.casename)+'_{}_drift.nc'.format(var))

## Temperature

In [None]:
for reg in temp_bias.region:
    # remove Nan's
    diff_reg = temp_bias.sel(region=reg).dropna('z_l')
    if diff_reg.z_l.max() <= 500.0:
      splitscale = None
    else:
      splitscale =  [0., -500., -diff_reg.z_l.max()]

    ztplot(diff_reg.values, diff_reg.time.values, diff_reg.z_l.values*-1, ignore=np.nan, splitscale=splitscale, 
           suptitle=dcase._casename, contour=True,
           title= str(reg.values) + ', Potential Temperature [C], (model - obs)', 
           extend='both', colormap='dunnePM', autocenter=True, tunits='Year', show=True) 

In [None]:
%%time
salt_bias = HorizontalMeanDiff_da(salt_diff,weights=area3d_masked, basins=basin_code)

In [None]:
print('Saving salt_bias...\n')
if not os.path.isdir('ncfiles'):
      os.system('mkdir -p ncfiles')
 
var = 'so'
attrs = {'casename': dcase.casename,
         'description': 'Annual mean bias for '+var,
         'obs': args.obs
        }

add_global_attrs(salt_bias,attrs)
salt_bias.to_netcdf('ncfiles/'+str(dcase.casename)+'_{}_drift.nc'.format(var))

## Salinity

In [None]:
for reg in salt_bias.region:
    # remove Nan's
    diff_reg = salt_bias.sel(region=reg).dropna('z_l')
    if diff_reg.z_l.max() <= 500.0:
      splitscale = None
    else:
      splitscale =  [0., -500., -diff_reg.z_l.max()]

    ztplot(diff_reg.values, diff_reg.time.values, diff_reg.z_l.values*-1, ignore=np.nan, splitscale=splitscale, 
           suptitle=dcase._casename, contour=True,
           title= str(reg.values) + ', Salinity [psu], (model - obs)', 
           extend='both', colormap='dunnePM', autocenter=True, tunits='Year', show=True)

## Horizontal Mean RMSe (model - obs)

In [None]:
# TODO
temp_rms = HorizontalMeanDiff_da(temp_diff,weights=area3d_masked, basins=basin_code)

## Temperature

In [None]:
for reg in temp_rms.region:
    # remove Nan's
    diff_reg = temp_rms.sel(region=reg).dropna('z_l')
    if diff_reg.z_l.max() <= 500.0:
      splitscale = None
    else:
      splitscale =  [0., -500., -diff_reg.z_l.max()]

    ztplot(diff_reg.values, diff_reg.time.values, diff_reg.z_l.values*-1, ignore=np.nan, splitscale=splitscale, 
           suptitle=dcase._casename, contour=True,
           title= str(reg.values) + ', Potential Temperature [C], RMSe', 
           extend='both', colormap='dunnePM', autocenter=False, tunits='Year', show=True)

In [None]:
# TODO
salt_rms = HorizontalMeanRmse_da(salt_diff,weights=area3d_masked, basins=basin_code)

## Salinity

In [None]:
for reg in salt_rms.region:
    # remove Nan's
    diff_reg = salt_rms.sel(region=reg).dropna('z_l')
    if diff_reg.z_l.max() <= 500.0:
      splitscale = None
    else:
      splitscale =  [0., -500., -diff_reg.z_l.max()]

    ztplot(diff_reg.values, diff_reg.time.values, diff_reg.z_l.values*-1, ignore=np.nan, 
           splitscale=splitscale, suptitle=dcase._casename, contour=True,
           title= str(reg.values) + ', Salinity [psu], RMSe', extend='both', 
           colormap='dunnePM', autocenter=False, tunits='Year', show=True); 