In [None]:
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
sys.path.append(os.path.join(os.path.pardir, 'lesview'))
from lesview import *

In [None]:
casenames = dict(
    c1 = 'R11_MSM97-CT_Stokes_f0_fixdt3',
    c2 = 'R11_MSM97-CT_Stokes_f0_fixdt3_visL',
    # c2 = 'R11_MSM97-CT_Stokes_f0_trs4',
    # c3 = 'R11_MSM97-CT_Stokes_f0_r2',
)
iclabel = 'icL'
# casenames = dict(
#     c1 = 'R11_MSM97-CT_Stokes_f0_fixdt3_IE',
#     c2 = 'R11_MSM97-CT_Stokes_f0_trs4_IL2',
#     c3 = 'R11_MSM97-CT_Stokes_f0_r2_IL',
# )
# iclabel = 'icE'
models = dict(
    c1 = 'ncarles',
    c2 = 'ncarles',
    # c2 = 'oceananigans',
    # c3 = 'oceananigans',
)
iends = dict(
    c1 = 57600,
    c2 = 28800,
)
snapshots = ['2000-01-01T02:00:00', '2000-01-01T03:00:00', '2000-01-01T06:00:00', '2000-01-01T12:00:00']
labels = ['2 hr', '3 hr', '6 hr', '12 hr']
abc = ['abcd', 'efgh', 'ijkl']

figpath = 'R11_MSM97_CT_Stokes_f0'
os.makedirs(figpath, exist_ok=True)

In [None]:
var = 'w'
da = {}
for cn in casenames.keys():
    print(cn)
    if models[cn] == 'ncarles':
        datapath = os.path.join(os.path.pardir, models[cn], casenames[cn], 'viz.vis.000000.{:06d}.xy.nc'.format(iends[cn]))
        da[cn] = NCARLESDataVolume(filepath=datapath, fieldname=var).dataset.data_vars[var].isel(nslc=1)
    elif models[cn] == 'oceananigans':
        datapath = os.path.join(os.path.pardir, models[cn], casenames[cn], 'fields.jld2')
        da[cn] = OceananigansDataVolume(filepath=datapath, fieldname=var).dataset.data_vars[var].sel(zi=-1.5, method='nearest')
    else:
        raise ValueError('Model {:s} not supported'.format(models[cn]))
        

In [None]:
levels = np.linspace(-0.02, 0.02, 41)
nrow = len(casenames.keys())
ncol = len(snapshots)

if nrow == 1:
    levels = np.linspace(-0.02, 0.02, 41)
    nrow = len(casenames.keys())
    ncol = len(snapshots)

    fig, axarr = plt.subplots(nrow, ncol, sharex='col',sharey='row')
    fig.set_size_inches(8,8/ncol*nrow+1.5)
    for j, ss in enumerate(snapshots):
        ax = axarr[j]
        im = xr.plot.contourf(da[cn].sel(time=ss, method='nearest'), ax=ax, y='y', levels=levels, extend='both', add_colorbar=False)
        ax.set_title('')
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_aspect('equal')
        ax.text(0.05, 0.05, '({:s})'.format(abc[0][j]), transform=ax.transAxes, va='bottom', ha='left', bbox=dict(facecolor='w', alpha=0.75, edgecolor='none'))
    ax = axarr[0]
    ax.set_ylabel('$y$ [m]')
    for j in np.arange(ncol):
        ax = axarr[j]
        ax.set_title(labels[j], fontsize=10)
        ax.set_xlabel('$x$ [m]')

    plt.subplots_adjust(top=0.97, bottom=0.25, left=0.08, right=0.97, hspace=0.08, wspace=0.06)
    cax = plt.axes([0.25, 0.17, 0.5, 0.03])
    cb = plt.colorbar(im, cax=cax, orientation='horizontal')
    clabels = np.linspace(levels[0], levels[-1], 5)
    cb.set_ticks(clabels) 
    cb.set_label('$w$ [m s$^{-1}$]')
else:
    fig, axarr = plt.subplots(nrow, ncol, sharex='col',sharey='row')
    fig.set_size_inches(8,8/ncol*nrow+1)
    for i, cn in enumerate(casenames.keys()):
        for j, ss in enumerate(snapshots):
            ax = axarr[i,j]
            im = xr.plot.contourf(da[cn].sel(time=ss, method='nearest'), ax=ax, y='y', levels=levels, extend='both', add_colorbar=False)
            ax.set_title('')
            ax.set_xlabel('')
            ax.set_ylabel('')
            ax.set_aspect('equal')
            ax.text(0.05, 0.05, '({:s})'.format(abc[i][j]), transform=ax.transAxes, va='bottom', ha='left', bbox=dict(facecolor='w', alpha=0.75, edgecolor='none'))
        ax = axarr[i,0]
        ax.set_ylabel('$y$ [m]')
    for j in np.arange(ncol):
        ax = axarr[0,j]
        ax.set_title(labels[j], fontsize=10)
        ax = axarr[-1,j]
        ax.set_xlabel('$x$ [m]')

    plt.subplots_adjust(top=0.95, bottom=0.22, left=0.08, right=0.97, hspace=0.03, wspace=0.06)

    cax = plt.axes([0.25, 0.1, 0.5, 0.015])
    cb = plt.colorbar(im, cax=cax, orientation='horizontal')
    clabels = np.linspace(levels[0], levels[-1], 5)
    cb.set_ticks(clabels) 
    cb.set_label('$w$ [m s$^{-1}$]')

figname = os.path.join(figpath, 'snapshots-v2-{:s}_r{:g}'.format(iclabel, nrow))
fig.savefig(figname, dpi = 300, facecolor='w')