In [None]:
import sys
import os
import matplotlib.pyplot as plt
from matplotlib.dates import DateFormatter
import numpy as np
import pandas as pd
import xarray as xr
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
from lesview import *
from sbl_bbl import *

In [None]:
g = 9.81
H = 30
u10 = 8
N2 = 1.962e-4
bstar = N2 * H
cd = 1.25e-3
rhoa = 1.225
rhoo = 1026
lat = 45.
tau = rhoa/rhoo*cd*u10*u10
ustar = np.sqrt(tau)
amplitude = 1.0
wavelength = 60
wavenumber = 2.*np.pi/wavelength
frequency = np.sqrt(g*wavenumber*np.tanh(wavenumber*H))
us0 = amplitude**2*wavenumber
la = np.sqrt(ustar/us0) 
print('La = {:6.3f}'.format(la))
Ti = inertial_period(lat)
print('Ti = {:6.3f}'.format(Ti))

In [None]:
reslist = ['n128l72', 'n256l72', 'n128l144', 'n256l144', 'n256l288', 'n512l288']

In [None]:
casename = 'lsc_ymc22_sbl_bbl_v2'
figpath  = 'overview_{:s}'.format(casename)
os.makedirs(figpath, exist_ok=True)

In [None]:
ds1_pfls = {}
ds2_pfls = {}
for res in reslist:
    datapath = os.path.join(os.path.pardir, 'tests', '{:s}_{:s}'.format(casename, res))
    filepath = os.path.join(datapath, 'averages.jld2')
    ds1_pfls[res] = OceananigansDataProfile(filepath=filepath).dataset
    datapath = os.path.join(os.path.pardir, 'tests', '{:s}_rf_{:s}'.format(casename, res))
    filepath = os.path.join(datapath, 'averages.jld2')
    ds2_pfls[res] = OceananigansDataProfile(filepath=filepath).dataset

In [None]:
def get_das(ds):
    das = dict(
        NN = ds.data_vars['b'].differentiate(coord='z')/N2,
        SS = (ds.data_vars['u'].differentiate(coord='z')**2+ds.data_vars['v'].differentiate(coord='z')**2)/N2,
        Ri = ds.data_vars['b'].differentiate(coord='z')/(ds.data_vars['u'].differentiate(coord='z')**2+ds.data_vars['v'].differentiate(coord='z')**2),
        wb = (ds.data_vars['wb']+ds.data_vars['wbsb'])/ustar/bstar*1e3,
        wu = (ds.data_vars['wu']+ds.data_vars['wusb'])/ustar**2,
        wv = (ds.data_vars['wv']+ds.data_vars['wvsb'])/ustar**2,
    )
    return das

In [None]:
labels = dict(
    NN = '$N^2/N_0^2$',
    SS = '$S^2/N_0^2$',
    Ri = '$N^S/S^2$',
    wb = '$10^3\overline{w^\prime b^\prime}/u_*b_*$',
    wu = '$\overline{w^\prime u^\prime}/u_*^2$',
    wv = '$\overline{w^\prime v^\prime}/u_*^2$',
)
levels = dict(
    NN = np.linspace(0, 4, 41),
    SS = np.linspace(0, 8, 41),
    Ri = np.linspace(0, 1, 5),
    wb = np.linspace(-4, 4, 41),
    wu = np.linspace(-2, 2, 41),
    wv = np.linspace(-1.2, 1.2, 41),
)
cmaps = dict(
    NN = 'pink_r',
    SS = 'pink_r',
    Ri = 'pink_r',
    wb = 'RdBu_r',
    wu = 'RdBu_r',
    wv = 'RdBu_r',
)
abc = dict(
    NN = 'ag',
    SS = 'bh',
    Ri = 'ci',
    wb = 'dj',
    wu = 'ek',
    wv = 'fl',
)
tags = ['Aligned', 'Opposite']
nv = int(len(labels))
rlcolor = {'RdBu_r': 'k', 'viridis': 'w', 'bone': 'w', 'bone_r': 'k', 'pink_r': 'k'}
line_kwargs = dict(color='k', linewidth=0.75)
cf = {}
for res in reslist:
# for res in ['n256l288']:
    fig, axarr = plt.subplots(nv, 2, sharex='col')
    fig.set_size_inches([10, 0.4+1.5*nv])
    for k, ds_pfls in enumerate([ds1_pfls, ds2_pfls]):
        ds = ds_pfls[res]
        das = nondim_das(get_das(ds), H=H, Tf=Ti)
        z0, z1, tstart1, tend1, tstart2, tend2 = get_merge(das['NN'], nondim=True, Tf=Ti)
        for i, var in enumerate(das.keys()):
            ax = axarr[i, k]
            cmap = cmaps[var]
            cf[var] = das[var].plot(ax=ax, levels=levels[var], cmap=cmap, add_colorbar=False)
            ax.axvline(x=tstart1, linestyle=':', color=rlcolor[cmap])
            ax.axvline(x=tend1,   linestyle=':', color=rlcolor[cmap])
            ax.axvline(x=tstart2, linestyle=':', color=rlcolor[cmap])
            ax.axvline(x=tend2,   linestyle=':', color=rlcolor[cmap])
            if i == 0:
                ax.text(tstart1, 0, 'T1', va='bottom', ha='left')
                ax.text(tstart2, 0, 'T2', va='bottom', ha='left')
            z0.plot(ax=ax, **line_kwargs)
            z1.plot(ax=ax, **line_kwargs)
            ax.text(0.98, 0.92, '({:s})'.format(abc[var][k]), transform=ax.transAxes, va='top', ha='right')
            ax.set_title('')
            ax.set_xlim([0,16])
            if i == nv-1:
                ax.set_xlabel('$t/T_f$')
            else:
                ax.set_xlabel('')
            if k == 0:
                ax.set_ylabel('$z/H$')
            else:
                ax.set_ylabel('')
        axarr[0,k].set_title(tags[k])
    
    plt.subplots_adjust(top=0.97, bottom=0.06, left=0.08, right=0.87, hspace=0.15, wspace=0.15)
    for i, var in enumerate(das.keys()):
        ax = axarr[i,1]
        pos = ax.get_position()
        cax = plt.axes([0.9, pos.y0, 0.015, pos.height])
        cb = plt.colorbar(cf[var], cax=cax)
        cb.set_label(labels[var])
    figname = os.path.join(figpath, 'mean_v2-{:s}'.format(res))
    fig.savefig(figname, dpi = 300, facecolor='w')


In [None]:
edges1 = {}
edges2 = {}
for res in reslist:
    edges1[res], edges2[res] = [get_edges(nondim_da(ds.data_vars['b'].differentiate(coord='z')/N2, H=H, Tf=Ti)) for ds in [ds1_pfls[res], ds2_pfls[res]]]

In [None]:
colors = dict(
    n128l72   = 'skyblue',
    n256l72   = 'steelblue',
    n128l144  = 'lightcoral',
    n256l144  = 'firebrick',
    n256l288  = 'mediumpurple',
    n512l288  = 'rebeccapurple',
)
labels = dict(
    n128l72   = '$128^2\\times72$',
    n256l72   = '$256^2\\times72$',
    n128l144  = '$128^2\\times144$',
    n256l144  = '$256^2\\times144$',
    n256l288  = '$256^2\\times288$',
    n512l288  = '$512^2\\times288$',
)
abc = ['ad', 'be', 'cf']
tags = ['Aligned', 'Opposite']
tagy = ['bottom', 'top']
texty = {'top': 0.92, 'bottom': 0.08}
line_kwargs=dict(linestyle='-', linewidth=1)
fig, axarr = plt.subplots(2, 3, sharex='col', sharey='row', gridspec_kw={'width_ratios': [3, 1, 1]})
fig.set_size_inches(8,5)
dss_pfls = [ds1_pfls, ds2_pfls]
for k, edges in enumerate([edges1, edges2]):
    ax = axarr[k,0]
    for res in reslist:
        edges[res][0].rolling(time=5, center=True).mean().plot(ax=ax, color=colors[res], **line_kwargs, label=labels[res])
        edges[res][1].rolling(time=5, center=True).mean().plot(ax=ax, color=colors[res], **line_kwargs)
    ax.set_ylabel('$z/H$')
    ax.set_ylim([-1, 0])
    ax.set_xlim([0, 16])
    ax.text(0.05, texty['top'], '({:s})'.format(abc[0][k]), transform=ax.transAxes, va='top', ha='left')
    ax.text(0.95, texty[tagy[k]], tags[k], transform=ax.transAxes, va=tagy[k], ha='right')
    if k == 0:
        ax.legend(loc='upper right', ncol=3, fontsize=8)
        ax.set_xlabel('')
    else:
        ax.set_xlabel('$t/T_f$')
        
    for res in reslist:
        time_merge = edges[res][0].dropna(dim='time').time[-1]
        tslice1 = slice(time_merge - 3, time_merge - 2)
        tslice2 = slice(time_merge + 4, time_merge + 5)
        ds = dss_pfls[k][res]
        da = nondim_da((ds.data_vars['wu']+ds.data_vars['wusb'])/ustar**2, H=H, Tf=Ti)
        da.sel(time=tslice1).mean(dim='time').plot(ax=axarr[k,1], y=da.dims[0], color=colors[res], **line_kwargs)
        da.sel(time=tslice2).mean(dim='time').plot(ax=axarr[k,2], y=da.dims[0], color=colors[res], **line_kwargs)

    for i in np.arange(2):
        ax = axarr[k,i+1]
        ax.set_ylabel('')
        ax.text(0.95, texty['top'], '({:s})'.format(abc[i+1][k]), transform=ax.transAxes, va='top', ha='right')
        if k == 0:
            ax.set_xlabel('')
            ax.set_title('T{:g}'.format(i+1), fontsize=10)
        else:
            ax.set_xlabel('$\overline{w^\prime u^\prime}/u_*^2$')
    
plt.tight_layout()
plt.subplots_adjust(wspace=0.12)
figname = os.path.join(figpath, 'resolution-sensitivity_v2')
fig.savefig(figname, dpi = 300, facecolor='w')