In [None]:
import sys
import os
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import numpy as np
import pandas as pd
import xarray as xr
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
sys.path.append(os.path.join(os.path.pardir, 'gotmtool'))
from lesview import *
from gotmtool import *
from sbl_bbl import *

In [None]:
g = 9.81
H = 30
u10 = 8
N2 = 1.962e-4
bstar = N2 * H
cd = 1.25e-3
rhoa = 1.225
rhoo = 1026
lat = 45.
alphaT = 2.0e-4
tau = rhoa/rhoo*cd*u10*u10
ustar = np.sqrt(tau)
amplitude = 1.0
wavelength = 60
wavenumber = 2.*np.pi/wavelength
frequency = np.sqrt(g*wavenumber*np.tanh(wavenumber*H))
us0 = amplitude**2*wavenumber
la = np.sqrt(ustar/us0) 
print('La = {:6.3f}'.format(la))
Ti = inertial_period(lat)
print('Ti = {:6.3f}'.format(Ti))

In [None]:
casename = 'lsc_ymc22_sbl_bbl_v2'
figpath  = 'overview_{:s}'.format(casename)
os.makedirs(figpath, exist_ok=True)

In [None]:
rlxlist = ['Rlx{:g}'.format(trlx) for trlx in [0, 60., 600., 3600.]]
rlxlist

In [None]:
turbmethod = 'SMCLT-H15'
# turbmethod = 'KPPLT-LF17'
runs = {'r1': '', 'r2': '_rf'}
nlev = 144
rlx = 'Rlx0'
ds_gotm = {}
ds_ocgn = {}
for rkey in runs.keys():
    ds_pfls = {}
    for rlx in rlxlist:
        gotm_dir = os.path.join(os.path.pardir, 'gotm', 'run', '{:s}{:s}'.format(casename, runs[rkey]))
        gotm_sim = Simulation(path=os.path.join(gotm_dir, '{:s}_L{:g}_{:s}'.format(turbmethod, nlev, rlx)))
        ds_pfls[rlx] = gotm_sim.load_data()
    ds_gotm[rkey] = ds_pfls
    ocgn_dir = os.path.join(os.path.pardir, 'tests', '{:s}{:s}'.format(casename, runs[rkey]))
    filepath = os.path.join(ocgn_dir, 'averages.jld2')
    ds_ocgn[rkey] = OceananigansDataProfile(filepath=filepath).dataset

In [None]:
tslice_gotm = {}
tslice_ocgn = {}
for rkey in runs.keys():
    print('---- GOTM {:s} ----'.format(rkey))
    tslice_tmp = {}
    for rlx in rlxlist:
        N2_gotm = nondim_da(ds_gotm[rkey][rlx].data_vars['temp'][:,:,0,0].differentiate(coord='z')*alphaT*g/N2, H=H, Tf=Ti)
        tslice_tmp[rlx]= get_tslice(N2_gotm, Ti)
    tslice_gotm[rkey] = tslice_tmp
    print('---- Oceananigans {:s} ----'.format(rkey))
    N2_ocgn = nondim_da(ds_ocgn[rkey].data_vars['b'].differentiate(coord='z')/N2, H=H, Tf=Ti)
    tslice_ocgn[rkey] = get_tslice(N2_ocgn, Ti)

In [None]:
def get_das_gotm(ds_gotm, U0=0.0):
    das_gotm = dict(
            # u  = (ds_gotm.data_vars['u']+ds_gotm.data_vars['us']-U0)/ustar,
            # v  = (ds_gotm.data_vars['v']+ds_gotm.data_vars['vs'])/ustar,
            u  = (ds_gotm.data_vars['u']-U0)/ustar,
            v  = (ds_gotm.data_vars['v'])/ustar,
            wu = (get_flux(ds_gotm.data_vars['u'].squeeze(),
                          ds_gotm.data_vars['num'].squeeze(),
                          ds_gotm.data_vars['gamu'].squeeze())
                  -ds_gotm.data_vars['nucl'].squeeze()*ds_gotm.data_vars['dusdz'].squeeze())/ustar**2,
            wv = (get_flux(ds_gotm.data_vars['v'].squeeze(),
                          ds_gotm.data_vars['num'].squeeze(),
                          ds_gotm.data_vars['gamv'].squeeze())
                  -ds_gotm.data_vars['nucl'].squeeze()*ds_gotm.data_vars['dvsdz'].squeeze())/ustar**2,
        )
    return das_gotm

def get_das_ocgn(ds_ocgn, U0=0):
    das_ocgn = dict(
            u  = (ds_ocgn.data_vars['u']-U0)/ustar,
            v  = ds_ocgn.data_vars['v']/ustar,
            wu = (ds_ocgn.data_vars['wu']+ds_ocgn.data_vars['wusb'])/ustar**2,
            wv = (ds_ocgn.data_vars['wv']+ds_ocgn.data_vars['wvsb'])/ustar**2,
        )
    return das_ocgn

In [None]:
labels = dict(
    u  = '$(\overline{u}-U_0)/u_*$',
    wu = '$\overline{w^\prime u^\prime}/u_*^2$',
    v  = '$\overline{v}/u_*$',
    wv = '$\overline{w^\prime v^\prime}/u_*^2$',
)
abc = dict(
    u  = 'aeim',
    wu = 'bfjn',
    v  = 'cgko',
    wv = 'dhlp',
)
abc_loc = dict(
    u  = ['br', 'br'],
    # wu = ['bl', 'bl'],
    wu = ['tl', 'tr'],    
    v  = ['tr', 'tr'],
    wv = ['br', 'br'],
)
lr = {'l': 'left', 'r': 'right'}
tb = {'t': 'top', 'b': 'bottom'}
abc_x = {'left': 0.1, 'right': 0.9}
abc_y = {'top': 0.9, 'bottom': 0.1}
rline = dict(
    u  = 'tt',
    v  = 'tf',
    wu = 'tt',
    wv = 'tt',
)
tf = {'t': True, 'f': False}
tags = ['Aligned', 'Opposite']
titles = ['T1', 'T2']
U0 = {'r1': 0.25, 'r2': -0.25}

alpha = dict(
    Rlx0      = 1,
    Rlx60     = 1,
    Rlx600    = 0.6,
    Rlx3600   = 0.3,
    Rlx86400  = 0.1,
)
linestyle = dict(
    Rlx0      = '--',
    Rlx60     = '-',
    Rlx600    = '-',
    Rlx3600   = '-',
    Rlx86400  = '-',
)
inset_loc = {'r1': 2, 'r2':7}
inset_xlim = {'r1': [-20,0], 'r2': [0,20]}
fig, axarr = plt.subplots(4, 4, sharey='row')
fig.set_size_inches([9, 9])

for k, rkey in enumerate(runs.keys()):
    # LES
    das_ocgn = nondim_das(get_das_ocgn(ds_ocgn[rkey], U0=U0[rkey]), H=H, Tf=Ti)
    # GOTM
    das_gotm = {}
    for rlx in rlxlist:
        das_gotm[rlx] = nondim_das(get_das_gotm(ds_gotm[rkey][rlx], U0=U0[rkey]), H=H, Tf=Ti)
    for j in np.arange(2):
        ax = axarr[0,k*2+j]
        # Stokes drift
        da = nondim_da(ds_gotm[rkey]['Rlx0'].data_vars['us']/ustar, H=H, Tf=Ti).isel(time=0)
        da.plot(ax=ax, y=da.dims[0], linestyle='-', linewidth=1, color='tab:red', zorder=1)
        axins = inset_axes(ax, width="40%", height="50%", loc=inset_loc[rkey], borderpad=1)
        for i, var in enumerate(labels.keys()):
            ax = axarr[i,k*2+j]
            if tf[rline[var][k]]:
                ax.axvline(x=0, linewidth=0.75, color='k', zorder=0)
            da1 = das_ocgn[var].sel(time=tslice_ocgn[rkey][j]).mean(dim='time')
            l1, = da1.plot(ax=ax, y=da1.dims[0], linestyle='-', linewidth=1, color='k')
            if i == 0:
                da1.plot(ax=axins, y=da1.dims[0], linestyle='-', linewidth=0.75, color='k')
            for rlx in rlxlist:
                da2 = das_gotm[rlx][var].sel(time=tslice_gotm[rkey][rlx][j]).mean(dim='time')
                l2, = da2.plot(ax=ax, y=da2.dims[0], linestyle=linestyle[rlx], linewidth=1, color='tab:blue', alpha=alpha[rlx])
                if i == 0:
                    da2.plot(ax=axins, y=da2.dims[0], linestyle=linestyle[rlx], linewidth=0.75, color='tab:blue', alpha=alpha[rlx])
            ax.set_title('')
            ax.set_xlabel(labels[var])
            ax.set_ylabel('')
            ax.text(abc_x[lr[abc_loc[var][k][1]]], abc_y[tb[abc_loc[var][k][0]]], '({:s})'.format(abc[var][k*2+j]), transform=ax.transAxes, va=tb[abc_loc[var][k][0]], ha=lr[abc_loc[var][k][1]])
        axins.set_title('')
        axins.set_xlabel('')
        axins.set_ylabel('')
        axins.set_xlim(inset_xlim[rkey])
        axins.set_ylim([-1,-0.9])
        axins.tick_params(labelleft=False, labelbottom=False)

for i in np.arange(4):
    ax = axarr[i,0]
    ax.set_ylabel('$z/H$')
    ax.set_ylim([-1,0])
    ax = axarr[0,i]
    ax.set_title(titles[i%2], fontsize=9)

for i in np.arange(2):
    ax = axarr[0,i*2]
    ax.text(1.1, 1.05, tags[i], transform=ax.transAxes, va='bottom', ha='center')
        
plt.subplots_adjust(top=0.95, bottom=0.08, left=0.08, right=0.97, hspace=0.35, wspace=0.15)
figname = os.path.join(figpath, 'profiles-v2-relax-{:s}'.format(turbmethod))
fig.savefig(figname, dpi = 300, facecolor='w')

In [None]:
edges1 = {}
edges2 = {}
for rlx in rlxlist:
    edges1[rlx], edges2[rlx] = [get_edges(nondim_da((ds.data_vars['temp'].differentiate(coord='z')*alphaT*g/N2)[:,:,0,0], H=H, Tf=Ti)) for ds in [ds_gotm['r1'][rlx], ds_gotm['r2'][rlx]]]

In [None]:
colors = dict(
    Rlx0      = 'darkorange',
    Rlx60     = 'forestgreen',
    Rlx600    = 'steelblue',
    Rlx3600   = 'firebrick',
    Rlx86400  = 'rebeccapurple',
)
labels = {rlx:rlx for rlx in rlxlist}

xmax = {
    'SMCLT-H15': 10,
    'SMCLT-KC04': 17,
    'KPPLT-LF17': 17,
}
abc = ['ad', 'be', 'cf']
tags = ['Aligned', 'Opposite']
tagy = ['bottom', 'top']
texty = {'top': 0.92, 'bottom': 0.08}
xlabels = ['$\overline{u}/u_*$', '$\overline{w^\prime u^\prime}/u_*^2$']
line_kwargs=dict(linestyle='-', linewidth=1)
line_kwargs1 = dict(linestyle='--', linewidth=0.75)
line_kwargs2 = dict(linestyle='-', linewidth=0.75)
fig, axarr = plt.subplots(2, 3, sharey='row', gridspec_kw={'width_ratios': [3, 1, 1]})
fig.set_size_inches(8,5)
dss_pfls = [ds_gotm['r1'], ds_gotm['r2']]
for k, edges in enumerate([edges1, edges2]):
    ax = axarr[k,0]
    for rlx in rlxlist:
        edges[rlx][0].rolling(time=5, center=True).mean().plot(ax=ax, color=colors[rlx], **line_kwargs, label=labels[rlx])
        edges[rlx][1].rolling(time=5, center=True).mean().plot(ax=ax, color=colors[rlx], **line_kwargs)
    ax.set_ylabel('$z/H$')
    ax.set_ylim([-1, 0])
    ax.set_xlim([0, xmax[turbmethod]])
    ax.text(0.05, texty['top'], '({:s})'.format(abc[0][k]), transform=ax.transAxes, va='top', ha='left')
    ax.text(0.95, texty[tagy[k]], tags[k], transform=ax.transAxes, va=tagy[k], ha='right')
    if k == 0:
        ax.legend(loc='upper right', ncol=2, fontsize=9)
        ax.set_xlabel('')
    else:
        ax.set_xlabel('$t/T_f$')
    
    for rlx in rlxlist:
        time_merge = edges[rlx][0].dropna(dim='time').time[-1]
        tslice1 = slice(time_merge - 3, time_merge - 2)
        tslice2 = slice(time_merge + 4, time_merge + 5)
        ds = dss_pfls[k][rlx]
        da = nondim_da((ds.data_vars['u']+ds.data_vars['us']).squeeze()/ustar, H=H, Tf=Ti)
        da.sel(time=tslice1).mean(dim='time').plot(ax=axarr[k,1], y=da.dims[0], color=colors[rlx], **line_kwargs1)
        da.sel(time=tslice2).mean(dim='time').plot(ax=axarr[k,1], y=da.dims[0], color=colors[rlx], **line_kwargs2)
        da = nondim_da((get_flux(ds.data_vars['u'].squeeze(),
                       ds.data_vars['num'].squeeze(),
                       ds.data_vars['gamu'].squeeze())
              -ds.data_vars['nucl'].squeeze()*ds.data_vars['dusdz'].squeeze())/ustar**2, H=H, Tf=Ti)
        da.sel(time=tslice1).mean(dim='time').plot(ax=axarr[k,2], y=da.dims[0], color=colors[rlx], **line_kwargs1)
        da.sel(time=tslice2).mean(dim='time').plot(ax=axarr[k,2], y=da.dims[0], color=colors[rlx], **line_kwargs2)

    for i in np.arange(2):
        ax = axarr[k,i+1]
        ax.set_title('')
        ax.set_ylabel('')
        ax.text(0.95, texty['top'], '({:s})'.format(abc[i+1][k]), transform=ax.transAxes, va='top', ha='right')
        if k == 0:
            ax.set_xlabel('')
        else:
            ax.set_xlabel(xlabels[i])

l1, = plt.plot(np.nan,np.nan,'k',**line_kwargs1)
l2, = plt.plot(np.nan,np.nan,'k',**line_kwargs2)
axarr[1,1].legend([l1,l2],['T1','T2'], loc='center right', fontsize=9)

plt.tight_layout()
plt.subplots_adjust(wspace=0.12)
figname = os.path.join(figpath, 'relax-sensitivity_gotm-v2-{:s}'.format(turbmethod))
fig.savefig(figname, dpi = 300, facecolor='w')