In [None]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
from lesview import *

In [None]:
casenames = dict(
    c1 = 'R11_MSM97-CT_Stokes_f0_fixdt3',
    c2 = 'R11_MSM97-CT_Stokes_f0_fixdt3_visL',
    c3 = 'R11_MSM97-CT_Stokes_f0',
)
models = dict(
    c1 = 'ncarles',
    c2 = 'ncarles',
    c3 = 'oceananigans',
)
filenames = dict(
    ncarles = 'his.mp.vis.000001.028801.nc',
    oceananigans = 'averages.jld2',
)
linestyles = dict(
    c1 = '--',
    c2 = ':',
    c3 = '-'
)
labels = dict(
    c1 = 'NCAR LES',
    c2 = 'NCAR LES (sgsL)',
    c3 = 'Oceananigans',
)

figpath = 'R11_MSM97_CT_Stokes_f0'
os.makedirs(figpath, exist_ok=True)

In [None]:
ds = {}
for cn in casenames.keys():
    print(cn)
    datapath = os.path.join(os.path.pardir, models[cn], casenames[cn], filenames[models[cn]])
    if models[cn] == 'ncarles':
        ds[cn] = NCARLESDataProfile(filepath=datapath).dataset
    elif models[cn] == 'oceananigans':
        ds[cn] = OceananigansDataProfile(filepath=datapath).dataset
    else:
        raise ValueError('Model {:s} not supported'.format(models[cn]))
        

In [None]:
ds['c1']

In [None]:
g = 9.81
a = 0.8
wavenum = 2*np.pi/60
f = np.sqrt(g*wavenum)
us0 = a**2*wavenum*f
alpha = 2e-4
NNT0 = 0.01
NN0 = alpha*g*NNT0
Qt0 = 1.221e-5
h0 = 33
wstar = (alpha*g*Qt0*h0)**(1/3)
depth = 50
def get_vars(ds, var):
    if 'uxym' in ds.data_vars.keys():
        if var == 'mkeL':
            return 0.5*((ds.data_vars['uxym']+ds.data_vars['stokes'])**2+ds.data_vars['vxym']**2)/wstar**2
        elif var == 'mkeE':
            return 0.5*(ds.data_vars['uxym']**2+ds.data_vars['vxym']**2)/wstar**2
        elif var == 'tke':
            return 0.5*(ds.data_vars['ups']+ds.data_vars['vps']+ds.data_vars['wps'].interp(zi=ds.z))/wstar**2
        else:
            raise ValueError('Variable {:s} not found.'.format(var))
    else:
        us = xr.zeros_like(ds.data_vars['u'])
        us.data = np.broadcast_to(us0 * np.exp(2.*wavenum*us.z), [us.time.size, us.z.size]).transpose()
        if var == 'mkeL':
            return 0.5*(ds.data_vars['u']**2+ds.data_vars['v']**2)/wstar**2
        elif var == 'mkeE':
            return 0.5*((ds.data_vars['u'] - us)**2+ds.data_vars['v']**2)/wstar**2
        elif var == 'tke':
            return 0.5*(ds.data_vars['uu']+ds.data_vars['vv']+ds.data_vars['ww'].interp(zi=ds.z))/wstar**2
        else:
            raise ValueError('Variable {:s} not found.'.format(var))

def time2hour(da):
    tdim = da.dims[0]
    time_dtime = pd.to_datetime(da.coords[tdim].data)
    time_sec = (time_dtime-time_dtime[0]).total_seconds()
    time_hour = time_sec/3600
    da_new = da.assign_coords({tdim: time_hour})
    da_new.coords[tdim].attrs['long_name'] = '$t$'
    da_new.coords[tdim].attrs['units'] = 'hr'
    return da_new

In [None]:
fig, axarr = plt.subplots(2, 1, sharex='col')
fig.set_size_inches(5,4.5)

mkeLlb = '$\\frac{1}{2}\\langle|\\overline{\\mathbf{u}}^L|^2\\rangle/w_*^2$'
mkeElb = '$\\frac{1}{2}\\langle|\\overline{\\mathbf{u}}|^2\\rangle/w_*^2$'
tkelb  = '$\\frac{1}{2}\\langle\\overline{|\\mathbf{u}^\prime|^2}\\rangle/w_*^2$'

ax = axarr[0]
for i, cn in enumerate(casenames.keys()):
    mkeL = get_vars(ds[cn], 'mkeL')
    mkeLM = time2hour(mkeL.where(mkeL.z>=-depth).mean(dim='z'))
    mkeLM.plot(ax=ax, color='k', linestyle=linestyles[cn])
    mkeE = get_vars(ds[cn], 'mkeE')
    mkeEM = time2hour(mkeE.where(mkeE.z>=-depth).mean(dim='z'))
    mkeEM.plot(ax=ax, color='gray', linestyle=linestyles[cn])
ax.set_ylabel('MKE')
ax.set_xlabel('')
ax.set_ylim([0,1.5])
ax.text(0.03, 0.08, '(a)', transform=ax.transAxes, va='bottom', ha='left')

ax.plot(np.nan, np.nan, color='k', label=mkeLlb)
ax.plot(np.nan, np.nan, color='gray', label=mkeElb)
ax.legend(loc='center right', bbox_to_anchor=(1, 0.67), ncol=2, framealpha=1)
ax = axarr[1]
for i, cn in enumerate(casenames.keys()):
    tke = get_vars(ds[cn], 'tke')
    tkeM = time2hour(tke.where(tke.z>=-depth).mean(dim='z'))
    tkeM.plot(ax=ax, color='k', linestyle=linestyles[cn], label=labels[cn])
ax.set_ylabel('TKE')
ax.set_xlabel('Time [hour]')
ax.set_ylim([0,0.4])
ax.legend(ncol=2, loc='lower right', framealpha=1)
ax.text(0.03, 0.92, '(b)', transform=ax.transAxes, va='top', ha='left')

for i in np.arange(2):
    ax = axarr[i]
    ax.grid()
    ax.set_xlim([0,24])

plt.tight_layout()
figname = os.path.join(figpath, 'timeseries-ke')
fig.savefig(figname, dpi = 300, facecolor='w')