## Equatorial plots

In [1]:
%%capture 
# comment above line to see details about the run(s) displayed
# from misc import *
%matplotlib inline

In [2]:
### Temporarily just pasting in all of misc

import xarray as xr
import numpy as np
from datetime import date
import matplotlib.pyplot as plt
import warnings, os, yaml
import matplotlib
import getpass
import netCDF4
# mom6_tools
from mom6_tools.DiagsCase import DiagsCase
from mom6_tools.MOM6grid import MOM6grid
from mom6_tools.m6toolbox import genBasinMasks
from mom6_tools.m6plot import xyplot, ztplot
from mom6_tools.poleward_heat_transport import heatTrans, plotHeatTrans, plotGandW

import pop_tools

matplotlib.rcParams.update({'font.size': 18})
warnings.filterwarnings("ignore")

diag_config_yml = yaml.load(open('/glade/u/home/eromashkova/codes/nbscuid-examples/mom6-solutions/nblibrary/diag_config.yml','r'),
                            Loader=yaml.Loader)
# initialize lists
dcase = []; casename=[]
label = []; OUTDIR=[]
ocn_path = []
ncases = len(diag_config_yml.keys()) - 1
if ncases < 2:
  # Create the case instance
  dcase = DiagsCase(diag_config_yml['Case'], xrformat=True)
  casename.append(dcase.casename)
  label.append(diag_config_yml['Case']['SNAME'])
  ocn_path.append(diag_config_yml['Case']['OCN_DIAG_ROOT'])
  DOUT_S = dcase.get_value('DOUT_S')
  if DOUT_S:
    OUTDIR.append(dcase.get_value('DOUT_S_ROOT')+'/ocn/hist/')
  else:
    OUTDIR.append(dcase.get_value('RUNDIR')+'/')
else:
  for i in range(ncases):
    cname = 'Case{}'.format(i+1)
    dcase = DiagsCase(diag_config_yml[cname], xrformat=True)
    casename.append(dcase.casename)
    label.append(diag_config_yml[cname]['SNAME'])
    ocn_path.append(diag_config_yml[cname]['OCN_DIAG_ROOT'])
    DOUT_S = dcase.get_value('DOUT_S')
    if DOUT_S:
      OUTDIR.append(dcase.get_value('DOUT_S_ROOT')+'/ocn/hist/')
    else:
      OUTDIR.append(dcase.get_value('RUNDIR')+'/')

# set avg dates
avg = diag_config_yml['Avg']
start_date = avg['start_date']
end_date = avg['end_date']

print('Casename is {}'.format(casename))
print('─' * 60)
print('Averaged between {} and {}'.format(start_date, end_date))
print('─' * 60)
print('Output directory is:', OUTDIR)
print('─' * 60)
print('Generated by {}'.format(getpass.getuser()))
print('─' * 60)
print("Last update:", date.today())

# load grid and pre-difined masks
grd = MOM6grid(OUTDIR[0]+casename[0]+'.mom6.static.nc');
grd_xr = MOM6grid(OUTDIR[0]+casename[0]+'.mom6.static.nc', xrformat=True);
depth = grd.depth_ocean
# remote Nan's, otherwise genBasinMasks won't work
depth[np.isnan(depth)] = 0.0
basin_code = genBasinMasks(grd.geolon, grd.geolat, depth, xda=True);

# pop grid
pop_grid = pop_tools.get_grid('POP_gx1v7')

def get_heat_transport_obs():
    """Plots model vs obs poleward heat transport for the global, Pacific and Atlantic basins"""
    # Load Observations
    fObs = netCDF4.Dataset('/glade/work/gmarques/cesm/datasets/Trenberth_and_Caron_Heat_Transport.nc')
    #Trenberth and Caron
    yobs = fObs.variables['ylat'][:]
    NCEP = {}; NCEP['Global'] = fObs.variables['OTn']
    NCEP['Atlantic'] = fObs.variables['ATLn'][:]; NCEP['IndoPac'] = fObs.variables['INDPACn'][:]
    ECMWF = {}; ECMWF['Global'] = fObs.variables['OTe'][:]
    ECMWF['Atlantic'] = fObs.variables['ATLe'][:]; ECMWF['IndoPac'] = fObs.variables['INDPACe'][:]

    #G and W
    Global = {}
    Global['lat'] = np.array([-30., -19., 24., 47.])
    Global['trans'] = np.array([-0.6, -0.8, 1.8, 0.6])
    Global['err'] = np.array([0.3, 0.6, 0.3, 0.1])

    Atlantic = {}
    Atlantic['lat'] = np.array([-45., -30., -19., -11., -4.5, 7.5, 24., 47.])
    Atlantic['trans'] = np.array([0.66, 0.35, 0.77, 0.9, 1., 1.26, 1.27, 0.6])
    Atlantic['err'] = np.array([0.12, 0.15, 0.2, 0.4, 0.55, 0.31, 0.15, 0.09])

    IndoPac = {}
    IndoPac['lat'] = np.array([-30., -18., 24., 47.])
    IndoPac['trans'] = np.array([-0.9, -1.6, 0.52, 0.])
    IndoPac['err'] = np.array([0.3, 0.6, 0.2, 0.05,])

    GandW = {}
    GandW['Global'] = Global
    GandW['Atlantic'] = Atlantic
    GandW['IndoPac'] = IndoPac
    return NCEP, ECMWF, GandW, yobs

NCEP, ECMWF, GandW, yobs = get_heat_transport_obs()

def get_adv_diff(ds):
    # create a ndarray subclass
    class C(np.ndarray): pass

    varName = 'T_ady_2d'
    if varName in ds.variables:
        tmp = np.ma.masked_invalid(ds[varName].values)
        tmp = tmp[:].filled(0.)
        advective = tmp.view(C)
        advective.units = 'W'
    else:
        raise Exception('Could not find "T_ady_2d"')

    varName = 'T_diffy_2d'
    if varName in ds.variables:
        tmp = np.ma.masked_invalid(ds[varName].values)
        tmp = tmp[:].filled(0.)
        diffusive = tmp.view(C)
        diffusive.units = 'W'
    else:
        diffusive = None
        warnings.warn('Diffusive temperature term not found. This will result in an underestimation of the heat transport.')

    varName = 'T_lbd_diffy_2d'
    if varName in ds.variables:
        tmp = np.ma.masked_invalid(ds[varName].values)
        tmp = tmp[:].filled(0.)
        diffusive = diffusive + tmp.view(C)
    else:
        warnings.warn('Horizontal boundary mixing term not found. This will result in an underestimation of the heat transport.')

    return advective, diffusive

def plot_heat_trans(ds, label, linestyle='-'):
    adv, diff = get_adv_diff(ds)
    HT = heatTrans(adv,diff); y = ds.yq
    plt.plot(y, HT, linewidth=3, linestyle=linestyle, label=label)
    return

def heatTrans(advective, diffusive=None, hbd=None, vmask=None, units="W"):
  """Converts vertically integrated temperature advection/diffusion into heat transport"""
  HT = advective[:]
  if diffusive is not None:
    HT = HT + diffusive[:]
  if hbd is not None:
    HT = HT + hbd[:]
  if len(HT.shape) == 3:
    HT = HT.mean(axis=0)
  if units == "Celsius meter3 second-1":
    rho0 = 1.035e3
    Cp = 3992.
    HT = HT * (rho0 * Cp)
    HT = HT * 1.e-15  # convert to PW
  elif units == "W":
    HT = HT * 1.e-15
  else:
    print('Unknown units')
  if vmask is not None: HT = HT*vmask
  HT = HT.sum(axis=-1); HT = HT.squeeze() # sum in x-direction
  return HT

def pop_add_cyclic(ds):

    nj = ds.TLAT.shape[0]
    ni = ds.TLONG.shape[1]

    xL = int(ni/2 - 1)
    xR = int(xL + ni)

    tlon = ds.TLONG.data
    tlat = ds.TLAT.data

    tlon = np.where(np.greater_equal(tlon, min(tlon[:,0])), tlon-360., tlon)
    lon  = np.concatenate((tlon, tlon + 360.), 1)
    lon = lon[:, xL:xR]

    if ni == 320:
        lon[367:-3, 0] = lon[367:-3, 0] + 360.
    lon = lon - 360.

    lon = np.hstack((lon, lon[:, 0:1] + 360.))
    if ni == 320:
        lon[367:, -1] = lon[367:, -1] - 360.

    #-- trick cartopy into doing the right thing:
    #   it gets confused when the cyclic coords are identical
    lon[:, 0] = lon[:, 0] - 1e-8

    #-- periodicity
    lat = np.concatenate((tlat, tlat), 1)
    lat = lat[:, xL:xR]
    lat = np.hstack((lat, lat[:,0:1]))

    TLAT = xr.DataArray(lat, dims=('nlat', 'nlon'))
    TLONG = xr.DataArray(lon, dims=('nlat', 'nlon'))

    dso = xr.Dataset({'TLAT': TLAT, 'TLONG': TLONG})

    # copy vars
    varlist = [v for v in ds.data_vars if v not in ['TLAT', 'TLONG']]
    for v in varlist:
        v_dims = ds[v].dims
        if not ('nlat' in v_dims and 'nlon' in v_dims):
            dso[v] = ds[v]
        else:
            # determine and sort other dimensions
            other_dims = set(v_dims) - {'nlat', 'nlon'}
            other_dims = tuple([d for d in v_dims if d in other_dims])
            lon_dim = ds[v].dims.index('nlon')
            field = ds[v].data
            field = np.concatenate((field, field), lon_dim)
            field = field[..., :, xL:xR]
            field = np.concatenate((field, field[..., :, 0:1]), lon_dim)
            dso[v] = xr.DataArray(field, dims=other_dims+('nlat', 'nlon'),
                                  attrs=ds[v].attrs)


    # copy coords
    for v, da in ds.coords.items():
        if not ('nlat' in da.dims and 'nlon' in da.dims):
            dso = dso.assign_coords(**{v: da})


    return dso

pop_grid_o = pop_add_cyclic(pop_grid)




Basemap module not found. Some regional plots may not function properly
Casename is ['gmom.e23.GJRAv4.TL319_t061_zstar_N65.nuopc.baseline_no_ML.001', 'gmom.e23.GJRAv4.TL319_t061_hycom1_N75.nuopc.baseline_no_ML.001']
────────────────────────────────────────────────────────────
Averaged between 0031-01-01 and 0062-01-01
────────────────────────────────────────────────────────────
Output directory is: ['/glade/scratch/gmarques/gmom.e23.GJRAv4.TL319_t061_zstar_N65.nuopc.baseline_no_ML.001/run/', '/glade/scratch/gmarques/gmom.e23.GJRAv4.TL319_t061_hycom1_N75.nuopc.baseline_no_ML.001/run/']
────────────────────────────────────────────────────────────
Generated by eromashkova
────────────────────────────────────────────────────────────
Last update: 2023-04-07
MOM6 grid successfully loaded... 

MOM6 grid successfully loaded... 

11.16428 64.78855 [391, 434]


In [3]:
# empty parametrizable cell 

cluster_scheduler_address = None

In [4]:
# Parameters
not_casename = "placeholder-casename"
cluster_scheduler_address = "tcp://10.12.206.48:44885"
subset_kwargs = {}


## Connecting to cluster

In [5]:
from dask.distributed import Client

if cluster_scheduler_address is None:
    cluster, client = nbscuid.util.get_ClusterClient()
    cluster.scale(12)
else:
    client = Client(cluster_scheduler_address)
client

0,1
Connection method: Direct,
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/eromashkova/proxy/46109/status,

0,1
Comm: tcp://10.12.206.48:44885,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/eromashkova/proxy/46109/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [6]:
# load WOA18 data
obs_path = '/glade/u/home/gmarques/Notebooks/CESM_MOM6/WOA18_remapping/'
obs_temp = xr.open_dataset(obs_path+'WOA18_TEMP_tx0.66v1_34lev_ann_avg.nc', decode_times=False)
obs_salt = xr.open_dataset(obs_path+'WOA18_SALT_tx0.66v1_34lev_ann_avg.nc', decode_times=False)
# get theta and salt and rename coordinates to be the same as the model's
thetao_obs = obs_temp.theta0.rename({'z_l': 'zl'});
salt_obs = obs_salt.s_an.rename({'z_l': 'zl'});
salt_obs['yh'] = grd.yh
thetao_obs['yh'] = grd.yh

In [9]:
def plot_comparison(model, var, obs, vmin, vmax, diff, cname):
  fig, ax = plt.subplots(nrows=3, ncols=1,
                       sharex=True,sharey=True,
                       figsize=(12,14))

  ax1 = ax.flatten()

  model[var].plot(ax=ax1[0], ylim=(500,0),
                     y="zl",vmin=vmin, vmax=vmax,
                     yincrease=False)

  plt.show()

  obs.sel(yh=0., 
               method="nearest").plot(ax=ax1[1], ylim=(500, 0),
                                      y="zl",vmin=vmin, vmax=vmax,
                                      yincrease=False)

  (model[var] - obs.sel(yh=0., 
               method="nearest")).plot(ax=ax1[2], ylim=(500, 0),
                                      y="zl", cmap="bwr", vmin=-diff, vmax=diff, 
                                      yincrease=False)

  ax[0].set_title('Model, casename = {}'.format(cname)); ax[0].set_xlabel('')
  ax[1].set_title('Obs (WOA18)'); ax[1].set_xlabel('')
  ax[2].set_title('Model - Obs'); ax[2].set_xlabel('Longitude')
    
  print("yes this actually ran")

  return fig

#### Potential temperature

In [10]:
for path, case, i in zip(ocn_path, casename, range(len(casename))):
  ds_temp = xr.open_dataset(ocn_path[i]+casename[i]+'_temp_eq.nc')
  plot_comparison(ds_temp,'temp_eq',thetao_obs,10,30,5,label[i])

yes this actually ran
yes this actually ran


### Salinity

In [9]:
for path, case, i in zip(ocn_path, casename, range(len(casename))):
  ds_salt = xr.open_dataset(ocn_path[i]+casename[i]+'_salt_eq.nc')
  plot_comparison(ds_salt,'salt_eq',salt_obs,33,37,1,label[i])

yes this actually ran
yes this actually ran
