In [1]:
import os, sys
from glob import glob
import numpy as np
import dask
import xarray as xr
import xgcm
import cartopy.crs as ccrs
from cmocean import cm

from matplotlib import pyplot as plt
%matplotlib inline

from mitequinox.utils import *
from mitequinox.sigp import *
from mitequinox.plot import *
from dask import compute, delayed

In [2]:
from dask_jobqueue import PBSCluster
#cluster = PBSCluster(cores=6, walltime='06:00:00')
cluster = PBSCluster(cores=6, processes=6,  walltime='04:00:00')
w = cluster.scale(6*10)

In [3]:
# get dask handles and check dask server status
from dask.distributed import Client
client = Client(cluster)

In [4]:
client

0,1
Client  Scheduler: tcp://10.135.39.35:40321  Dashboard: http://10.135.39.35:8787/status,Cluster  Workers: 60  Cores: 60  Memory: 1.00 TB


__________

# Global

In [5]:
E_dir = '/work/ALT/swot/aval/syn/xy/momentum_balance/hourly/'

time_length = 240
dij=4

ds_ice = xr.open_zarr(work_data_dir+'xy/sea_ice_mask.zarr')
ice = ds_ice.AREA.isel(i=slice(0,None,dij), j=slice(0,None,dij))

# define (real) time
def iters_to_date(iters, delta_t=3600.):
    t0 = datetime.datetime(2011,11,23,8)    
    ltime = delta_t * (np.array(iters))
    dtime = [t0+dateutil.relativedelta.relativedelta(seconds=t) for t in ltime]    
    return dtime

time_day = iters_to_date(np.arange(time_length))

In [6]:
font_size = 20

def plot_pretty_6(v1, v2, v3, v4, v5, v6, colorbar=False, title=None, label=None, vmin=None, vmax=None, savefig=None, 
                  offline=False, figsize=(20,12), cmmap='thermal', ignore_face=[]):
    
    if vmin is None:
        vmin = v.min()
    if vmax is None:
        vmax = v.max()
    #
    MPL_LOCK = threading.Lock()
    with MPL_LOCK:
        if offline:
            plt.switch_backend('agg')
        #
        fig = plt.figure(figsize=figsize)
        cmap = getattr(cm, cmmap)
        
        # 1
        ax = fig.add_subplot(321, projection=ccrs.PlateCarree(central_longitude=180))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v1.isel(face=face)
            im = vplt.plot.pcolormesh(ax=ax,                   
                            transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                            x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
        cb = plt.colorbar(im, ax=ax)
        cb.set_label(label=label[0], fontsize=font_size)   
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[0],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)

        # 2
        ax = fig.add_subplot(322, projection=ccrs.PlateCarree(central_longitude=180))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v2.isel(face=face)
            im = vplt.plot.pcolormesh(ax=ax,                   
                            transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                            x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
        cb = plt.colorbar(im, ax=ax)
        cb.set_label(label=label[1], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        #ax.set_title('%s \n $-\zeta v \: (m\, s^{-2})$'%mtime, fontsize=font_size) 
        ax.set_title(title[1],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)        

        # 3
        ax = fig.add_subplot(323, projection=ccrs.PlateCarree(central_longitude=180))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v3.isel(face=face)
            im = vplt.plot.pcolormesh(ax=ax,                   
                            transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                            x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
        cb = plt.colorbar(im, ax=ax)
        cb.set_label(label=label[2], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[2],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)                

        # 4
        ax = fig.add_subplot(324, projection=ccrs.PlateCarree(central_longitude=180))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v4.isel(face=face)
            im = vplt.plot.pcolormesh(ax=ax,                   
                            transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                            x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
        cb = plt.colorbar(im, ax=ax)
        cb.set_label(label=label[3], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[3],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)     

        # 5
        ax = fig.add_subplot(325, projection=ccrs.PlateCarree(central_longitude=180))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v5.isel(face=face)
            im = vplt.plot.pcolormesh(ax=ax,                   
                            transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                            x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
        cb = plt.colorbar(im, ax=ax)
        cb.set_label(label=label[4], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[4],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)     

        # 6
        ax = fig.add_subplot(326, projection=ccrs.PlateCarree(central_longitude=180))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v6.isel(face=face)
            im = vplt.plot.pcolormesh(ax=ax,                   
                            transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                            x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
        cb = plt.colorbar(im, ax=ax)
        cb.set_label(label=label[5], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size) 
        ax.set_title(title[5],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)     
        
        #if title is not None:
        #    ax.set_title(title, size=font_size)
        #if label is not None:
        #    cb.set_label(label=label, size=font_size)     #
            
        if savefig is not None:
            fig.savefig(savefig, dpi=100)
            plt.close(fig)
        #
        if not offline:
            plt.show()
            

# zonal

In [7]:
# read data
F = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
du_dt = xr.concat([xr.open_zarr(E_dir+'du_dt_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
f_v = xr.concat([xr.open_zarr(E_dir+'fv_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
zeta_v = xr.concat([xr.open_zarr(E_dir+'zeta_v_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
dKEdx = xr.concat([xr.open_zarr(E_dir+'dKE_dx_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
detadx = xr.concat([xr.open_zarr(E_dir+'deta_dx_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
residuals_u = xr.concat([xr.open_zarr(E_dir+'residuals_u_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')

dudt = du_dt.du_dt
fv = f_v.fv
zeta_v = zeta_v.zeta_v
dKE_dx = dKEdx.dKE_dx
deta_dx = detadx.deta_dx
residuals_u = residuals_u.residuals_u

ds = xr.merge([dudt,fv, zeta_v,dKE_dx, deta_dx, residuals_u])
ds = ds.assign_coords(time=time_day) 
print(ds)
print('\n data size: %.1f GB' %(ds.nbytes / 1e9))

<xarray.Dataset>
Dimensions:      (face: 13, i: 1080, j: 1080, time: 240)
Coordinates:
  * time         (time) datetime64[ns] 2011-11-23T08:00:00 ... 2011-12-03T07:00:00
  * i            (i) int64 0 4 8 12 16 20 24 ... 4296 4300 4304 4308 4312 4316
  * j            (j) int64 0 4 8 12 16 20 24 ... 4296 4300 4304 4308 4312 4316
    CS           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    Depth        (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    SN           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    XC           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    YC           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
  * face         (face) int64 0 1 2 3 4 5 6 7 8 9 10 11 12
    hFacC        (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    rA           (face, j, i) float32 dask.array<chunks

In [14]:
vmin, vmax = -4e-5, 4e-5
lds = ds

def process(ds, i, overwrite=True):    
    
    du_dt = ds['du_dt'].where(ice>0)
    hadv = -1*ds['zeta_v'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0) + ds['dKE_dx'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    fv = -1*ds['fv'].where(ice>0)
    #zeta_v = -1*ds['zeta_v'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    #dKE_dx = ds['dKE_dx'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    deta_dx = ds['deta_dx'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    residual_u = ds['residuals_u'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    fva = -1*ds['fv'].where(ice>0) + ds['deta_dx'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    
    mtime = time_day[i]
    figname = '/home/uz/yux/mit_equinox/hal/Geostrophy_assessment/Figures/Global_M_U/M_U_global_t%05d'%(i)+'.png'
    #
    if not os.path.isfile(figname) or overwrite:
        
        title = ['%s \n $\partial u/\partial t \: (m\, s^{-2})$'%mtime, r'$-\zeta v + \partial KE/\partial x \: (m\, s^{-2})$', r'$-fv \: (m\, s^{-2})$',
                 r'$g\partial \eta/\partial x \: (m\, s^{-2})$', r'$R_u \: (m\, s^{-2})$', r'$-fv+g\partial \eta/\partial x \: (m\, s^{-2})$']
        label = ['', '', '', '', '', '']
        
        plot_pretty_6(du_dt, hadv, fv, deta_dx, residual_u, fva, colorbar=False, title=title, label=label, savefig=figname, vmin=vmin, vmax=vmax, offline=False, figsize=(40,18), cmmap='balance')
        m = 1
    else:
        m = -1.
    return m

#I = range(len(ds['time']))
I = range(180,240)
#I = range(120,180)
print(I)
values = [delayed(process)(lds.isel(time=i), i) for i in I]

range(180, 240)


Process all times

In [15]:
futures = client.compute(values)
%time results = client.gather(futures)

CPU times: user 7min 47s, sys: 8min 1s, total: 15min 48s
Wall time: 16min 36s


# meridional

In [7]:
F = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
dv_dt = xr.concat([xr.open_zarr(E_dir+'dv_dt_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
f_u = xr.concat([xr.open_zarr(E_dir+'fu_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
zeta_u = xr.concat([xr.open_zarr(E_dir+'zeta_u_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
dKEdy = xr.concat([xr.open_zarr(E_dir+'dKE_dy_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
detady = xr.concat([xr.open_zarr(E_dir+'deta_dy_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')
residuals_v = xr.concat([xr.open_zarr(E_dir+'residuals_v_f%02d.zarr'%(face)) for face in F], dim='face', compat='identical')

dvdt = dv_dt.dv_dt
fu = f_u.fu
zeta_u = zeta_u.zeta_u
dKE_dy = dKEdy.dKE_dy
deta_dy = detady.deta_dy
residuals_v = residuals_v.residuals_v

# merge data
ds = xr.merge([dvdt,fu, zeta_u,dKE_dy, deta_dy, residuals_v])
ds = ds.assign_coords(time=time_day) 
print(ds)
print('\n data size: %.1f GB' %(ds.nbytes / 1e9))

<xarray.Dataset>
Dimensions:      (face: 13, i: 1080, j: 1080, time: 240)
Coordinates:
  * time         (time) datetime64[ns] 2011-11-23T08:00:00 ... 2011-12-03T07:00:00
  * j            (j) int64 0 4 8 12 16 20 24 ... 4296 4300 4304 4308 4312 4316
  * i            (i) int64 0 4 8 12 16 20 24 ... 4296 4300 4304 4308 4312 4316
    CS           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    Depth        (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    SN           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    XC           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    YC           (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
  * face         (face) int64 0 1 2 3 4 5 6 7 8 9 10 11 12
    hFacC        (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    rA           (face, j, i) float32 dask.array<chunks

In [12]:
vmin, vmax = -4e-5, 4e-5
lds = ds

def process(ds, i, overwrite=True):    
    
    dv_dt = ds['dv_dt'].where(ice>0)
    hadv = ds['zeta_u'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0) + ds['dKE_dy'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    fu = ds['fu'].where(ice>0)
    deta_dy = ds['deta_dy'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    residual_v = ds['residuals_v'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    fua = ds['fu'].where(ice>0) + ds['deta_dy'].isel(i=slice(1,-1),j=slice(1,-1)).where(ice>0)
    
    mtime = time_day[i]
    figname = '/home/uz/yux/mit_equinox/hal/Geostrophy_assessment/Figures/Global_M_V/M_V_global_t%05d'%(i)+'.png'
    #
    if not os.path.isfile(figname) or overwrite:
        
        title = ['%s \n $\partial v/\partial t \: (m\, s^{-2})$'%mtime, r'$\zeta u + \partial KE/\partial y \: (m\, s^{-2})$', r'$fu \: (m\, s^{-2})$',
                 r'$g\partial \eta/\partial y \: (m\, s^{-2})$', r'$R_v \: (m\, s^{-2})$', r'$fu+g\partial \eta/\partial y \: (m\, s^{-2})$']
        label = ['', '', '', '', '', '']
        
        plot_pretty_6(dv_dt, hadv, fu, deta_dy, residual_v, fua, colorbar=False, title=title, label=label, savefig=figname, vmin=vmin, vmax=vmax, offline=False, figsize=(40,18), cmmap='balance')
        m = 1
    else:
        m = -1.
    return m

#I = range(len(ds['time']))
I = range(180,240)
print(I)
values = [delayed(process)(lds.isel(time=i), i) for i in I]

range(180, 240)


In [13]:
futures = client.compute(values)
%time results = client.gather(futures)

CPU times: user 8min 56s, sys: 6min 50s, total: 15min 47s
Wall time: 16min 34s


In [14]:
cluster.close()

distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/home/uz/yux/.conda/envs/equinox/lib/python3.7/site-packages/distributed/utils.py", line 666, in log_errors
    yield
  File "/home/uz/yux/.conda/envs/equinox/lib/python3.7/site-packages/distributed/client.py", line 1276, in _close
    await gen.with_timeout(timedelta(seconds=2), list(coroutines))
concurrent.futures._base.CancelledError
distributed.utils - ERROR - 
Traceback (most recent call last):
  File "/home/uz/yux/.conda/envs/equinox/lib/python3.7/site-packages/distributed/utils.py", line 666, in log_errors
    yield
  File "/home/uz/yux/.conda/envs/equinox/lib/python3.7/site-packages/distributed/client.py", line 1005, in _reconnect
    await self._close()
  File "/home/uz/yux/.conda/envs/equinox/lib/python3.7/site-packages/distributed/client.py", line 1276, in _close
    await gen.with_timeout(timedelta(seconds=