In [None]:
import os
import sys
import copy
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
sys.path.append(os.path.join(os.path.pardir, 'gotmtool'))
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
from gotmtool import *
from lesview import *

In [None]:
casename = 'lsc_ymc22_sbl_bbl_rf'
turbmethod = 'SMCLT-H15'
# turbmethod = 'KPPLT-LF17'

datapath = os.path.join(os.path.pardir, 'oceananigans', '{:s}'.format(casename))
filepath = os.path.join(datapath, 'averages.jld2')
data_pfl = OceananigansDataProfile(filepath=filepath)
les_ds = data_pfl.dataset

gotm_dir   = os.path.join(os.path.pardir, 'gotm', 'run', '{:s}'.format(casename))
gotm_sim1  = Simulation(path=os.path.join(gotm_dir, turbmethod))
gotm_sim2  = Simulation(path=os.path.join(gotm_dir, turbmethod+'-Rlx'))
gotm_ds1   = gotm_sim1.load_data()
gotm_ds2   = gotm_sim2.load_data()

In [None]:
les_ds

In [None]:
gotm_ds1

In [None]:
gotm_ds2

In [None]:
g = 9.81
H = 30
u10 = 10
cd = 1.25e-3
rhoa = 1.225
rhoo = 1026
tau = rhoa/rhoo*cd*u10*u10
ustar = np.sqrt(tau)
amplitude = 1.13
wavelength = 60
wavenumber = 2.*np.pi/wavelength
frequency = np.sqrt(g*wavenumber*np.tanh(wavenumber*H))
us0 = amplitude**2*wavenumber
la = np.sqrt(ustar/us0)
print(la)

In [None]:
# one inertial period (17 hours)
startdate1 = '2000-01-02T00:00:00'
enddate1 = '2000-01-02T17:00:00'
tslice1 = slice(startdate1, enddate1)
startdate2 = '2000-01-06T00:00:00'
enddate2 = '2000-01-06T17:00:00'
tslice2 = slice(startdate2, enddate2)

In [None]:
def cmp_da(da0, da1, da2, var, units, levels, cmap, depth=-30):

    if units == 'unitless':
        xlabel = var
    else:
        xlabel = var+' ['+units+']'
    rlcolor = {
        'viridis': 'w',
        'RdBu_r': 'k',
              }
    fig1, axarr = plt.subplots(3, 1, sharex='col')
    fig1.set_size_inches([7, 7])
    da0.plot(ax=axarr[0], levels=levels, cmap=cmap, cbar_kwargs={'label': xlabel})
    da1.plot(ax=axarr[1], levels=levels, cmap=cmap, cbar_kwargs={'label': xlabel})
    da2.plot(ax=axarr[2], levels=levels, cmap=cmap, cbar_kwargs={'label': xlabel})
    for j in np.arange(3):
        axarr[j].set_ylim([depth, 0])
        axarr[j].set_xlabel('')
        axarr[j].set_title('')
        axarr[j].set_ylabel('Depth [m]')
        axarr[j].axvline(x=pd.Timestamp(startdate1), linestyle=':', color=rlcolor[cmap])
        axarr[j].axvline(x=pd.Timestamp(enddate1),   linestyle=':', color=rlcolor[cmap])
        axarr[j].axvline(x=pd.Timestamp(startdate2), linestyle=':', color=rlcolor[cmap])
        axarr[j].axvline(x=pd.Timestamp(enddate2),   linestyle=':', color=rlcolor[cmap])
        axarr[j].text(pd.Timestamp(startdate1), 0, 'T1', va='bottom', ha='left')
        axarr[j].text(pd.Timestamp(startdate2), 0, 'T2', va='bottom', ha='left')
    plt.tight_layout()

    fig2, axarr2 = plt.subplots(1, 2, sharey='row')
    fig2.set_size_inches([6,3])
    da0.sel(time=tslice1).mean(dim='time').plot(ax=axarr2[0], y=da0.dims[0], color='k', linestyle='-', label='LES')
    da1.sel(time=tslice1).mean(dim='time').plot(ax=axarr2[0], y=da1.dims[0], color='k', linestyle='--', label='GOTM')
    da2.sel(time=tslice1).mean(dim='time').plot(ax=axarr2[0], y=da2.dims[0], color='k', linestyle=':', label='GOTM-R')
    da0.sel(time=tslice2).mean(dim='time').plot(ax=axarr2[1], y=da0.dims[0], color='k', linestyle='-', label='LES')
    da1.sel(time=tslice2).mean(dim='time').plot(ax=axarr2[1], y=da1.dims[0], color='k', linestyle='--', label='GOTM')
    da2.sel(time=tslice2).mean(dim='time').plot(ax=axarr2[1], y=da2.dims[0], color='k', linestyle=':', label='GOTM-R')
    vmin = np.min(levels)
    vmax = np.max(levels)
    ylabel = ['Depth [m]', '']
    title = ['Time averaged T1', 'Time averaged T2']
    for j in np.arange(2):
        axarr2[j].set_ylim([depth, 0])
        axarr2[j].set_xlim([vmin, vmax])
        axarr2[j].set_xlabel(xlabel)
        axarr2[j].set_ylabel(ylabel[j])
        axarr2[j].set_title(title[j])
        axarr2[j].legend()
    plt.tight_layout()
    
    
def plot_mean_fields(var, units, levels, cmap, lesvar, gotmvar, lesshift=0, scale=1):
    if lesvar == 'tke':
        da0 = (0.5*(les_ds.data_vars['uu']+les_ds.data_vars['vv']+les_ds.data_vars['ww'].interp(zi=les_ds.z))+lesshift)*scale
    else:
        da0 = (les_ds.data_vars[lesvar]+lesshift)*scale
    if gotmvar == 'u':
        da1 = (gotm_ds1.data_vars[gotmvar]+gotm_ds1.data_vars['us']).squeeze()*scale
        da2 = (gotm_ds2.data_vars[gotmvar]+gotm_ds2.data_vars['us']).squeeze()*scale
    elif gotmvar == 'v':
        da1 = (gotm_ds1.data_vars[gotmvar]+gotm_ds1.data_vars['vs']).squeeze()*scale
        da2 = (gotm_ds2.data_vars[gotmvar]+gotm_ds2.data_vars['vs']).squeeze()*scale
    else:
        da1 = gotm_ds1.data_vars[gotmvar].squeeze()*scale
        da2 = gotm_ds2.data_vars[gotmvar].squeeze()*scale
    for da in [da0, da1, da2]:
        da.attrs['long_name'] = var
        da.attrs['units'] = units
    cmp_da(da0, da1, da2, var, units, levels, cmap)


def get_flux(lam, num, gam):
    tmp = xr.zeros_like(num)
    nt = lam.shape[1]
    for i in np.arange(nt):
        tmp.data[1:-1,i] = (lam.data[:-1,i]-lam.data[1:,i])/(lam.z.data[:-1]-lam.z.data[1:])
    flux = - num * tmp + gam
    return flux
    
    
def plot_turbulent_flux(var, units, levels, cmap, lesvar1, lesvar2, gotmvar1, gotmvar2, gotmvar3, gotmvar4=None, gotmvar5=None, scale=1):
    da0 = (les_ds.data_vars[lesvar1]+les_ds.data_vars[lesvar2])*scale
    da1 = get_flux(gotm_ds1.data_vars[gotmvar1].squeeze(),
                   gotm_ds1.data_vars[gotmvar2].squeeze(),
                   gotm_ds1.data_vars[gotmvar3].squeeze())*scale
    da2 = get_flux(gotm_ds2.data_vars[gotmvar1].squeeze(),
                   gotm_ds2.data_vars[gotmvar2].squeeze(),
                   gotm_ds2.data_vars[gotmvar3].squeeze())*scale
    if (gotmvar4 is not None) and (gotmvar5 is not None):
        da1 -= gotm_ds1.data_vars[gotmvar4].squeeze()*gotm_ds1.data_vars[gotmvar5].squeeze()*scale
        da2 -= gotm_ds2.data_vars[gotmvar4].squeeze()*gotm_ds2.data_vars[gotmvar5].squeeze()*scale
    for da in [da0, da1, da2]:
        da.attrs['long_name'] = var
        da.attrs['units'] = units
    cmp_da(da0, da1, da2, var, units, levels, cmap)

In [None]:
# T
levels = np.linspace(17, 20, 31)
plot_mean_fields('T', '$^\circ$C', levels, 'viridis', 'T', 'temp')

In [None]:
# u
levels = np.linspace(0, 40, 41)
plot_mean_fields('$\overline{u}/u_*$', 'unitless', levels, 'viridis', 'u', 'u', scale=1/ustar)

In [None]:
# v
levels = np.linspace(-15, 15, 31)
plot_mean_fields('$\overline{v}/u_*$', 'unitless', levels, 'RdBu_r', 'v', 'v', scale=1/ustar)

In [None]:
levels = np.linspace(-1.6, 1.6, 41)
plot_turbulent_flux('wT', '$10^4*^\circ$C m/s', levels, 'RdBu_r', 'wt', 'wtsb', 'temp', 'nuh', 'gamh', scale=1e4)

In [None]:
levels = np.linspace(-2, 2, 41)
plot_turbulent_flux('$\overline{w^\prime u^\prime}/u_*^2$', 'unitless', levels, 'RdBu_r', 'wu', 'wusb', 'u', 'num', 'gamu', 'dusdz', 'nucl', scale=1/ustar**2)

In [None]:
levels = np.linspace(-1.2, 1.2, 41)
# levels = None
plot_turbulent_flux('$\overline{w^\prime v^\prime}/u_*^2$', 'unitless', levels, 'RdBu_r', 'wv', 'wvsb', 'v', 'num', 'gamv', 'dvsdz', 'nucl', scale=1/ustar**2)

In [None]:
# uu
levels = np.linspace(0, 15, 31)
plot_mean_fields('$\overline{u^\prime u^\prime}/u_*^2$', 'unitless', levels, 'viridis', 'uu', 'uu', scale=1/ustar**2)

In [None]:
# vv
levels = np.linspace(0, 15, 31)
plot_mean_fields('$\overline{v^\prime v^\prime}/u_*^2$', 'unitless', levels, 'viridis', 'vv', 'vv', scale=1/ustar**2)

In [None]:
# ww
levels = np.linspace(0, 6, 31)
plot_mean_fields('$\overline{w^\prime w^\prime}/u_*^2$', 'unitless', levels, 'viridis', 'ww', 'ww', scale=1/ustar**2)

In [None]:
# TKE
levels = np.linspace(0, 10, 41)
plot_mean_fields('$TKE/u_*^2$', 'unitless', levels, 'viridis', 'tke', 'tke', scale=1/ustar**2)