In [19]:
# interactive figure
%matplotlib widget

In [20]:
import os
import dask
import warnings
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
from sys import platform
from streamjoy import stream, wrap_matplotlib
from viztool import FormatScalarFormatter
from dask.distributed import LocalCluster, Client
from dask_jobqueue import PBSCluster
from matplotlib.colors import LinearSegmentedColormap

import streamjoy
streamjoy.config['codec'] = 'h264'

In [21]:
from mpl_toolkits.mplot3d.axis3d import Axis
if not hasattr(Axis, '_get_coord_info_old'):
    def _get_coord_info_new(self, renderer):
        mins, maxs, centers, deltas, tc, highs = self._get_coord_info_old(renderer)
        mins += deltas / 4
        maxs -= deltas / 4
        return mins, maxs, centers, deltas, tc, highs
    Axis._get_coord_info_old = Axis._get_coord_info
    Axis._get_coord_info = _get_coord_info_new

def plot_box_frame(ax, xmin, xmax, ymin, ymax, zmin, zmax, cloud=False, **edges_kw):
    ax.plot([xmax, xmax], [ymin, ymax], zmax, **edges_kw)
    ax.plot([xmin, xmax], [ymin, ymin], zmax, **edges_kw)
    ax.plot([xmin, xmin], [ymin, ymax], zmax, **edges_kw)
    ax.plot([xmin, xmax], [ymax, ymax], zmax, **edges_kw)
    ax.plot([xmax, xmax], [ymax, ymax], [zmin, zmax], **edges_kw)
    ax.plot([xmax, xmax], [ymin, ymin], [zmin, zmax], **edges_kw)
    ax.plot([xmin, xmin], [ymin, ymin], [zmin, zmax], **edges_kw)
    if not cloud:
        ax.plot([xmax, xmax], [ymin, ymax], zmin, **edges_kw)
        ax.plot([xmin, xmax], [ymin, ymin], zmin, **edges_kw)
    else:
        ax.plot([xmin, xmin], [ymax, ymax], [zmin, zmax], **edges_kw)

# @wrap_matplotlib()
def plot_xyz3d(ax, ds, **kwargs):
    depth_lim = kwargs.pop('depth_lim')
    sub_str   = kwargs.pop('sub_str')
    blines    = kwargs.pop('blines')
    cloud     = kwargs.pop('cloud')
    var       = kwargs.pop('var', 'b')

    ustar = np.sqrt(ds.attrs['τ₀']/ds.attrs['ρ₀'])
    xvar, yvar, zvar = sorted(ds[var].dims)
    xb, yb, zb = sorted(ds.b.dims)

    top   = ds[varsel].sel({zvar: -depth_lim[0], zb: -depth_lim[0]}, method='nearest')
    south = ds[varsel].isel({yvar: 0, yb: 0}).sel({zvar: slice(-depth_lim[1],top[zvar]), zb: slice(-depth_lim[1],top[zb])})
    east  = ds[varsel].isel({xvar: -1, xb: -1}).sel({zvar: slice(-depth_lim[1],top[zvar]), zb: slice(-depth_lim[1],top[zb])})
    top['b']   = top.b   - ds.attrs['M²']*top[xb]
    south['b'] = south.b - ds.attrs['M²']*south[xb]
    east['b']  = east.b  - ds.attrs['M²']*east[xb]
    if var in ['u', 'v', 'w']:
        top[var]   = top[var]/ustar
        south[var] = south[var]/ustar
        east[var]  = east[var]/ustar

    xmin, xmax = np.around(south[xvar].min().data), np.around(south[xvar].max().data)
    ymin, ymax = np.around(east[yvar].min().data), np.around(east[yvar].max().data)
    zmin, zmax = np.around(east[zvar].min().data), np.around(top[zvar].data)
    Lx = xmax - xmin
    Ly = ymax - ymin
    Lz = zmax - zmin
    ry = 1.2 if Lx/Ly >= 2 else 1
    rz = 2.8 if Lz > 100 else 3.6
    zcloud = 52

    bkw = dict(levels=blines, linewidths=0.05, colors='xkcd:almost black')
    edges_kw = dict(color='xkcd:charcoal', linewidth=1, zorder=2)

    im = ax.contourf(top[xvar], top[yvar], top[var], zdir='z', offset=top[zvar], **kwargs)
    ax.contour(top[xb], top[yb], top.b, zdir='z', offset=top[zb], **bkw)
    if cloud:
        ax.contourf(ds.xC, ds.yC, ds.aheps, zdir='z', offset=zcloud, cmap='PRGn_r', extend='both',
                    vmin=-0.7, vmax=0.7, levels=np.arange(-0.7, 0.8, 0.1), zorder=2)

    bkw.update(linewidths=0.2, colors='xkcd:almost black') #in place
    X, Z = np.meshgrid(south[xvar], south[zvar])
    ax.contourf(X, south[var], Z, zdir='y', offset=south[yvar], **kwargs)
    X, Z = np.meshgrid(south[xb], south[zb])
    ax.contour(X, south.b, Z, zdir='y', offset=south[yb], **bkw)
    # ax.plot(ds.xC, np.ones_like(ds.xC)*south[yb].data, -ds.heps.isel(yC=0), lw=0.5, c='c')
    # ax.plot(ds.xC, np.ones_like(ds.xC)*south[yb].data, -ds.heps.mean('yC'), lw=0.5, c='c')

    Y, Z = np.meshgrid(east[yvar], east[zvar])
    ax.contourf(east[var], Y, Z, zdir='x', offset=east[xvar], **kwargs)
    Y, Z = np.meshgrid(east[yb], east[zb])
    ax.contour(east.b, Y, Z, zdir='x', offset=east[xb], **bkw)
    # ax.plot(np.ones_like(ds.yC)*east[xb].data, ds.yC, -ds.heps.isel(xC=-1), lw=0.5, c='c')
    # ax.plot(np.ones_like(ds.yC)*east[xb].data, ds.yC, -ds.heps.mean('xC'), lw=0.5, c='c')

    ax.axis('off')
    ax.set(xlim=[xmin, xmax],
           ylim=[ymin, ymax],
           zlim=[zmin, zmax])
    plot_box_frame(ax, xmin, xmax, ymin, ymax, zmin, zmax, **edges_kw)
    if cloud:
        edges_kw.update(linestyle=':', linewidth=0.8, zorder=1)
        plot_box_frame(ax, xmin, xmax, ymin, ymax, zmax, zcloud, cloud=cloud, **edges_kw)

    ax.text2D(0.06, 0.36, sub_str, fontsize=12, ha='left', va='bottom', transform=ax.transAxes)
    ax.set_box_aspect((1, Ly/Lx*ry, Lz/Lx*rz))
    return im

In [15]:
if platform == 'linux' or platform == 'linux2':
    data_dir = '/glade/derecho/scratch/zhihuaz/FrontalZone/Output/'
elif platform == 'darwin':
    data_dir = '/Users/zhihua/Documents/Work/Research/Projects/TRACE-SEAS/FrontalZone/Data/'
else:
    print('OS not supported.')

dsf = []
mask_itime = []
cnames = ['s11_M036_Q000_W037_D270_St0',
          's11_M003_Q000_W444_D270_St0',
          's11_M009_Q000_W148_D000_St0',
          's11_M009_Q000_W148_D180_St0',
          's11_M009_Q000_W148_D090_St0',
          's11_M009_Q135_W148_D090_St0',
         ]
for cname in cnames:
    tmpf = xr.open_dataset(data_dir+cname+'_full.nc').chunk({'time':1, 'zC':40})
    tmpf.close()
    # tmpf = tmpf.where(((tmpf.time / np.timedelta64(int(tmpf.out_interval_slice), 's')) % 1) == 0, drop=True)
    # unique_time = ~pd.Index(tmpf.time).duplicated(keep='last')
    # idx_unique_time = np.arange(tmpf.time.size)[unique_time]

    # integer_time = ((tmpf.time / np.timedelta64(int(tmpf.out_interval_slice), 's')) % 1) == 0
    # tmpf = tmpf.isel(time=integer_time)
    # tmpf = tmpf.drop_duplicates(dim='time', keep='last')

    # unique_time = ~pd.Index(tmpf.time).duplicated(keep='last')
    # mask_time = integer_time & unique_time
    # tmpf = tmpf.isel(time=unique_time)

    tmpf['timeTf'] = tmpf.time/np.timedelta64(int(np.around(2*np.pi/tmpf.f)), 's')

    integer_time = ((tmpf.time / np.timedelta64(int(tmpf.out_interval_slice), 's')) % 1) == 0
    unique_time = ~pd.Index(tmpf.time).duplicated(keep='last')
    mask_time = integer_time & unique_time
    tmpf_idx = np.arange(tmpf.time.size)[mask_time]

    mask_itime.append(tmpf_idx)
    dsf.append(tmpf)

In [9]:
@wrap_matplotlib()
def anim3d_group(ds, **plt_kwargs):
    plt.close()
    nrow, ncol = 1, 2
    subplot_kw = dict(projection='3d', computed_zorder=False)
    fig, ax = plt.subplots(nrow, ncol, figsize=(6,3), dpi=600, subplot_kw=subplot_kw)
    plt.subplots_adjust(wspace=0, hspace=0, top=1, bottom=0, right=1, left=0)
    case_names = plt_kwargs.pop('sub_str')

    for i in range(len(ds)):
        plt_kwargs.update(sub_str=f'{case_names[i]}')
        plot_xyz3d(ax[i], ds[i], **plt_kwargs)
    fig.suptitle(rf'$T_{{inertial}}$ = {ds[0].timeTf.data:.2f}')
    return fig

In [17]:
USER = os.getenv('USER')
TMPDIR = f'/glade/derecho/scratch/{USER}/temp'
job_script_prologue = [f'export TMPDIR={TMPDIR}', 'mkdir -p $TMPDIR']
cluster_kw = dict(job_name='anim3d',
                  cores=1,
                  memory='10GiB',
                  processes=1,
                  local_directory=f'{TMPDIR}/pbs.$PBS_JOBID/dask/spill',
                  log_directory=f'{TMPDIR}/pbs.$PBS_JOBID/dask/worker_logs',
                  job_extra_directives=['-j oe', '-A UMCP0036'],
                  job_script_prologue=job_script_prologue,
                  resource_spec='select=1:ncpus=1:mem=10GB',
                  queue='casper',
                  walltime='60:00',
                  interface='ext')
cluster = PBSCluster(**cluster_kw)
client = Client(cluster)
print(cluster.job_script())
cluster.scale(16)
# print(cluster.dashboard_link.replace(':8787', ':1212/proxy/42843'))

#!/usr/bin/env bash

#PBS -N anim3d
#PBS -q casper
#PBS -l select=1:ncpus=1:mem=10GB
#PBS -l walltime=60:00
#PBS -e /glade/derecho/scratch/zhihuaz/temp/pbs.$PBS_JOBID/dask/worker_logs/
#PBS -o /glade/derecho/scratch/zhihuaz/temp/pbs.$PBS_JOBID/dask/worker_logs/
#PBS -j oe
#PBS -A UMCP0036
export TMPDIR=/glade/derecho/scratch/zhihuaz/temp
mkdir -p $TMPDIR
/glade/work/zhihuaz/conda-envs/trace-seas/bin/python -m distributed.cli.dask_worker tcp://128.117.211.221:42559 --name dummy-name --nthreads 1 --memory-limit 10.00GiB --nanny --death-timeout 60 --local-directory /glade/derecho/scratch/zhihuaz/temp/pbs.$PBS_JOBID/dask/spill --interface ext



In [11]:
var = 'v'
varsel = [var, 'b']
letter = 'abcdef'
vmin, vmax = -9, 9
depth_lim = [3, 80]
plt_kwargs = dict(var=var, cmap='RdBu_r', vmin=vmin, vmax=vmax, blines=np.arange(-0.02, 2.6, 0.008)*1e-2,
                  levels=np.linspace(vmin, vmax, 64), depth_lim=depth_lim, extend='both', cloud=False)

In [None]:
%%time
plt_kwargs.update(sub_str=['DF1', 'DF2'])
stream([(dsf[0].isel(time=i), dsf[1].isel(time=j)) for i,j in zip(mask_itime[0], mask_itime[1])],
       renderer=anim3d_group,
       max_frames=-1,
       renderer_kwargs=plt_kwargs,
       client=client, fps=5,
       write_kwargs=dict(crf='18'),
       threads_per_worker=1,
       uri=f'DF_{var}3d.mp4')
cluster.close()

In [None]:
%%time
plt_kwargs.update(sub_str=['CF1', 'CF2'])
stream([(dsf[2].isel(time=i), dsf[3].isel(time=j)) for i,j in zip(mask_itime[2], mask_itime[3])],
       renderer=anim3d_group,
       max_frames=-1,
       renderer_kwargs=plt_kwargs,
       client=client, fps=5,
       write_kwargs=dict(crf='18'),
       threads_per_worker=1,
       uri=f'CF_{var}3d.mp4')
cluster.close()

In [None]:
%%time
plt_kwargs.update(sub_str=['UF1', 'UF1c'])
stream([(dsf[4].isel(time=i), dsf[5].isel(time=j)) for i,j in zip(mask_itime[4], mask_itime[5])],
       renderer=anim3d_group,
       max_frames=-1,
       renderer_kwargs=plt_kwargs,
       client=client, fps=5,
       write_kwargs=dict(crf='18'),
       threads_per_worker=1,
       uri=f'UF_{var}3d.mp4')
cluster.close()

In [14]:
client.close()