In [None]:
import sys
import os
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
import numpy as np
import pandas as pd
import xarray as xr
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
from lesview import *
from sbl_bbl import *

In [None]:
g = 9.81
H = 30
u10 = 8
N2 = 1.962e-4
bstar = N2 * H
cd = 1.25e-3
rhoa = 1.225
rhoo = 1026
lat = 45.
tau = rhoa/rhoo*cd*u10*u10
ustar = np.sqrt(tau)
amplitude = 1.0
wavelength = 60
wavenumber = 2.*np.pi/wavelength
frequency = np.sqrt(g*wavenumber*np.tanh(wavenumber*H))
us0 = amplitude**2*wavenumber
la = np.sqrt(ustar/us0) 
print('La = {:6.3f}'.format(la))
Ti = inertial_period(lat)
print('Ti = {:6.3f}'.format(Ti))

In [None]:
reslist = ['n128l72', 'n256l72', 'n128l144', 'n256l144', 'n512l288']

In [None]:
casename = 'lsc_ymc22_sbl_bbl_v2_rf'
figpath  = 'overview_{:s}'.format(casename)
os.makedirs(figpath, exist_ok=True)

In [None]:
ds_pfls = {}
for res in reslist:
    datapath = os.path.join(os.path.pardir, 'tests', '{:s}_{:s}'.format(casename, res))
    filepath = os.path.join(datapath, 'averages.jld2')
    ds_pfls[res] = OceananigansDataProfile(filepath=filepath).dataset

In [None]:
labels = dict(
    NN = '$N^2/N_0^2$',
    SS = '$S^2/N_0^2$',
    Ri = '$N^S/S^2$',
    wb = '$10^3\overline{w^\prime b^\prime}/u_*b_*$',
    wu = '$\overline{w^\prime u^\prime}/u_*^2$',
    wv = '$\overline{w^\prime v^\prime}/u_*^2$',
)
levels = dict(
    NN = np.linspace(0, 3.2, 41),
    SS = np.linspace(0, 6.4, 41),
    Ri = np.linspace(0, 1, 5),
    wb = np.linspace(-4, 4, 41),
    wu = np.linspace(-2, 2, 41),
    wv = np.linspace(-1.2, 1.2, 41),
)
cmaps = dict(
    NN = 'bone_r',
    SS = 'bone_r',
    Ri = 'bone_r',
    wb = 'RdBu_r',
    wu = 'RdBu_r',
    wv = 'RdBu_r',
)
loc = dict(
    NN = [0,0],
    SS = [1,0],
    Ri = [2,0],
    wb = [0,1],
    wu = [1,1],
    wv = [2,1],
)
xylabel = dict(
    NN = [False, True],
    SS = [False, True],
    Ri = [True, True],
    wb = [False, False],
    wu = [False, False],
    wv = [True, False],
)
abc = dict(
    NN = '(a)',
    SS = '(b)',
    Ri = '(c)',
    wb = '(d)',
    wu = '(e)',
    wv = '(f)',
)
rlcolor = {'RdBu_r': 'k', 'viridis': 'w', 'bone': 'w', 'bone_r': 'k'}
for res in reslist:
    ds = ds_pfls[res]
    das = dict(
        NN = ds.data_vars['b'].differentiate(coord='z')/N2,
        SS = (ds.data_vars['u'].differentiate(coord='z')**2+ds.data_vars['v'].differentiate(coord='z')**2)/N2,
        Ri = ds.data_vars['b'].differentiate(coord='z')/(ds.data_vars['u'].differentiate(coord='z')**2+ds.data_vars['v'].differentiate(coord='z')**2),
        wb = (ds.data_vars['wb']+ds.data_vars['wbsb'])/ustar/bstar*1e3,
        wu = (ds.data_vars['wu']+ds.data_vars['wusb'])/ustar**2,
        wv = (ds.data_vars['wv']+ds.data_vars['wvsb'])/ustar**2,
    )
    z0, z1 = get_edges(ds.data_vars['b'].differentiate(coord='z')/N2)
    line_kwargs=dict(color='k', linewidth=0.75)

    time_merge = z0.dropna(dim='time').time[-1]
    tstart1, tend1 = time_merge - pd.Timedelta(Ti, 's') - pd.Timedelta(1.5, 'D'), time_merge - pd.Timedelta(1.5, 'D')
    tstart2, tend2 = time_merge + pd.Timedelta(3.5, 'D'), time_merge + pd.Timedelta(3.5, 'D') + pd.Timedelta(Ti, 's') 
    
    nv = int(len(das)/2)
    fig, axarr = plt.subplots(nv, 2, sharex='col')
    fig.set_size_inches([10, 0.4+1.5*nv])
    date_form = DateFormatter("%d")
    for i, var in enumerate(das.keys()):
        al = loc[var]
        ax = axarr[al[0], al[1]]
        cmap = cmaps[var]
        cf = das[var].plot(ax=ax, levels=levels[var], cmap=cmap, cbar_kwargs={'label': labels[var]})
        ax.axvline(x=pd.to_datetime(tstart1.data), linestyle=':', color=rlcolor[cmap])
        ax.axvline(x=pd.to_datetime(tend1.data),   linestyle=':', color=rlcolor[cmap])
        ax.axvline(x=pd.to_datetime(tstart2.data), linestyle=':', color=rlcolor[cmap])
        ax.axvline(x=pd.to_datetime(tend2.data),   linestyle=':', color=rlcolor[cmap])
        if al[0] == 0:
            ax.text(pd.to_datetime(tstart1.data), 0, 'T1', va='bottom', ha='left')
            ax.text(pd.to_datetime(tstart2.data), 0, 'T2', va='bottom', ha='left')
        z0.plot(ax=ax, **line_kwargs)
        z1.plot(ax=ax, **line_kwargs)
        ax.xaxis.set_major_formatter(date_form)
        for lb in ax.get_xticklabels(which='major'):
            lb.set(rotation=0, horizontalalignment='right')
        ax.text(0.98, 0.92, abc[var], transform=ax.transAxes, va='top', ha='right')
        ax.set_title('')
        if xylabel[var][0]:
            ax.set_xlabel('Time [day]')
        else:
            ax.set_xlabel('')
        if xylabel[var][1]:
            ax.set_ylabel('Depth [m]')
        else:
            ax.set_ylabel('')

    plt.tight_layout()
    plt.subplots_adjust(hspace=0.15, wspace=0.12)
    figname = os.path.join(figpath, 'mean_{:s}'.format(res))
    fig.savefig(figname, dpi = 300, facecolor='w')

In [None]:
edges = {}
for res in reslist:
    ds = ds_pfls[res]
    edges[res] = get_edges(ds.data_vars['b'].differentiate(coord='z')/N2)

In [None]:
colors = dict(
    n128l72   = 'skyblue',
    n256l72   = 'steelblue',
    n128l144  = 'lightcoral',
    n256l144  = 'firebrick',
    n512l288  = 'rebeccapurple',
)
labels = dict(
    n128l72   = '$128^2\\times72$',
    n256l72   = '$256^2\\times72$',
    n128l144  = '$128^2\\times144$',
    n256l144  = '$256^2\\times144$',
    n512l288  = '$512^2\\times288$',
)
if 'rf' in casename:
    abc = ['(d)', '(e)', '(f)']
else:
    abc = ['(a)', '(b)', '(c)']
line_kwargs=dict(linestyle='-', linewidth=1)
fig, axarr = plt.subplots(1, 3, sharey='row', gridspec_kw={'width_ratios': [3, 1, 1]})
fig.set_size_inches(8,2.5)
ax = axarr[0]
for res in reslist:
    edges[res][0].rolling(time=5, center=True).mean().plot(ax=ax, color=colors[res], **line_kwargs, label=labels[res])
    edges[res][1].rolling(time=5, center=True).mean().plot(ax=ax, color=colors[res], **line_kwargs)
ax.legend(loc='upper right', ncol=3, fontsize=8)
ax.set_ylabel('Depth [m]')
ax.set_ylim([-30, 0])
ax.set_xlim([pd.Timestamp('2000-01-01T00:00:00'), pd.Timestamp('2000-01-09T00:00:00')])
ax.set_xlabel('Time [day]')
ax.text(0.05, 0.92, abc[0], transform=ax.transAxes, va='top', ha='left')
ax.xaxis.set_major_formatter(date_form)
for lb in ax.get_xticklabels(which='major'):
    lb.set(rotation=0, horizontalalignment='right')

for res in reslist:
    time_merge = edges[res][0].dropna(dim='time').time[-1]
    tslice1 = slice(time_merge - pd.Timedelta(Ti, 's') - pd.Timedelta(1.5, 'D'), time_merge - pd.Timedelta(1.5, 'D'))
    tslice2 = slice(time_merge + pd.Timedelta(3.5, 'D'), time_merge + pd.Timedelta(3.5, 'D') + pd.Timedelta(Ti, 's'))
    ds = ds_pfls[res]
    da = (ds.data_vars['wu']+ds.data_vars['wusb'])/ustar**2
    da.sel(time=tslice1).mean(dim='time').plot(ax=axarr[1], y=da.dims[0], color=colors[res], **line_kwargs)
    da.sel(time=tslice2).mean(dim='time').plot(ax=axarr[2], y=da.dims[0], color=colors[res], **line_kwargs)

for i in np.arange(2):
    ax = axarr[i+1]
    ax.set_ylabel('')
    ax.set_xlabel('$\overline{w^\prime u^\prime}/u_*^2$')
    ax.text(0.95, 0.92, abc[i+1], transform=ax.transAxes, va='top', ha='right')
    ax.text(0.1, 0.1, 'T{:g}'.format(i+1), transform=ax.transAxes, va='bottom', ha='left')
    
plt.tight_layout()
plt.subplots_adjust(wspace=0.12)
figname = os.path.join(figpath, 'resolution-sensitivity')
fig.savefig(figname, dpi = 300, facecolor='w')