In [1]:
import os, sys
from glob import glob
import numpy as np
import dask
import xarray as xr
import xgcm
import cartopy.crs as ccrs
from cmocean import cm

from matplotlib import pyplot as plt
%matplotlib inline

from mitequinox.utils import *
from mitequinox.sigp import *
from mitequinox.plot import *
from dask import compute, delayed

In [2]:
from dask_jobqueue import PBSCluster
# for heavy processing:
cluster = PBSCluster(cores=6, processes=6,  walltime='03:00:00')
#cluster = PBSCluster(cores=12, walltime='03:00:00')
w = cluster.scale(6*12)

In [3]:
# get dask handles and check dask server status
from dask.distributed import Client
client = Client(cluster)

In [4]:
client

0,1
Client  Scheduler: tcp://10.135.36.243:33276  Dashboard: /user/yux/proxy/8787/status,Cluster  Workers: 72  Cores: 72  Memory: 1.20 TB


______________
# Four terms (Eta, SSU, Ug, Ua)
## needs calculation

In [20]:
time_length = 240

E_dir = '/work/ALT/swot/aval/syn/xy/momentum_balance/daily/'

grd = load_grd().drop(['hFacC','hFacW','hFacS','rA','rAw','rAs'])
dsE = xr.open_zarr(root_data_dir+'zarr/%s.zarr'%('Eta')).isel(time=slice(1512+200,1512+200+time_length))
dsU = xr.open_zarr(root_data_dir+'zarr/%s.zarr'%('SSU')).rename({'i_g': 'i'}).isel(time=slice(200,200+time_length))
dsV = xr.open_zarr(root_data_dir+'zarr/%s.zarr'%('SSV')).rename({'j_g': 'j'}).isel(time=slice(200,200+time_length))

# define (real) time
def iters_to_date(iters, delta_t=3600.):
    t0 = datetime.datetime(2011,11,23,8)    
    ltime = delta_t * (np.array(iters))
    dtime = [t0+dateutil.relativedelta.relativedelta(seconds=t) for t in ltime]    
    return dtime

time_day = iters_to_date(np.arange(time_length))

ds = xr.merge([dsE,dsU,dsV,grd])
ds = ds.assign_coords(time=time_day) 
ds

<xarray.Dataset>
Dimensions:  (face: 13, i: 4320, i_g: 4320, j: 4320, j_g: 4320, k: 90, k_l: 90, k_p1: 91, k_u: 90, time: 240)
Coordinates:
    dtime    (time) datetime64[ns] dask.array<chunksize=(240,), meta=np.ndarray>
  * face     (face) int64 0 1 2 3 4 5 6 7 8 9 10 11 12
  * i        (i) int64 0 1 2 3 4 5 6 7 ... 4313 4314 4315 4316 4317 4318 4319
    iters    (time) int64 dask.array<chunksize=(1,), meta=np.ndarray>
  * j        (j) int64 0 1 2 3 4 5 6 7 ... 4313 4314 4315 4316 4317 4318 4319
  * time     (time) datetime64[ns] 2011-11-23T08:00:00 ... 2011-12-03T07:00:00
    CS       (face, j, i) float32 dask.array<chunksize=(1, 4320, 4320), meta=np.ndarray>
    Depth    (face, j, i) float32 dask.array<chunksize=(1, 4320, 4320), meta=np.ndarray>
    PHrefC   (k) float32 dask.array<chunksize=(90,), meta=np.ndarray>
    PHrefF   (k_p1) float32 dask.array<chunksize=(91,), meta=np.ndarray>
    SN       (face, j, i) float32 dask.array<chunksize=(1, 4320, 4320), meta=np.ndarray>
    XC   

In [22]:
grd = load_grd().reset_coords()
mask = ((grd.hFacW.rename({'i_g': 'i'}) == 1) &
        (grd.hFacS.rename({'j_g': 'j'}) == 1) 
       ).rename('mask').reset_coords(drop=True)
grd_rspec = xr.merge([mask, grd.XC, grd.YC, grd.Depth])

# coriolis term
lat = grd_rspec['YC']
omega = 7.3/100000
f_ij = 2*omega*np.sin(np.deg2rad(lat))

# define the connectivity between faces
face_connections = {'face':
                    {0: {'X':  ((12, 'Y', False), (3, 'X', False)),
                         'Y':  (None,             (1, 'Y', False))},
                     1: {'X':  ((11, 'Y', False), (4, 'X', False)),
                         'Y':  ((0, 'Y', False),  (2, 'Y', False))},
                     2: {'X':  ((10, 'Y', False), (5, 'X', False)),
                         'Y':  ((1, 'Y', False),  (6, 'X', False))},
                     3: {'X':  ((0, 'X', False),  (9, 'Y', False)),
                         'Y':  (None,             (4, 'Y', False))},
                     4: {'X':  ((1, 'X', False),  (8, 'Y', False)),
                         'Y':  ((3, 'Y', False),  (5, 'Y', False))},
                     5: {'X':  ((2, 'X', False),  (7, 'Y', False)),
                         'Y':  ((4, 'Y', False),  (6, 'Y', False))},
                     6: {'X':  ((2, 'Y', False),  (7, 'X', False)),
                         'Y':  ((5, 'Y', False),  (10, 'X', False))},
                     7: {'X':  ((6, 'X', False),  (8, 'X', False)),
                         'Y':  ((5, 'X', False),  (10, 'Y', False))},
                     8: {'X':  ((7, 'X', False),  (9, 'X', False)),
                         'Y':  ((4, 'X', False),  (11, 'Y', False))},
                     9: {'X':  ((8, 'X', False),  None),
                         'Y':  ((3, 'X', False),  (12, 'Y', False))},
                     10: {'X': ((6, 'Y', False),  (11, 'X', False)),
                          'Y': ((7, 'Y', False),  (2, 'X', False))},
                     11: {'X': ((10, 'X', False), (12, 'X', False)),
                          'Y': ((8, 'Y', False),  (1, 'X', False))},
                     12: {'X': ((11, 'X', False), None),
                          'Y': ((9, 'Y', False),  (0, 'X', False))}}}

# create the grid object
gridx = xgcm.Grid(ds, periodic=False, face_connections=face_connections)
gridx

<xgcm.Grid>
Z Axis (not periodic):
  * center   k --> left
  * left     k_l --> center
  * outer    k_p1 --> center
  * right    k_u --> center
Y Axis (not periodic):
  * center   j --> left
  * left     j_g --> center
X Axis (not periodic):
  * center   i --> left
  * left     i_g --> center

In [23]:
# u_g
ug = -9.8*(gridx.diff( ds.Eta,'Y', boundary='fill')/ds.dyC).rename({'j_g': 'j'})/f_ij
ug = ug.chunk({'face': 1,'i':4320,'j':4320})

vg = 9.8*(gridx.diff( ds.Eta,'X', boundary='fill')/ds.dxC).rename({'i_g': 'i'})/f_ij
vg = vg.chunk({'face': 1,'i':4320,'j':4320})

ds['ug'] = ug
ds['vg'] = vg
ds_1080 = ds.isel(i=slice(0,None,dij), j=slice(0,None,dij))
print(ds_1080)
Efile = work_data_dir+'xy/comparison/Movies/uveta_1080_snapshot.zarr'
%time ds_1080.to_zarr(Efile, mode='w')

<xarray.Dataset>
Dimensions:  (face: 13, i: 1080, i_g: 4320, j: 1080, j_g: 4320, k: 90, k_l: 90, k_p1: 91, k_u: 90, time: 240)
Coordinates:
    dtime    (time) datetime64[ns] dask.array<chunksize=(240,), meta=np.ndarray>
  * face     (face) int64 0 1 2 3 4 5 6 7 8 9 10 11 12
  * i        (i) int64 0 4 8 12 16 20 24 ... 4292 4296 4300 4304 4308 4312 4316
    iters    (time) int64 dask.array<chunksize=(1,), meta=np.ndarray>
  * j        (j) int64 0 4 8 12 16 20 24 ... 4292 4296 4300 4304 4308 4312 4316
  * time     (time) datetime64[ns] 2011-11-23T08:00:00 ... 2011-12-03T07:00:00
    CS       (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    Depth    (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    PHrefC   (k) float32 dask.array<chunksize=(90,), meta=np.ndarray>
    PHrefF   (k_p1) float32 dask.array<chunksize=(91,), meta=np.ndarray>
    SN       (face, j, i) float32 dask.array<chunksize=(1, 1080, 1080), meta=np.ndarray>
    X

<xarray.backends.zarr.ZarrStore at 0x2b3465934888>

# Read data

In [5]:
time_length = 240
dij = 4
grd = load_grd().reset_coords().isel(i=slice(0,None,dij), j=slice(0,None,dij))
mask = ((grd.hFacW.rename({'i_g': 'i'}) == 1) &
        (grd.hFacS.rename({'j_g': 'j'}) == 1) 
       ).rename('mask').reset_coords(drop=True)
grd_rspec = xr.merge([grd.XC, grd.YC, grd.Depth, grd.CS, grd.SN])

# define (real) time
def iters_to_date(iters, delta_t=3600.):
    t0 = datetime.datetime(2011,11,23,8)    
    ltime = delta_t * (np.array(iters))
    dtime = [t0+dateutil.relativedelta.relativedelta(seconds=t) for t in ltime]    
    return dtime

time_day = iters_to_date(np.arange(time_length))
ds = xr.open_zarr(work_data_dir+'xy/comparison/Movies/uveta_snapshot.zarr')
ds = ds.assign_coords(time=time_day) 
print(ds)
ds = ds.drop_dims(['i_g','j_g','k','k_l','k_p1','k_u'])
print('\n data size: %.1f GB' %(ds.nbytes / 1e9))

In [52]:
# no need to store Eta first

dij=4
time_length = 240
overwrite=True
    
for face in range(13):
#for face in [1]:

    Efile = work_data_dir+'xy/comparison/Movies/Eta_U_Ug_Ua_f%02d.zarr'%(face)

    if not os.path.isdir(Efile) or overwrite:
        
        Eta = ds.Eta.isel(face=face)
        u_rotate = ds.SSU.isel(face=face)*grd_rspec.CS.isel(face=face) - ds.SSV.isel(face=face)*grd_rspec.SN.isel(face=face)
        ug_rotate = ds.ug.isel(face=face)*grd_rspec.CS.isel(face=face) - ds.vg.isel(face=face)*grd_rspec.SN.isel(face=face)
        ua_rotate = u_rotate - ug_rotate
        
        Eta = np.real(Eta).rename('Eta')
        u_rotate = np.real(u_rotate).rename('u_rotate')
        ug_rotate = np.real(ug_rotate).rename('ug_rotate')
        ua_rotate = np.real(ua_rotate).rename('ua_rotate')
        
        ds_rotate = xr.merge([Eta, u_rotate, ug_rotate, ua_rotate])
        #print(ds_rotate)       
        %time ds_rotate.to_zarr(Efile, mode='w')

        print('--- face %d done'%face)

    else:
        print('--- face %d allready computed'%face)

CPU times: user 6.55 s, sys: 331 ms, total: 6.88 s
Wall time: 6.81 s
--- face 0 done
CPU times: user 5.95 s, sys: 483 ms, total: 6.43 s
Wall time: 7.24 s
--- face 1 done
CPU times: user 5.68 s, sys: 374 ms, total: 6.05 s
Wall time: 6.11 s
--- face 2 done
CPU times: user 6.05 s, sys: 338 ms, total: 6.39 s
Wall time: 6.52 s
--- face 3 done
CPU times: user 5.92 s, sys: 382 ms, total: 6.3 s
Wall time: 6.41 s
--- face 4 done
CPU times: user 5.99 s, sys: 351 ms, total: 6.34 s
Wall time: 6.33 s
--- face 5 done
CPU times: user 5.83 s, sys: 387 ms, total: 6.21 s
Wall time: 6.33 s
--- face 6 done
CPU times: user 5.48 s, sys: 368 ms, total: 5.85 s
Wall time: 6.04 s
--- face 7 done
CPU times: user 6.61 s, sys: 337 ms, total: 6.94 s
Wall time: 6.94 s
--- face 8 done
CPU times: user 6.36 s, sys: 349 ms, total: 6.71 s
Wall time: 6.73 s
--- face 9 done
CPU times: user 6.72 s, sys: 366 ms, total: 7.09 s
Wall time: 7.11 s
--- face 10 done
CPU times: user 6.75 s, sys: 419 ms, total: 7.17 s
Wall time: 7.2

# Read final data

In [5]:
time_length = 240
dij = 4
grd = load_grd().reset_coords().isel(i=slice(0,None,dij), j=slice(0,None,dij))
mask = ((grd.hFacW.rename({'i_g': 'i'}) == 1) &
        (grd.hFacS.rename({'j_g': 'j'}) == 1) 
       ).rename('mask').reset_coords(drop=True)
grd_rspec = xr.merge([grd.XC, grd.YC, grd.Depth, grd.CS, grd.SN])

ds_ice = xr.open_zarr(work_data_dir+'xy/sea_ice_mask.zarr')
ice = ds_ice.AREA.isel(i=slice(0,None,dij), j=slice(0,None,dij))

# define (real) time
def iters_to_date(iters, delta_t=3600.):
    t0 = datetime.datetime(2011,11,23,8)    
    ltime = delta_t * (np.array(iters))
    dtime = [t0+dateutil.relativedelta.relativedelta(seconds=t) for t in ltime]    
    return dtime

time_day = iters_to_date(np.arange(time_length))

face_all = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
D = [xr.open_zarr(work_data_dir+'xy/comparison/Movies/Eta_U_Ug_Ua_f%02d.zarr'%(face)) for face in face_all] 
ds = xr.concat(D, dim='face')
print(ds)
print('\n data size: %.1f GB' %(ds.nbytes / 1e9))

<xarray.Dataset>
Dimensions:    (face: 13, i: 1080, j: 1080, time: 240)
Coordinates:
    CS         (face, j, i) float32 dask.array<chunksize=(1, 270, 540), meta=np.ndarray>
    Depth      (face, j, i) float32 dask.array<chunksize=(1, 270, 540), meta=np.ndarray>
    SN         (face, j, i) float32 dask.array<chunksize=(1, 270, 540), meta=np.ndarray>
    XC         (face, j, i) float32 dask.array<chunksize=(1, 270, 540), meta=np.ndarray>
    YC         (face, j, i) float32 dask.array<chunksize=(1, 270, 540), meta=np.ndarray>
    dtime      (time) datetime64[ns] 2011-11-23T08:00:00 ... 2011-12-03T07:00:00
  * face       (face) int64 0 1 2 3 4 5 6 7 8 9 10 11 12
  * i          (i) int64 0 4 8 12 16 20 24 ... 4296 4300 4304 4308 4312 4316
    iters      (time) int64 256896 257040 257184 257328 ... 291024 291168 291312
  * j          (j) int64 0 4 8 12 16 20 24 ... 4296 4300 4304 4308 4312 4316
  * time       (time) datetime64[ns] 2011-11-23T08:00:00 ... 2011-12-03T07:00:00
Data variables:


In [6]:
import cartopy.feature as cfeature
font_size = 22

def plot_pretty_4(v1, v2, v3, v4, colorbar=False, title=None, label=None, vmin=None, vmax=None, savefig=None, 
                  offline=False, figsize=(20,12), cmmap='thermal', ignore_face=[], coast_resolution='110m'):
    
    if vmin is None:
        vmin = v.min()
    if vmax is None:
        vmax = v.max()
    #
    MPL_LOCK = threading.Lock()
    with MPL_LOCK:
        if offline:
            plt.switch_backend('agg')
        #
        fig = plt.figure(figsize=figsize)
        cmap = getattr(cm, cmmap)
        
        # 1
        ax = fig.add_subplot(411, projection=ccrs.PlateCarree(central_longitude=0))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v1.sel(face=face)
            if face in [6,7,8,9]:
                # this deals with dateline crossing areas
                im = vplt.where( (vplt.XC<=179.9) & (vplt.XC>=0.)).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
                im = vplt.where( (vplt.XC>-179.9) & (vplt.XC<=0) ).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
            else:
                im = vplt.plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)        
        cb = plt.colorbar(im, ax=ax, ticks=[-4, -3, -2, -1, 0, 1, 2, 3, 4])
        cb.set_label(label=label[0], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[0],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-70, -60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60, 70])
        ax.set_yticklabels(['','$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N',''])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)
        if coast_resolution is not None:
            ax.coastlines(resolution=coast_resolution, color='k')
        ax.add_feature(cfeature.LAND)    
        
        # 2
        ax = fig.add_subplot(412, projection=ccrs.PlateCarree(central_longitude=0))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v2.sel(face=face)
            if face in [6,7,8,9]:
                # this deals with dateline crossing areas
                im = vplt.where( (vplt.XC<=179.9) & (vplt.XC>=0.)).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
                im = vplt.where( (vplt.XC>-179.9) & (vplt.XC<=0) ).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
            else:
                im = vplt.plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)      
        cb = plt.colorbar(im, ax=ax, ticks=[-2, -1, 0, 1, 2])
        cb.set_label(label=label[1], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[2],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-70, -60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60, 70])
        ax.set_yticklabels(['','$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N',''])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)        
        if coast_resolution is not None:
            ax.coastlines(resolution=coast_resolution, color='k')
        ax.add_feature(cfeature.LAND)    
        
        # 3
        ax = fig.add_subplot(413, projection=ccrs.PlateCarree(central_longitude=0))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v3.sel(face=face)
            if face in [6,7,8,9]:
                # this deals with dateline crossing areas
                im = vplt.where( (vplt.XC<=179.9) & (vplt.XC>=0.)).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
                im = vplt.where( (vplt.XC>-179.9) & (vplt.XC<=0) ).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
            else:
                im = vplt.plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)      
        cb = plt.colorbar(im, ax=ax, ticks=[-2, -1, 0, 1, 2])
        cb.set_label(label=label[2], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[2],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-70, -60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60, 70])
        ax.set_yticklabels(['','$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N',''])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)                
        if coast_resolution is not None:
            ax.coastlines(resolution=coast_resolution, color='k')
        ax.add_feature(cfeature.LAND)   
        
        # 4
        ax = fig.add_subplot(414, projection=ccrs.PlateCarree(central_longitude=0))
        gen = [0,1,2,3,4,5,7,8,9,10,11,12]
        for face in gen:
            vplt = v4.sel(face=face)
            if face in [6,7,8,9]:
                # this deals with dateline crossing areas
                im = vplt.where( (vplt.XC<=179.9) & (vplt.XC>=0.)).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
                im = vplt.where( (vplt.XC>-179.9) & (vplt.XC<=0) ).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
            else:
                im = vplt.plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=-2, vmax=2,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)       
        cb = plt.colorbar(im, ax=ax, ticks=[-2, -1, 0, 1, 2])
        cb.set_label(label=label[3], fontsize=font_size)            
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title(title[3],fontsize=font_size) 
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$W','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$E'])
        ax.set_ylim(-70., 70.)
        ax.set_yticks([-70, -60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60, 70])
        ax.set_yticklabels(['','$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N',''])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)     
        if coast_resolution is not None:
            ax.coastlines(resolution=coast_resolution, color='k')
        ax.add_feature(cfeature.LAND)  
        
        #if title is not None:
        #    ax.set_title(title, size=font_size)
        #if label is not None:
        #    cb.set_label(label=label, size=font_size)     #
            
        if savefig is not None:
            fig.savefig(savefig, dpi=150)
            plt.close(fig)
        #
        if not offline:
            plt.show()
            

In [9]:
vmin, vmax = -4, 4
lds = ds

def genfig(ds, i, overwrite=True):
    
    # !! passing the variable does not work
    eta = ds.Eta.where(ice>0)
    u = ds.u_rotate.where(ice>0)
    ug = ds.ug_rotate.where(ice>0)
    ua = ds.ua_rotate.where(ice>0)
    
    mtime = time_day[i]
    figname = '/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Global_Eta_U_Ug_Ua/Eta_U_Ug_Ua_t%05d'%(i)+'.png'
    #
    if not os.path.isfile(figname) or overwrite:
        
    
        title = ['%s'%mtime, '', '', '']
        label = ['$\eta$ (m)', 'u (m/s)', '$u_g$ (m/s)', '$u_a$ (m/s)']
        
        plot_pretty_4(eta, u, ug, ua,  vmin=vmin, vmax=vmax, cmmap='balance', title=title, label=label, savefig=figname, figsize=(20,24), ignore_face=[6])

        #m = 1.
        m = float(eta.mean().values)
    else:
        m = -1.
    return m

I = range(len(ds['time']))
#I = range(24)
print(I)
values = [delayed(genfig)(lds.isel(time=i), i) for i in I]

range(0, 240)


Process all times

In [10]:
# 2 hours for 10 days
futures = client.compute(values)
%time results = client.gather(futures)

distributed.scheduler - ERROR - Couldn't gather keys {'genfig-11a6b5c8-f1e6-45c7-8ea3-8f73a4f4afc7': [], 'genfig-6da8609a-d36b-4989-9a70-105aa6e98087': []} state: ['memory', 'processing'] workers: []
NoneType: None
distributed.scheduler - ERROR - Workers don't have promised key: [], genfig-11a6b5c8-f1e6-45c7-8ea3-8f73a4f4afc7
NoneType: None
distributed.scheduler - ERROR - Workers don't have promised key: [], genfig-6da8609a-d36b-4989-9a70-105aa6e98087
NoneType: None


CPU times: user 30min 59s, sys: 1h 1min 18s, total: 1h 32min 17s
Wall time: 2h 3min 56s


In [11]:
cluster.close()

distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
asyncio.exceptions.CancelledError
