In [None]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
from lesview import *

In [None]:
ctname = 'CT'
Qt0s = dict(
    CT = 1.221e-5,
    CT05 = 1.221e-6,
    CT5h = 1.221e-4,
)
depths = dict(
    CT = 50,
    CT05 = 40,
    CT5h = 80,
)
w2maxs = dict(
    CT = 0.6,
    CT05 = 0.6,
    CT5h = 0.8,
)
casenames = dict(
    c1 = 'R11_MSM97-{:s}_f0_fixdt3'.format(ctname),
    c2 = 'R11_MSM97-{:s}_Stokes6_f0_fixdt3'.format(ctname),
    c3 = 'R11_MSM97-{:s}_Stokes_f0_fixdt3'.format(ctname),
    c4 = 'R11_MSM97-{:s}_Stokes5_f0_fixdt3'.format(ctname),
)
amp = dict(
    c1 = 0.0,
    c2 = 0.45,
    c3 = 0.8,
    c4 = 1.423,
)
linestyles = [':','-.','--','-']
colors = ['tab:blue', 'tab:green', 'tab:orange', 'tab:red']
labels = ['$a$ = {:4.2f} m'.format(amp[cn]) for cn in amp.keys()]
iend = 57600
# iend = 28800
filename = 'his.mp.vis.000001.{:06d}.nc'.format(iend+1)
figpath = 'R11_MSM97_{:s}_Stokes_f0'.format(ctname)
os.makedirs(figpath, exist_ok=True)

In [None]:
labels

In [None]:
ds = {}
for cn in casenames.keys():
    print(cn)
    datapath = os.path.join(os.path.pardir, 'ncarles', casenames[cn], filename)
    ds[cn] = NCARLESDataProfile(filepath=datapath).dataset        

In [None]:
ds['c1']

In [None]:
g = 9.81
rho = 1026
cp = 3991
alpha = 2e-4
NNT0 = 0.01
NN0 = alpha*g*NNT0
Qt0 = Qt0s[ctname]
h0 = 33
wstar = (alpha*g*Qt0*h0)**(1/3)
depth = depths[ctname]
def get_vars(ds, var):
    if var == 'uL':
        return (ds.data_vars['uxym']+ds.data_vars['stokes'])/wstar
    elif var == 'uE':
        return ds.data_vars['uxym']/wstar
    elif var == 'uS':
        return ds.data_vars['stokes']/wstar
    elif var == 'NN':
        return ds.data_vars['txym'].differentiate(coord='z')/NNT0
    elif var == 'wNN':
        return -ds.data_vars['uxym'].differentiate(coord='z')*ds.data_vars['stokes'].differentiate(coord='z')/NN0
    elif var == 'tNN':
        return ds.data_vars['txym'].differentiate(coord='z')/NNT0-(ds.data_vars['uxym']-ds.data_vars['stokes']).differentiate(coord='z')*ds.data_vars['stokes'].differentiate(coord='z')/NN0
    elif var == 'wt':
        # return (ds.data_vars['wtle']+ds.data_vars['wtsb'])/Qt0
        return ds.data_vars['wtle']/Qt0
    elif var == 'wtsb':
        return ds.data_vars['wtsb']/Qt0
    elif var == 'ww':
        return ds.data_vars['wps']/wstar**2
    elif var == 'uu':
        return ds.data_vars['ups']/wstar**2
    elif var == 'vv':
        return ds.data_vars['vps']/wstar**2
    elif var == 'wu':
        return ds.data_vars['uwle']/wstar**2
    elif var == 'wusb':
        return ds.data_vars['uwsb']/wstar**2
    elif var == 'wv':
        return ds.data_vars['vwle']/wstar**2
    elif var == 'sk':
        return ds.data_vars['wcube']/ds.data_vars['wps']**1.5
    elif var == 'tke':
        return 0.5*(ds.data_vars['ups']+ds.data_vars['vps']+ds.data_vars['wps'].interp(zi=ds.z))/wstar**2
    else:
        raise ValueError('Variable {:s} not found.'.format(var))

In [None]:
# tslice = slice('2000-01-01T16:00:00', '2000-01-02T00:00:00')
tslice = slice('2000-01-01T12:00:00', '2000-01-02T00:00:00')

In [None]:
fig, axarr = plt.subplots(2, 3, sharey='row')
fig.set_size_inches(8,6)
variables = ['uL', 'NN', 'wt', 'ww', 'uu', 'sk']
xlabels = dict(
    uL = '$\overline{u}^L/w_*$, $u^S/w_*$',
    NN = '$N^2/N^2_0$, $N_*^2/N^2_0$',
    wt = '$\overline{w^\prime b^\prime}/B_0$, $q^{sgs}_b/B_0$',
    ww = '$\overline{{w^\prime}^2}/w_*^2$, TKE$/w_*^2$',
    sk = '$\overline{{w^\prime}^3}/(\overline{{w^\prime}^2})^{3/2}$',
    uu = '$\overline{{u^\prime}^2}/w_*^2$, $\overline{{v^\prime}^2}/w_*^2$',
)
xlims = dict(
    uL = [-2,3],
    NN = [-1.5,2.5],
    wt = [-0.3,1.3],
    ww = [0,w2maxs[ctname]],
    sk = [-3.5,1.5],
    uu = [0,w2maxs[ctname]],
)
lr = 'aaaaaa'
tb = 'bbbaaa'
leftright = {'a': 'left', 'b': 'right'}
topbottom = {'a': 'top', 'b': 'bottom'}
abc = 'abcdef'
xabc = {'a': 0.05, 'b': 0.95}
yabc = {'a': 0.95, 'b': 0.05}
for k, var in enumerate(variables):
    m = k//3
    n = k%3
    ax = axarr[m,n]
    ax.axvline(x=0, color='k', linewidth=0.75)
    if var == 'uL':
        for i, cn in enumerate(casenames.keys()):
            da = get_vars(ds[cn], 'uS')
            da.sel(time=tslice).mean(dim='time').plot(ax=ax,y='z',color=colors[i],linestyle='--')
            # da = get_vars(ds[cn], 'uE')
            # da.sel(time=tslice).mean(dim='time').plot(ax=ax,y='z',color=colors[i],linestyle=':')
    elif var == 'NN':
        for i, cn in enumerate(casenames.keys()):
            da = get_vars(ds[cn], 'tNN')
            da.sel(time=tslice).mean(dim='time').plot(ax=ax,y='z',color=colors[i],linestyle='--')
    elif var == 'ww':
        for i, cn in enumerate(casenames.keys()):
            da = get_vars(ds[cn], 'tke')
            da.sel(time=tslice).mean(dim='time').plot(ax=ax,y='z',color=colors[i],linestyle='--')
    elif var == 'uu':
        for i, cn in enumerate(casenames.keys()):
            da = get_vars(ds[cn], 'vv')
            da.sel(time=tslice).mean(dim='time').plot(ax=ax,y='z',color=colors[i],linestyle='--')
    elif var == 'wt':
        for i, cn in enumerate(casenames.keys()):
            da = get_vars(ds[cn], 'wtsb')
            da.sel(time=tslice).mean(dim='time').plot(ax=ax,y='zi',color=colors[i],linestyle='--')
    for i, cn in enumerate(casenames.keys()):
        da = get_vars(ds[cn], var)
        if 'z' in da.coords:
            yaxis = 'z'
        elif 'zi' in da.coords:
            yaxis = 'zi'
        da.sel(time=tslice).mean(dim='time').plot(ax=ax,y=yaxis,color=colors[i])
    ax.set_xlim(xlims[var])
    ax.set_ylim([-depth,0])
    ax.set_ylabel('')
    ax.set_xlabel(xlabels[var])
    ax.text(xabc[lr[k]], yabc[tb[k]], '({:s})'.format(abc[k]), transform=ax.transAxes, va=topbottom[tb[k]], ha=leftright[lr[k]])
    ax.grid()
    
for i in np.arange(2):
    axarr[i,0].set_ylabel('$z$ [m]')

    ax = axarr[0,0]
ax.plot(np.nan, np.nan, color='k', label='$\overline{u}^L/w_*$')
ax.plot(np.nan, np.nan, color='k', linestyle='--', label='$u^S/w_*$')
ax.legend(loc='lower right')

ax = axarr[0,1]
ax.plot(np.nan, np.nan, color='k', label='$N^2/N^2_0$')
ax.plot(np.nan, np.nan, color='k', linestyle='--', label='$N_*^2/N^2_0$')
ax.legend(loc='center right')

ax = axarr[0,2]
ax.plot(np.nan, np.nan, color='k', label='$\overline{w^\prime b^\prime}/B_0$')
ax.plot(np.nan, np.nan, color='k', linestyle='--', label='$q^{sgs}_b/B_0$')
ax.legend(loc='lower right')

ax = axarr[1,0]
ax.plot(np.nan, np.nan, color='k', label='$\overline{{w^\prime}^2}/w_*^2$')
ax.plot(np.nan, np.nan, color='k', linestyle='--', label='TKE$/w_*^2$')
ax.legend(loc='lower right')

ax = axarr[1,1]
ax.plot(np.nan, np.nan, color='k', label='$\overline{{u^\prime}^2}/w_*^2$')
ax.plot(np.nan, np.nan, color='k', linestyle='--', label='$\overline{{v^\prime}^2}/w_*^2$')
ax.legend(loc='lower right')

ax = axarr[1,-1]
for i, cn in enumerate(casenames.keys()):
    ax.plot(np.nan, np.nan, color=colors[i], label=labels[i])
ax.legend(ncol=4, loc='lower right', bbox_to_anchor=(0.7, -0.45))

# plt.tight_layout()
plt.subplots_adjust(top=0.97, bottom=0.18, left=0.09, right=0.97, hspace=0.3, wspace=0.15)
figname = os.path.join(figpath, 'profiles-ncar-cmp-wave-stratification')
fig.savefig(figname, dpi = 300, facecolor='w')


In [None]:
fig, axarr = plt.subplots(4, 2, sharex='col', sharey='row')
fig.set_size_inches(8,5)
levels = np.linspace(-0.02,0.02,41)
varnames = ['wu', 'wv']
titles = ['$\overline{w^\prime u^\prime}/w_*^2$', '$\overline{w^\prime v^\prime}/w_*^2$']
labels = ['$a$ = {:4.2f} m'.format(amp) for amp in [0,0.45,0.8,1.42]]
abc = ['abcd', 'efjh']
for i, cn in enumerate(casenames.keys()):
    for j, var in enumerate(varnames):
        ax = axarr[i,j]
        da = get_vars(ds[cn], var)
        time_dtime = pd.to_datetime(da.coords['time'].data)
        time_sec = (time_dtime-time_dtime[0]).total_seconds()
        time_hr = time_sec/3600
        da_new = da.assign_coords({'time': time_hr})
        im = da_new.plot(ax=ax, levels=levels, extend='both', add_colorbar=False)
        ax.set_ylim([-depth,0])
        ax.set_xlabel('')
        ax.set_ylabel('')
        if i == 0:
            ax.set_title(titles[j], fontsize=10)
        elif i == 3:
            ax.set_xlabel('Time [hour]')
        if j == 0:
            ax.text(0.03, 0.1, '({:s}) {:s}'.format(abc[j][i], labels[i]), transform=ax.transAxes, va='bottom', ha='left')
        else:
            ax.text(0.03, 0.1, '({:s})'.format(abc[j][i]), transform=ax.transAxes, va='bottom', ha='left')
    ax = axarr[i,0]
    ax.set_ylabel('$z$ [m]')
plt.subplots_adjust(top=0.95, bottom=0.24, left=0.09, right=0.98, hspace=0.2, wspace=0.06)
cax = plt.axes([0.25, 0.11, 0.5, 0.02])
cb = plt.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.tick_params(rotation=30)
figname = os.path.join(figpath, 'profiles-ncar-momentum-flux')
fig.savefig(figname, dpi = 300, facecolor='w')

In [None]:
fig, axarr = plt.subplots(4, 2, sharex='col', sharey='row')
fig.set_size_inches(8,5)
levels = np.linspace(-0.02,0.02,41)
varnames = ['wu', 'wusb']
titles = ['$\overline{w^\prime u^\prime}/w_*^2$', '$\overline{w^\prime u^\prime}^{sgs}/w_*^2$']
labels = ['$a$ = {:4.2f} m'.format(amp) for amp in [0,0.45,0.8,1.42]]
abc = ['abcd', 'efjh']
for i, cn in enumerate(casenames.keys()):
    for j, var in enumerate(varnames):
        ax = axarr[i,j]
        da = get_vars(ds[cn], var)
        time_dtime = pd.to_datetime(da.coords['time'].data)
        time_sec = (time_dtime-time_dtime[0]).total_seconds()
        time_hr = time_sec/3600
        da_new = da.assign_coords({'time': time_hr})
        im = da_new.plot(ax=ax, levels=levels, extend='both', add_colorbar=False)
        ax.set_ylim([-depth,0])
        ax.set_xlabel('')
        ax.set_ylabel('')
        if i == 0:
            ax.set_title(titles[j], fontsize=10)
        elif i == 3:
            ax.set_xlabel('Time [hour]')
        if j == 0:
            ax.text(0.03, 0.1, '({:s}) {:s}'.format(abc[j][i], labels[i]), transform=ax.transAxes, va='bottom', ha='left')
        else:
            ax.text(0.03, 0.1, '({:s})'.format(abc[j][i]), transform=ax.transAxes, va='bottom', ha='left')
    ax = axarr[i,0]
    ax.set_ylabel('$z$ [m]')
plt.subplots_adjust(top=0.95, bottom=0.24, left=0.09, right=0.98, hspace=0.2, wspace=0.06)
cax = plt.axes([0.25, 0.11, 0.5, 0.02])
cb = plt.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.tick_params(rotation=30)
figname = os.path.join(figpath, 'profiles-ncar-momentum-flux-x')
fig.savefig(figname, dpi = 300, facecolor='w')

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-2,2,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'uL')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')    
ax2.set_ylim([-depth,0])
# ax3.set_ylim([-0.1,0.1])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'uS')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_ylim([-depth,0])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-2,2,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'wNN')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-2,2,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'tNN')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-2,2,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'NN')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(0.,0.6,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'ww')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='zi',color=colors[i])
    da.where(da.zi>=-depth).mean(dim='zi').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(0.,0.6,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'uu')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(0.,0.6,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'vv')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])
ax3.set_ylim([levels.min(), levels.max()])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-0.01,0.01,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'wu')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='zi',color=colors[i])
    da.where(da.zi>=-depth).mean(dim='zi').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])
ax3.set_ylim([levels.min(), levels.max()])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-0.01,0.01,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'wv')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='zi',color=colors[i])
    da.where(da.zi>=-depth).mean(dim='zi').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])
ax3.set_ylim([levels.min(), levels.max()])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(0.,0.6,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'tke')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='z',color=colors[i])
    da.where(da.z>=-depth).mean(dim='z').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])
ax3.set_ylim([levels.min(), levels.max()])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-2.4,2.4,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'sk')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='zi',color=colors[i])
    da.where(da.zi>=-depth).mean(dim='zi').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])

In [None]:
fig, axarr = plt.subplots(4, 1, sharex='col')
fig.set_size_inches(8,6)
fig2, ax2 = plt.subplots(1)
fig2.set_size_inches(6,4)
fig3, ax3 = plt.subplots(1)
fig3.set_size_inches(6,3)
levels = np.linspace(-1.2,1.2,41)
for i, cn in enumerate(casenames.keys()):
    ax = axarr[i]
    da = get_vars(ds[cn], 'wt')
    da.plot(ax=ax, levels=levels)
    da.sel(time=tslice).mean(dim='time').plot(ax=ax2,y='zi',color=colors[i])
    da.where(da.zi>=-depth).mean(dim='zi').plot(ax=ax3,color=colors[i])
    ax.set_ylim([-depth,0])
    ax.set_xlabel('')
ax2.set_xlim([levels.min(), levels.max()])
ax2.set_ylim([-depth,0])