In [None]:
import sys
import os
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
import numpy as np
import pandas as pd
import xarray as xr
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
from lesview import *
from sbl_bbl import get_edges

In [None]:
g = 9.81
H = 30
u10 = 8
N2 = 1.962e-4
bstar = N2 * H
cd = 1.25e-3
rhoa = 1.225
rhoo = 1026
tau = rhoa/rhoo*cd*u10*u10
ustar = np.sqrt(tau)
amplitude = 1.0
wavelength = 60
wavenumber = 2.*np.pi/wavelength
frequency = np.sqrt(g*wavenumber*np.tanh(wavenumber*H))
us0 = amplitude**2*wavenumber
la = np.sqrt(ustar/us0)
print(la)

In [None]:
casename = 'lsc_ymc22_sbl_bbl_v2'
datapath = os.path.join(os.path.pardir, 'tests', '{:s}'.format(casename))
figpath  = 'overview_{:s}'.format(casename)
os.makedirs(figpath, exist_ok=True)

In [None]:
filepath = os.path.join(datapath, 'averages.jld2')
data_pfl = OceananigansDataProfile(filepath=filepath)

In [None]:
ds = data_pfl.dataset
ds

In [None]:
tavg1 = dict(starttime='2000-01-04T00:00:00', endtime='2000-01-04T17:00:00', line_kw=dict(color='k', linestyle='--'))
tavg2 = dict(starttime='2000-01-11T00:00:00', endtime='2000-01-11T17:00:00', line_kw=dict(color='k', linestyle='-'))
tavgs = dict(T1=tavg1, T2=tavg2)

In [None]:
rlcolor = {'RdBu_r': 'k', 'viridis': 'w', 'bone': 'w', 'bone_r': 'k'}

In [None]:
das = dict(
    NN = ds.data_vars['b'].differentiate(coord='z')/N2,
    SS = (ds.data_vars['u'].differentiate(coord='z')**2+ds.data_vars['v'].differentiate(coord='z')**2)/N2,
    Ri = ds.data_vars['b'].differentiate(coord='z')/(ds.data_vars['u'].differentiate(coord='z')**2+ds.data_vars['v'].differentiate(coord='z')**2),
    wb = (ds.data_vars['wb']+ds.data_vars['wbsb'])/ustar/bstar*1e3,
    wu = (ds.data_vars['wu']+ds.data_vars['wusb'])/ustar**2,
    wv = (ds.data_vars['wv']+ds.data_vars['wvsb'])/ustar**2,
)
labels = dict(
    NN = '$N^2/N_0^2$',
    SS = '$S^2/N_0^2$',
    Ri = '$N^2/S^2$',
    wb = '$10^3\overline{w^\prime b^\prime}/u_*b_*$',
    wu = '$\overline{w^\prime u^\prime}/u_*^2$',
    wv = '$\overline{w^\prime v^\prime}/u_*^2$',
)
levels = dict(
    NN = np.linspace(0, 3.2, 41),
    SS = np.linspace(0, 6.4, 41),
    Ri = np.linspace(0, 1, 5),
    wb = np.linspace(-4, 4, 41),
    wu = np.linspace(-2, 2, 41),
    wv = np.linspace(-1.2, 1.2, 41),
)
cmaps = dict(
    NN = 'bone_r',
    SS = 'bone_r',
    Ri = 'bone_r',
    wb = 'RdBu_r',
    wu = 'RdBu_r',
    wv = 'RdBu_r',
)
loc = dict(
    NN = [0,0],
    SS = [1,0],
    Ri = [2,0],
    wb = [0,1],
    wu = [1,1],
    wv = [2,1],
)
xylabel = dict(
    NN = [False, True],
    SS = [False, True],
    Ri = [True, True],
    wb = [False, False],
    wu = [False, False],
    wv = [True, False],
)
abc = dict(
    NN = '(a)',
    SS = '(b)',
    Ri = '(c)',
    wb = '(d)',
    wu = '(e)',
    wv = '(f)',
)

z0, z1 = get_edges(ds.data_vars['b'].differentiate(coord='z')/N2)
line_kwargs=dict(color='k', linewidth=0.75)

nv = int(len(das)/2)
fig, axarr = plt.subplots(nv, 2, sharex='col')
fig.set_size_inches([10, 0.4+1.5*nv])
date_form = DateFormatter("%d")
for i, var in enumerate(das.keys()):
    al = loc[var]
    ax = axarr[al[0], al[1]]
    cf = das[var].plot(ax=ax, levels=levels[var], cmap=cmaps[var], cbar_kwargs={'label': labels[var]})
    cmap = cf.get_cmap().name
    for j, tag in enumerate(tavgs.keys()):
        ax.axvline(x=pd.Timestamp(tavgs[tag]['starttime']), linestyle=':', color=rlcolor[cmap])
        ax.axvline(x=pd.Timestamp(tavgs[tag]['endtime']), linestyle=':', color=rlcolor[cmap])
        if al[0] == 0:
            ax.text(pd.Timestamp(tavgs[tag]['starttime']), 0, tag, va='bottom', ha='left')
    z0.plot(ax=ax, **line_kwargs)
    z1.plot(ax=ax, **line_kwargs)
    ax.xaxis.set_major_formatter(date_form)
    for lb in ax.get_xticklabels(which='major'):
        lb.set(rotation=0, horizontalalignment='right')
    ax.text(0.98, 0.92, abc[var], transform=ax.transAxes, va='top', ha='right')
    ax.set_title('')
    if xylabel[var][0]:
        ax.set_xlabel('Time [day]')
    else:
        ax.set_xlabel('')
    if xylabel[var][1]:
        ax.set_ylabel('Depth [m]')
    else:
        ax.set_ylabel('')

plt.tight_layout()
plt.subplots_adjust(hspace=0.15, wspace=0.12)
figname = os.path.join(figpath, 'mean')
fig.savefig(figname, dpi = 300, facecolor='w')