In [None]:
import sys
import os
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
import numpy as np
import pandas as pd
sys.path.append(os.path.join(os.path.pardir, 'gotmtool'))
from gotmtool import *

In [None]:
casename = 'lsc_ymc22_sbl_bbl_v2'
turbmethod = 'SMCLT-H15'
# turbmethod = 'SMCLT-KC04'
turbmethod = 'KPPLT-LF17'
gotm_dir  = os.path.join(os.path.pardir, 'gotm', 'run', '{:s}'.format(casename))
gotm_sim  = Simulation(path=os.path.join(gotm_dir, turbmethod))

figpath  = os.path.join(os.path.pardir, 'gotm', 'figure', '{:s}-{:s}'.format(casename, turbmethod))
os.makedirs(figpath, exist_ok=True)

In [None]:
ds = gotm_sim.load_data()
ds

In [None]:
g = 9.81
H = 30
u10 = 8
N2 = 1.962e-4
bstar = N2 * H
cd = 1.25e-3
rhoa = 1.225
rhoo = 1026
tau = rhoa/rhoo*cd*u10*u10
ustar = np.sqrt(tau)
amplitude = 1.0
wavelength = 60
wavenumber = 2.*np.pi/wavelength
frequency = np.sqrt(g*wavenumber*np.tanh(wavenumber*H))
us0 = amplitude**2*wavenumber
la = np.sqrt(ustar/us0)
print(la)

In [None]:
tavg1 = dict(starttime='2000-01-04T00:00:00', endtime='2000-01-04T17:00:00', line_kw=dict(color='k', linestyle='--'))
tavg2 = dict(starttime='2000-01-10T00:00:00', endtime='2000-01-10T17:00:00', line_kw=dict(color='k', linestyle='-'))
# tavg2 = dict(starttime='2000-01-12T00:00:00', endtime='2000-01-12T17:00:00', line_kw=dict(color='k', linestyle='-'))
tavgs = dict(T1=tavg1, T2=tavg2)

In [None]:
def plot_overview(das, levels, labels, tavgs):

    nv = len(das)
    fig, axarr = plt.subplots(nv, 2, gridspec_kw={'width_ratios': [1, 5]})
    fig.set_size_inches([8, 0.4+2*nv])
    rlcolor = {'RdBu_r': 'k', 'viridis': 'w'}
    date_form = DateFormatter("%d")
    for i, var in enumerate(das.keys()):
        ax = np.ravel(axarr)[i*2+1]
        cf = das[var].plot(ax=ax, levels=levels[var], cbar_kwargs={'label': labels[var]})
        cmap = cf.get_cmap().name
        for j, tag in enumerate(tavgs.keys()):
            ax.axvline(x=pd.Timestamp(tavgs[tag]['starttime']), linestyle=':', color=rlcolor[cmap])
            ax.axvline(x=pd.Timestamp(tavgs[tag]['endtime']), linestyle=':', color=rlcolor[cmap])
            ax.text(pd.Timestamp(tavgs[tag]['starttime']), 0, tag, va='bottom', ha='left')
        ax.set_title('')
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.xaxis.set_major_formatter(date_form)
        for lb in ax.get_xticklabels(which='major'):
            lb.set(rotation=0, horizontalalignment='right')
        ax = np.ravel(axarr)[i*2+0]
        for j, tag in enumerate(tavgs.keys()):
            tslice = slice(tavgs[tag]['starttime'], tavgs[tag]['endtime'])
            das[var].sel(time=tslice).mean(dim='time').plot(ax=ax, y=das[var].dims[0], label=tag, **tavgs[tag]['line_kw'])
        ax.set_title('')
        ax.set_xlabel(labels[var])
        ax.set_ylabel('Depth [m]')
    axarr[0,0].legend()
    axarr[-1,-1].set_xlabel('Time [day]')

    plt.tight_layout()
    return fig

def get_flux(lam, num, gam):
    tmp = xr.zeros_like(num)
    nt = lam.shape[1]
    for i in np.arange(nt):
        tmp.data[1:-1,i] = (lam.data[:-1,i]-lam.data[1:,i])/(lam.z.data[:-1]-lam.z.data[1:])
    flux = - num * tmp + gam
    return flux

In [None]:
das = dict(
    T  = ds.data_vars['temp'],
    u = (ds.data_vars['u']+ds.data_vars['us'])/ustar,
    v = (ds.data_vars['v']+ds.data_vars['vs'])/ustar,
)
labels = dict(
    T = '$T$ [$^\circ$C]',
    u = '$u/u_*$',
    v = '$v/u_*$',
)
levels = dict(
    T = np.linspace(17,20,31),
    u = np.linspace(-40, 40, 41),
    v = np.linspace(-15, 15, 31),
)
fig = plot_overview(das, levels, labels, tavgs)
figname = os.path.join(figpath, 'mean')
fig.savefig(figname, dpi = 300, facecolor='w')

In [None]:
das = dict(
    wT = get_flux(ds.data_vars['temp'].squeeze(),
                  ds.data_vars['nuh'].squeeze(),
                  ds.data_vars['gamh'].squeeze())*1e4,
    wu = (get_flux(ds.data_vars['u'].squeeze(),
                  ds.data_vars['num'].squeeze(),
                  ds.data_vars['gamu'].squeeze())
          -ds.data_vars['nucl'].squeeze()*ds.data_vars['dusdz'].squeeze())/ustar**2,
    wv = (get_flux(ds.data_vars['v'].squeeze(),
                  ds.data_vars['num'].squeeze(),
                  ds.data_vars['gamv'].squeeze())
          -ds.data_vars['nucl'].squeeze()*ds.data_vars['dvsdz'].squeeze())/ustar**2,
)
labels = dict(
    wT = '$10^4\overline{w^\prime T^\prime}$ [$^\circ$C m/s]',
    wu = '$\overline{w^\prime u^\prime}/u_*^2$',
    wv = '$\overline{w^\prime v^\prime}/u_*^2$',
)
levels = dict(
    wT = np.linspace(-1.6, 1.6, 41),
    wu = np.linspace(-2, 2, 41),
    wv = np.linspace(-1.2, 1.2, 41),
)
fig = plot_overview(das, levels, labels, tavgs)
figname = os.path.join(figpath, 'fluxT')
fig.savefig(figname, dpi = 300, facecolor='w')