In [1]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:57939")
client

0,1
Client  Scheduler: tcp://127.0.0.1:57939  Dashboard: http://127.0.0.1:53140/status,Cluster  Workers: 8  Cores: 40  Memory: 169.33 GB


In [2]:
import numpy as np
import xarray as xr
import pandas as pd
from dask.diagnostics import ProgressBar
import os.path as op
import os
import xrft
import gsw
import time
from scipy.interpolate import PchipInterpolator, interp1d
from xgcm.grid import Grid
from xmitgcm import open_mdsdataset
from MITgcmutils import jmd95
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
from xlayers import finegrid, layers
from xlayers.core import layers_numpy, layers_apply

In [4]:
ddir = '/tank/chaocean/'
savedir = '/tank/topog/tuchida/TWA/'

In [5]:
grav = 9.81
nensembs = 5
Kh = 20.
K4 = 1e10
Kr = 1e-5
thetaMax = 1e20
rhoConst = 9.998e2
Cp = 3.994e3
ySstart = -15
yNstart = 10
yend = 50

xchunk = 250
ychunk = 400
membchunk = 1
dnf = 2160

In [6]:
for nmemb in range(24,36):
    if nmemb == 24:
        dsnap = open_mdsdataset(op.join(ddir,'qjamet/RUNS/ORAR/memb%02d/run1963/ocn/' % nmemb), 
                                grid_dir=op.join(ddir,'grid_chaO/gridMIT_update1/'),
                                prefix=['diag_ocnSnap',
                                        'diag_Tbgt2D','diag_Tbgt3D',
                                        'diag_Sbgt2D','diag_Sbgt3D'], delta_t=2e2,
#                                 iters=range(943920-2160,943920+2160,2160),
                                iters=943920, 
                                chunks={'XC':xchunk,'XG':xchunk}
                                ).isel(time=-1)
#         dsave = open_mdsdataset(op.join(ddir,'qjamet/RUNS/ORAR/memb%02d/run1963/ocn/' % nmemb), 
#                                 grid_dir=op.join(ddir,'grid_chaO/gridMIT_update1/'),
#                                 prefix=['diag_ocnTave','diag_ocnSurf'], delta_t=2e2,
#         #                              iters=range(788400+2160,943920+2160,2160)
#                                 iters=943920, chunks={'XC':xchunk,'XG':xchunk}
#                                 ).isel(time=-1)
    else:
        dsnap = xr.concat([dsnap, open_mdsdataset(op.join(ddir,'qjamet/RUNS/ORAR/memb%02d/run1963/ocn/' 
                                                          % nmemb), 
                                                  grid_dir=op.join(ddir,'grid_chaO/gridMIT_update1/'),
                                                  prefix=['diag_ocnSnap',
                                                          'diag_Tbgt2D','diag_Tbgt3D',
                                                          'diag_Sbgt2D','diag_Sbgt3D'], delta_t=2e2,
#                                                   iters=range(943920-2160,943920+2160,2160),
                                                  iters=943920, 
                                                  chunks={'XC':xchunk,'XG':xchunk}
                                                 ).isel(time=-1)],
                          dim='nmemb')
#         dsave = xr.concat([dsave, open_mdsdataset(op.join(ddir,'qjamet/RUNS/ORAR/memb%02d/run1963/ocn/' 
#                                                           % nmemb), 
#                                                   grid_dir=op.join(ddir,'grid_chaO/gridMIT_update1/'),
#                                                   prefix=['diag_ocnTave','diag_ocnSurf'], delta_t=2e2,
#         #                                             iters=range(788400+2160,943920+2160,2160)
#                                                   iters=943920, chunks={'XC':xchunk,'XG':xchunk}
#                                                  ).isel(time=-1)],
#                           dim='nmemb')
dsnap

In [7]:
grid = Grid(dsnap, periodic=['X'])
grid

<xgcm.Grid>
X Axis (periodic):
  * center   XC --> left
  * left     XG --> center
Z Axis (not periodic):
  * center   Z --> left
  * left     Zl --> center
  * outer    Zp1 --> center
  * right    Zu --> center
Y Axis (not periodic):
  * center   YC --> left
  * left     YG --> center

In [8]:
fG = xr.apply_ufunc(gsw.f, dsnap.YG, dask='parallelized')
fC = xr.apply_ufunc(gsw.f, dsnap.YC, dask='parallelized')

In [9]:
rho_snap = xr.apply_ufunc(jmd95.densjmd95, dsnap.SALT, dsnap.THETA,
                          (dsnap.PHIHYD-grav*dsnap.Z)*rhoConst*1e-4, # pressure in [dbar]!!!!
                          dask='parallelized', output_dtypes=[float,]
                         ).where(dsnap.maskC!=0.)
rho_snapx = xr.apply_ufunc(jmd95.densjmd95, grid.interp(dsnap.SALT,'X'), 
                           grid.interp(dsnap.THETA,'X'), 
                           grid.interp((dsnap.PHIHYD-grav*dsnap.Z)*rhoConst*1e-4,'X'), 
                           dask='parallelized', output_dtypes=[float,]
                          ).where(dsnap.maskW!=0.)
rho_snapy = xr.apply_ufunc(jmd95.densjmd95, grid.interp(dsnap.SALT,'Y',boundary='fill'), 
                           grid.interp(dsnap.THETA,'Y',boundary='fill'), 
                           grid.interp((dsnap.PHIHYD-grav*dsnap.Z)*rhoConst*1e-4,
                                       'Y',boundary='fill'),
                           dask='parallelized', output_dtypes=[float,]
                          ).where(dsnap.maskS!=0.)

In [10]:
dpdx_snap = (grid.diff(dsnap.PHIHYD.where(dsnap.maskC!=0.) 
                       * grid.interp(dsnap.dyG,'X'),'X'
                      ) * dsnap.rAw**-1
            ).where(dsnap.maskW!=0.)
dpdy_snap = (grid.diff(dsnap.PHIHYD.where(dsnap.maskC!=0.) 
                       * grid.interp(dsnap.dxG,'Y',boundary='fill'),
                       'Y',boundary='fill'
                      ) * dsnap.rAs**-1
            ).where(dsnap.maskS!=0.)

In [11]:
swfrac = .62 * np.exp(dsnap.Zl/.6) + (1-.62) * np.exp(dsnap.Zl/20.)
swfrac1 = .62 * np.exp(dsnap.Zl.shift(Zl=-1)/.6) + (1-.62) * np.exp(dsnap.Zl.shift(Zl=-1)/20.)
Qsw = (dsnap.oceQsw / (rhoConst*Cp) / (dsnap.drF*dsnap.hFacC) 
       * (swfrac - swfrac1).data
      ).where(dsnap.maskC!=0.)
Tflx = ((dsnap.TFLUX - dsnap.oceQsw)
        / (rhoConst*Cp*dsnap.drF[0]*dsnap.hFacC.isel(Z=0))
       ).where(dsnap.maskInC!=0.)
tsurf_corr = 0.                  # linFSConverveTr = F.
Surf_corr = (tsurf_corr - dsnap.WTHMASS[:,0].where(dsnap.maskInC!=0.)
            ) / (dsnap.drF[0]*dsnap.hFacC.isel(Z=0))
DibaT = dsnap.Diss_TH.where(dsnap.maskC!=0.) + Qsw
DibaT = DibaT + xr.concat([xr.DataArray((Tflx+Surf_corr).data.reshape((len(dsnap.nmemb),1,
                                                                       len(dsnap.YC),len(dsnap.XC))),
                                            dims=['nmemb','Z','YC','XC'],
                                            coords={'nmemb':dsnap.nmemb.data,
                                                    'Z':np.array([dsnap.Z[0].data]),
                                                    'YC':dsnap.YC.data,'XC':dsnap.XC.data}),
                               xr.zeros_like(dsnap.Diss_TH.isel(Z=slice(1,None))
                                            ).reset_coords(drop=True)],
                              dim='Z')
# DibaT

# ##################
Sflx = (dsnap.SFLUX
        / (rhoConst*dsnap.drF[0]*dsnap.hFacC.isel(Z=0))
       ).where(dsnap.maskInC!=0.)
ssurf_corr = 0.                  # linFSConverveTr = F.
Surf_corr = (ssurf_corr - dsnap.WSLTMASS[:,0].where(dsnap.maskInC!=0.)
            ) / (dsnap.drF[0]*dsnap.hFacC.isel(Z=0))
DibaS = dsnap.Diss_SLT.where(dsnap.maskC!=0.)
DibaS = DibaS + xr.concat([xr.DataArray((Sflx+Surf_corr).data.reshape((len(dsnap.nmemb),1,
                                                                       len(dsnap.YC),len(dsnap.XC))),
                                            dims=['nmemb','Z','YC','XC'],
                                            coords={'nmemb':dsnap.nmemb.data,
                                                    'Z':np.array([dsnap.Z[0].data]),
                                                    'YC':dsnap.YC.data,'XC':dsnap.XC.data}),
                               xr.zeros_like(dsnap.Diss_SLT.isel(Z=slice(1,None))
                                            ).reset_coords(drop=True)],
                              dim='Z')
DibaS

In [12]:
drf_finer, mapindex, mapfact, cellindex = finegrid.finegrid(np.squeeze(dsnap.drF.sel(Z=slice(None,
                                                                                             None))),
                                                            np.squeeze(dsnap.drC.sel(Zp1=slice(None,
                                                                                               None))),
                                                            10)

In [13]:
Dsig = 30
nlayers = 60

siglayers = np.linspace(20,50,nlayers)
func = interp1d(range(1,nlayers+1), siglayers, fill_value='extrapolate')
sigp1layers = func(np.arange(.5,nlayers+1.5,1))

In [14]:
zz, _ = xr.broadcast(dsnap.Z, rho_snap.isel(Z=0,nmemb=0).sel(YC=slice(ystart-12**-1,yend+2*12**-1)))
zzy, _ = xr.broadcast(dsnap.Z, rho_snapy.isel(Z=0,nmemb=0).sel(YG=slice(ystart-12**-1,yend+2*12**-1)))
zzx, _ = xr.broadcast(dsnap.Z, rho_snapx.isel(Z=0,nmemb=0).sel(YC=slice(ystart-12**-1,yend+2*12**-1)))

In [15]:
zdz_snap = xr.DataArray(np.ones((len(dsnap.nmemb),
                                 len(dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1))),
                                 len(dsnap.XC),len(siglayers))),
                        dims=['nmemb','YC','XC','sig'],
                        coords={'nmemb':range(len(dsnap.nmemb),),
                                'YC':dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1)).data,
                                'XC':dsnap.XC.data,'sig':siglayers}
                       ) * np.nan
dzetaF_snap = zdz_snap.copy()

for tt in range(len(dsnap.nmemb)):
    tmp1 = xr.apply_ufunc(layers_numpy, 
                          xr.ones_like(zz).where(dsnap.maskC!=0.).sel(YC=slice(ystart-12**-1,
                                                                               yend+2*12**-1)),
                          (rho_snap.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,
                                                                          yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized', 
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()

    tmp2 = xr.apply_ufunc(layers_numpy, 
                          zz.chunk({'XC':xchunk}),
                          (rho_snap.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,
                                                                          yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized', 
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()


    dzetaF_snap[tt] = tmp1.data
    zdz_snap[tt] = tmp2.data

    del tmp1, tmp2


# dzetaF_snap = dzetaF_snap.chunk({'nmemb':1,'sig':10,'XC':xchunk})
# zdz_snap = zdz_snap.chunk({'nmemb':1,'sig':10,'XC':xchunk})

In [16]:
zxdz_snap = xr.DataArray(np.ones((len(dsnap.nmemb),
                                  len(dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1))),
                                  len(dsnap.XG),len(siglayers))),
                         dims=['nmemb','YC','XG','sig'],
                         coords={'nmemb':range(len(dsnap.nmemb),),
                                 'YC':dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1)).data,
                                 'XG':dsnap.XG.data,'sig':siglayers}
                        ) * np.nan
dzetaFx_snap = zxdz_snap.copy()


for tt in range(len(dsnap.nmemb)):
    tmp3 = xr.apply_ufunc(layers_numpy, 
                          xr.ones_like(zzx).chunk({'XG':xchunk}),
                          (rho_snapx.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,
                                                                           yend+2*12**-1)
                                                                 ).chunk({'XG':xchunk}),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized', 
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()
    

    tmp4 = xr.apply_ufunc(layers_numpy, 
                          zzx.chunk({'XG':xchunk}),
                          (rho_snapx.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,
                                                                           yend+2*12**-1)
                                                                 ).chunk({'XG':xchunk}),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized', 
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()


    dzetaFx_snap[tt] = tmp3.data
    zxdz_snap[tt] = tmp4.data

    del tmp3, tmp4


# dzetaFx_snap = dzetaFx_snap.chunk({'nmemb':1,'sig':10,'XG':xchunk})
# zxdz_snap = zxdz_snap.chunk({'nmemb':1,'sig':10,'XG':xchunk})

In [17]:
zydz_snap = xr.DataArray(np.ones((len(dsnap.nmemb),
                                  len(dsnap.YG.sel(YG=slice(ystart-12**-1,yend+2*12**-1))),
                                  len(dsnap.XC),len(siglayers))),
                         dims=['nmemb','YG','XC','sig'],
                         coords={'nmemb':range(len(dsnap.nmemb),),
                                 'YG':dsnap.YG.sel(YG=slice(ystart-12**-1,yend+2*12**-1)).data,
                                 'XC':dsnap.XC.data,'sig':siglayers}
                        ) * np.nan
dzetaFy_snap = zydz_snap.copy()


for tt in range(len(dsnap.nmemb)):
    tmp5 = xr.apply_ufunc(layers_numpy, 
                          xr.ones_like(zzy).chunk({'XC':xchunk}),
                          (rho_snapy.isel(nmemb=tt)-rhoConst).sel(YG=slice(ystart-12**-1,
                                                                           yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized', 
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()
    
    tmp6 = xr.apply_ufunc(layers_numpy, 
                          zzy.chunk({'XC':xchunk}),
                          (rho_snapy.isel(nmemb=tt)-rhoConst).sel(YG=slice(ystart-12**-1,
                                                                           yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized', 
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()


    dzetaFy_snap[tt] = tmp5.data
    zydz_snap[tt] = tmp6.data

    del tmp5, tmp6


# dzetaFy_snap = dzetaFy_snap.chunk({'nmemb':1,'sig':10,'XC':xchunk})
# zydz_snap = zydz_snap.chunk({'nmemb':1,'sig':10,'XC':xchunk})

In [18]:
diaTdz = xr.DataArray(np.ones((len(dsnap.nmemb),
                               len(dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1))),
                               len(dsnap.XC),len(siglayers))),
                      dims=['nmemb','YC','XC','sig'],
                      coords={'nmemb':range(len(dsnap.nmemb),),
                              'YC':dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1)).data,
                              'XC':dsnap.XC.data,'sig':siglayers}
                     ) * np.nan
diaSdz = diaTdz.copy()


for tt in range(len(dsnap.nmemb)):
    tmp3 = xr.apply_ufunc(layers_numpy, 
                          DibaT.isel(nmemb=tt).where(dsnap.maskC!=0.).sel(YC=slice(ystart-12**-1,
                                                                                   yend+2*12**-1)
                                                                         ).chunk({'XC':xchunk,'Z':-1}),
                          (rho_snap.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()
    
    tmp4 = xr.apply_ufunc(layers_numpy, 
                          DibaS.isel(nmemb=tt).where(dsnap.maskC!=0.).sel(YC=slice(ystart-12**-1,
                                                                                   yend+2*12**-1)
                                                                         ).chunk({'XC':xchunk,'Z':-1}),
                          (rho_snap.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()
    
        
    diaTdz[tt] = tmp3.data
    diaSdz[tt] = tmp4.data
    
    del tmp3, tmp4


diaTdz = diaTdz.chunk({'nmemb':1,'sig':10,'XC':xchunk})
diaSdz = diaSdz.chunk({'nmemb':1,'sig':10,'XC':xchunk})

In [19]:
sdz = xr.DataArray(np.ones((len(dsnap.nmemb),
                            len(dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1))),
                            len(dsnap.XC),len(siglayers))),
                   dims=['nmemb','YC','XC','sig'],
                   coords={'nmemb':range(len(dsnap.nmemb),),
                           'YC':dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1)).data,
                           'XC':dsnap.XC.data,'sig':siglayers}
                  ) * np.nan
tdz = sdz.copy()


for tt in range(len(dsnap.nmemb)):
    tmp5 = xr.apply_ufunc(layers_numpy, 
                          dsnap.THETA.isel(nmemb=tt).where(dsnap.maskC!=0.).sel(YC=slice(ystart-12**-1,
                                                                                         yend+2*12**-1)
                                                                               ).chunk({'XC':xchunk}),
                          (rho_snap.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()
    
    tmp6 = xr.apply_ufunc(layers_numpy, 
                          dsnap.SALT.isel(nmemb=tt).where(dsnap.maskC!=0.).sel(YC=slice(ystart-12**-1,
                                                                                        yend+2*12**-1)
                                                                              ).chunk({'XC':xchunk}),
                          (rho_snap.isel(nmemb=tt)-rhoConst).sel(YC=slice(ystart-12**-1,yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()

        
    tdz[tt] = tmp5.data
    sdz[tt] = tmp6.data
    
    del tmp5, tmp6


tdz = tdz.chunk({'nmemb':1,'sig':10,'XC':xchunk})
sdz = sdz.chunk({'nmemb':1,'sig':10,'XC':xchunk})

In [20]:
pxdz_snap = xr.DataArray(np.ones((len(dsnap.nmemb),
                                  len(dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1))),
                                  len(dsnap.XG),len(siglayers))),
                         dims=['nmemb','YC','XG','sig'],
                         coords={'nmemb':range(len(dsnap.nmemb),),
                                 'YC':dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1)).data,
                                 'XG':dsnap.XG.data,'sig':siglayers}
                        ) * np.nan
pydz_snap = xr.DataArray(np.ones((len(dsnap.nmemb),
                                  len(dsnap.YG.sel(YG=slice(ystart-12**-1,yend+2*12**-1))),
                                  len(dsnap.XC),len(siglayers))),
                         dims=['nmemb','YG','XC','sig'],
                         coords={'nmemb':range(len(dsnap.nmemb),),
                                 'YG':dsnap.YG.sel(YG=slice(ystart-12**-1,yend+2*12**-1)).data,
                                 'XC':dsnap.XC.data,'sig':siglayers}
                        ) * np.nan


for tt in range(len(dsnap.nmemb)):
    tmp1 = xr.apply_ufunc(layers_numpy, 
                          dpdx_snap[tt].sel(YC=slice(ystart-12**-1,yend+2*12**-1)
                                           ).chunk({'XG':xchunk}),
                          (rho_snapx[tt]-rhoConst).sel(YC=slice(ystart-12**-1,yend+2*12**-1)
                                                      ).chunk({'XG':xchunk}),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()
    
    tmp2 = xr.apply_ufunc(layers_numpy, 
                          dpdy_snap[tt].sel(YG=slice(ystart-12**-1,yend+2*12**-1)),
                          (rho_snapy[tt]-rhoConst).sel(YG=slice(ystart-12**-1,yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()


    pxdz_snap[tt] = tmp1
    pydz_snap[tt] = tmp2
    
    del tmp1, tmp2

    
# pxdz_snap = pxdz_snap.chunk({'nmemb':1,'sig':10,'XG':xchunk})
# pydz_snap = pydz_snap.chunk({'nmemb':1,'sig':10,'XC':xchunk})

In [21]:
udz_snap = xr.DataArray(np.ones((len(dsnap.nmemb),
                                 len(dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1))),
                                 len(dsnap.XG),len(siglayers))),
                        dims=['nmemb','YC','XG','sig'],
                        coords={'nmemb':range(len(dsnap.nmemb),),
                                'YC':dsnap.YC.sel(YC=slice(ystart-12**-1,yend+2*12**-1)).data,
                                'XG':dsnap.XG.data,'sig':siglayers}
                       ) * np.nan
vdz_snap = xr.DataArray(np.ones((len(dsnap.nmemb),
                                 len(dsnap.YG.sel(YG=slice(ystart-12**-1,yend+2*12**-1))),
                                 len(dsnap.XC),len(siglayers))),
                        dims=['nmemb','YG','XC','sig'],
                        coords={'nmemb':range(len(dsnap.nmemb),),
                                'YG':dsnap.YG.sel(YG=slice(ystart-12**-1,yend+2*12**-1)).data,
                                'XC':dsnap.XC.data,'sig':siglayers}
                       ) * np.nan

for tt in range(len(dsnap.nmemb)):
    tmp1 = xr.apply_ufunc(layers_numpy, 
                          dsnap.UVEL[tt].sel(YC=slice(ystart-12**-1,yend+2*12**-1)),
                          (rho_snapx[tt]-rhoConst).sel(YC=slice(ystart-12**-1,yend+2*12**-1)
                                                      ).chunk({'XG':xchunk}),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()
    
    tmp2 = xr.apply_ufunc(layers_numpy, 
                          dsnap.VVEL[tt].sel(YG=slice(ystart-12**-1,yend+2*12**-1)),
                          (rho_snapy[tt]-rhoConst).sel(YG=slice(ystart-12**-1,yend+2*12**-1)),
                          kwargs={'thetalayers':siglayers,'mapfact':mapfact,
                                  'mapindex':mapindex,'cellindex':cellindex,
                                  'drf_finer':drf_finer},
                          dask='parallelized',
                          input_core_dims=[['Z'],['Z']], 
                          output_core_dims=[['sig']],
                          output_dtypes=[float,], output_sizes={'sig':siglayers.size}
                         ).compute()


    udz_snap[tt] = tmp1
    vdz_snap[tt] = tmp2

    del tmp1, tmp2

    
# udz_snap = udz_snap.chunk({'nmemb':1,'sig':10,'XG':xchunk})
# vdz_snap = vdz_snap.chunk({'nmemb':1,'sig':10,'XC':xchunk})

In [22]:
dsave = zdz_snap.to_dataset(name='zdz')
dsave['zydz'] = zydz_snap
dsave['zxdz'] = zxdz_snap
dsave['dzeta'] = dzetaF_snap
dsave['dzetax'] = dzetaFx_snap
dsave['dzetay'] = dzetaFy_snap
# dsave['dzetaz'] = dzetaFz_snap
dsave.coords['sigp1'] = ('sigp1',sigp1layers)
dsave.to_netcdf(op.join(savedir,'%4d/Zs_membs_%010d.nc' % (1963,943920)))
dsave.close()

NameError: name 'grids' is not defined

In [25]:
dssig = zdz_snap.to_dataset(name='zdz')
# dssig['udz'] = udz
# dssig['vdz'] = vdz
dssig.coords['YG'] = ('YG',dsnap.YG.sel(YG=slice(ystart-12**-1,yend+2*12**-1)).data)
dssig.coords['XG'] = ('XG',dsnap.XG.data)
dssig.coords['sigp1'] = ('sigp1',sigp1layers)

buoyb = -grav * dssig.sig * rhoConst**-1
buoybp1 = -grav * dssig.sigp1 * rhoConst**-1

grids = Grid(dssig, periodic=['X'], 
             coords={'Z':{'center':'sig','outer':'sigp1'},
                     'Y':{'center':'YC','left':'YG'},
                     'X':{'center':'XC','left':'XG'}}
            )

In [8]:
dbdT = grids.diff(tdz*dzetaF_snap**-1,'Z',boundary='fill')**-1 * grids.diff(buoyb,'Z',boundary='fill')
dbds = grids.diff(sdz*dzetaF_snap**-1,'Z',boundary='fill')**-1 * grids.diff(buoyb,'Z',boundary='fill')
varpi = ((grids.interp(dbdT,'Z',boundary='fill') * diaTdz
          + grids.interp(dbds,'Z',boundary='fill') * diaSdz
         ) * dzetaF_snap**-1).compute()
dsave = vdz_snap.to_dataset(name='vdz')
dsave['udz'] = udz_snap
dsave['varpi'] = varpi
dsave['pxdz'] = pxdz_snap
dsave['pydz'] = pydz_snap
# dsave['tdz'] = tdz_snap
# dsave['sdz'] = sdz_snap
dsave.coords['sigp1'] = ('sigp1',sigp1layers)
dsave.to_netcdf(op.join(savedir,'%4d/Dyn_membs_%010d.nc' % (1963,943920)))
dsave.close() 

NameError: name 'tdz' is not defined

In [8]:
ds = open_mdsdataset(op.join(ddir,'qjamet/RUNS/ORAR/memb%02d/run%4d/ocn/' 
                             % (24,1965)), 
                     grid_dir=op.join(ddir,'grid_chaO/gridMIT_update1/'),
                     iters=1103760,
                     prefix=['diag_ocnSnap'], delta_t=2e2,
                    ).sel(YC=slice(ySstart-12**-1,yend+2*12**-1),
                          YG=slice(ySstart-12**-1,yend+2*12**-1)
                         ).chunk({'XC':xchunk,'XG':xchunk,
                                  'YC':ychunk,'YG':ychunk})
ds

In [9]:
nremap = 90

for year in range(1963,1968):
    if year == 1963:
        ntimes = np.arange( 943920    , 943920+dnf,dnf,dtype=int)
        ystart = ySstart
    else:
        ystart = yNstart
        if year == 1964:
            ntimes = np.arange( 943920+dnf,1101600+dnf,dnf,dtype=int)
        elif year == 1965:
            ntimes = np.arange(1101600+dnf,1259280+dnf,dnf,dtype=int)
        elif year == 1966:
            ntimes = np.arange(1259280+dnf,1416960+dnf,dnf,dtype=int)
        else:
            ntimes = np.arange(1416960+dnf,1574640+dnf,dnf,dtype=int)
        
    for itime in ntimes:

    #     ds1 = xr.open_dataset(op.join(savedir,'%4d/Zs_membs_parallel-%2d_%010d.nc' 
    #                                   % (year,nremap,itime))
    #                          ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
    #                                YG=slice(ystart-12**-1,yend+2*12**-1)
    #                               ).chunk({'nmemb':membchunk,'XC':xchunk,'XG':xchunk})
    #     ds2 = xr.open_dataset(op.join(savedir,'%4d/Dyn_membs_parallel-%2d_%010d.nc' 
    #                                   % (year,nremap,itime))
    #                          ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
    #                                YG=slice(ystart-12**-1,yend+2*12**-1)
    #                               ).chunk({'nmemb':membchunk,'XC':xchunk,'XG':xchunk})
    #     ds3 = xr.open_dataset(op.join(savedir,
    #                                   '%4d/Varpidz-compressible_membs_parallel-%2d_%010d.nc' 
    #                                   % (year,nremap,itime))
    #                          ).sel(YC=slice(ystart-12**-1,yend+2*12**-1)
    #                               ).chunk({'nmemb':membchunk,'XC':xchunk})
        ds1 = xr.open_zarr(op.join(savedir,'%4d/Zs/%07d/' % (year,itime))
                          ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
                                YG=slice(ystart-12**-1,yend+2*12**-1)
                               )
        ds2 = xr.open_zarr(op.join(savedir,'%4d/Dyns/%07d/' % (year,itime))
                          ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
                                YG=slice(ystart-12**-1,yend+2*12**-1)
                               )
        ds3 = xr.open_zarr(op.join(savedir,'%4d/Varpi/%07d/' % (year,itime))
                          ).sel(YC=slice(ystart-12**-1,yend+2*12**-1)
                               )
        buoyb = -grav * ds1.sig * rhoConst**-1
        buoybp1 = -grav * ds1.sigp1 * rhoConst**-1

        grid = Grid(ds1, periodic=['X'], 
                    coords={'Z':{'center':'sig','outer':'sigp1'},
                            'Y':{'center':'YC','left':'YG'},
                            'X':{'center':'XC','left':'XG'}}
                   )

    ###########################    
        sigma = ds1.dzetaF * -grid.diff(buoybp1,'Z',boundary='fill')**-1
        sigmay = ds1.dzetaFy * -grid.diff(buoybp1,'Z',boundary='fill')**-1
        sigmax = ds1.dzetaFx * -grid.diff(buoybp1,'Z',boundary='fill')**-1

        zetab = (ds1.zdz*ds1.dzetaF**-1).mean('nmemb',skipna=True)
        zetaxb = (ds1.zxdz*ds1.dzetaFx**-1).mean('nmemb',skipna=True)
        zetayb = (ds1.zydz*ds1.dzetaFy**-1).mean('nmemb',skipna=True)

        sigmab = (ds1.dzetaF.mean('nmemb',skipna=True) 
                  * -grid.diff(buoybp1,'Z',boundary='fill')**-1
                 )
        sigmaxb = (ds1.dzetaFx.mean('nmemb',skipna=True) 
                   * -grid.diff(buoybp1,'Z',boundary='fill')**-1
                  )
        sigmayb = (ds1.dzetaFy.mean('nmemb',skipna=True) 
                   * -grid.diff(buoybp1,'Z',boundary='fill')**-1
                  )

        vhat = ds2.vdz.mean('nmemb',skipna=True) * ds1.dzetaFy.mean('nmemb',skipna=True)**-1
        uhat = ds2.udz.mean('nmemb',skipna=True) * ds1.dzetaFx.mean('nmemb',skipna=True)**-1
        varpihat = ds3.varpidz.mean('nmemb',skipna=True) * ds1.dzetaF.mean('nmemb',skipna=True)**-1

        vpp = ds2.vdz*ds1.dzetaFy**-1 - vhat
        upp = ds2.udz*ds1.dzetaFx**-1 - uhat
        varpipp = ds3.varpidz*ds1.dzetaF**-1 - varpihat

        zetap = (ds1.zdz*ds1.dzetaF**-1 - zetab) 
        zetayp = (ds1.zydz*ds1.dzetaFy**-1 - zetayb)
        zetaxp = (ds1.zxdz*ds1.dzetaFx**-1 - zetaxb)

        mbx = ds2.pxdz * ds1.dzetaFx**-1
        mby = ds2.pydz * ds1.dzetaFy**-1
        mbxp = mbx - mbx.mean('nmemb',skipna=True)
        mbyp = mby - mby.mean('nmemb',skipna=True)

    ########################### 
    #     tmp00 = ((upp**2*sigmax)
    #              + .5*(zetaxp**2)
    #             ).mean('nmemb',skipna=True)
        tmp00_0 = (upp**2*sigmax).mean('nmemb',skipna=True)
        tmp00_1 = .5*(zetaxp**2).mean('nmemb',skipna=True)
    #     tmp11 = ((vpp**2*sigmay) 
    #              + .5*(zetayp**2)
    #             ).mean('nmemb',skipna=True)
        tmp11_0 = (vpp**2*sigmay).mean('nmemb',skipna=True)
        tmp11_1 = .5*(zetayp**2).mean('nmemb',skipna=True)
        tmp01 = (grid.interp(grid.interp(vpp,'X'),'Y',boundary='fill')
                 * upp*sigmax).mean('nmemb',skipna=True)
        tmp10 = (grid.interp(grid.interp(upp,'X'),'Y',boundary='fill') 
                 * vpp*sigmay).mean('nmemb',skipna=True)
    #     tmp20 = ((varpipp*grid.interp(upp,'X')*sigma)
    #              + grid.interp(mbxp*zetaxp,'X')
    #             ).mean('nmemb',skipna=True)
    #     tmp21 = ((varpipp*grid.interp(vpp,'Y',boundary='fill')*sigma)
    #              + grid.interp(mbyp*zetayp,'Y',boundary='fill')
    #             ).mean('nmemb',skipna=True)
        tmp20_0 = (varpipp*grid.interp(upp,'X')
                   * sigma
                  ).mean('nmemb',skipna=True)
        tmp20_1 = grid.interp(mbxp*zetaxp,'X').mean('nmemb',skipna=True)
        tmp21_0 = (varpipp*grid.interp(vpp,'Y',boundary='fill')
                   * sigma
                  ).mean('nmemb',skipna=True)
        tmp21_1 = grid.interp(mbyp*zetayp,'Y',boundary='fill').mean('nmemb',skipna=True)


    ###########################    
        start = time.time()
        e00 = ((tmp00_0+tmp00_1) * sigmaxb**-1).compute()
        e01 = (tmp01 * sigmaxb**-1).compute()
        e10 = (tmp10 * sigmayb**-1).compute()
        e11 = ((tmp11_0+tmp11_1) * sigmayb**-1).compute()
    #     e20 = ((tmp20_0+tmp20_1) * sigmab**-1).compute()
    #     e21 = ((tmp21_0+tmp21_1) * sigmab**-1).compute()
        e20_0 = (tmp20_0 * sigmab**-1).compute()
        e20_1 = (tmp20_1 * sigmab**-1).compute()
        e21_0 = (tmp21_0 * sigmab**-1).compute()
        e21_1 = (tmp21_1 * sigmab**-1).compute()
    #     dsave = e00.to_dataset(name='e00')
    #     dsave['e01'] = e01
    #     dsave['e10'] = e10
    #     dsave['e11'] = e11
    #     dsave['e20'] = e20
    #     dsave['e21'] = e21
    #     dsave.coords['YG'] = ('YG',ds1.YG.data)
    #     dsave.coords['XG'] = ('XG',ds1.XG.data)
    #     dsave.coords['sigp1'] = ('sigp1',ds1.sigp1)
    #     dsave.to_netcdf(op.join(savedir,'%4d/E-P_flux_parallel-%2d_%010d.nc' 
    #                             % (year,nremap,itime)), mode='w')
    #     dsave.to_zarr(op.join(savedir,'%4d/E-P/%7d' % (year,itime)))
    #     dsave.close()

    #     del e00, e01, e10, e11, e20, e21
    ###########################  
        e00_x = ((grid.diff(tmp00_0*ds.dyG,'X') * ds.rA**-1
                  + grid.diff(tmp00_1,'X') * grid.interp(ds.dxG,'Y',boundary='fill')**-1
                 ) * sigmab**-1).compute()
        e10_y = (grid.diff(tmp10*ds.dxG,'Y',boundary='fill') * ds.rA**-1 
                 * sigmab**-1).compute()
    #     e20_b = (grid.interp(grid.diff((tmp20_0+tmp20_1),'Z',boundary='fill') 
    #                          * grid.diff(buoyb,'Z',boundary='fill')**-1,
    #                          'Z',boundary='fill')
    #              * sigmab**-1).compute()
        e01_x = (grid.diff(tmp01*ds.dyG,'X') * ds.rA**-1 
                 * sigmab**-1).compute()
        e11_y = ((grid.diff(tmp11_0*ds.dxG,'Y',boundary='fill') * ds.rA**-1
                  + grid.diff(tmp11_1,'Y',boundary='fill') * grid.interp(ds.dyG,'X')**-1
                 ) * sigmab**-1).compute()
    #     e21_b = (grid.interp(grid.diff((tmp21_0+tmp21_1),'Z',boundary='fill') 
    #                          * grid.diff(buoyb,'Z',boundary='fill')**-1,
    #                          'Z',boundary='fill') 
    #              * sigmab**-1).compute()
        e20_0b = (grid.interp(grid.diff(tmp20_0,'Z',boundary='fill') 
                              * grid.diff(buoyb,'Z',boundary='fill')**-1,
                              'Z',boundary='fill')
                  * sigmab**-1).compute()
        e20_1b = (grid.interp(grid.diff(tmp20_1,'Z',boundary='fill') 
                              * grid.diff(buoyb,'Z',boundary='fill')**-1,
                              'Z',boundary='fill')
                  * sigmab**-1).compute()
        e21_0b = (grid.interp(grid.diff(tmp21_0,'Z',boundary='fill') 
                              * grid.diff(buoyb,'Z',boundary='fill')**-1,
                              'Z',boundary='fill')
                  * sigmab**-1).compute()
        e21_1b = (grid.interp(grid.diff(tmp21_1,'Z',boundary='fill') 
                              * grid.diff(buoyb,'Z',boundary='fill')**-1,
                              'Z',boundary='fill')
                  * sigmab**-1).compute()

        end = time.time()
        print("Lapse time:", end-start)

    ########################### 
        dsave = e00.to_dataset(name='e00')
        dsave['e01'] = e01
        dsave['e10'] = e10
        dsave['e11'] = e11
    #     dsave['e20'] = e20_0 + e20_1
    #     dsave['e21'] = e21_0 + e21_1
        dsave['e00x'] = e00_x
        dsave['e01x'] = e01_x
        dsave['e10y'] = e10_y
        dsave['e11y'] = e11_y
    #     dsave['e20b'] = e20_0b + e20_1b
    #     dsave['e21b'] = e21_0b + e21_1b
        dsave['e20_0'] = e20_0
        dsave['e20_1'] = e20_1
        dsave['e21_0'] = e21_0
        dsave['e21_1'] = e21_1
        dsave['e20_0b'] = e20_0b
        dsave['e20_1b'] = e20_1b
        dsave['e21_0b'] = e21_0b
        dsave['e21_1b'] = e21_1b
        dsave.coords['YG'] = ('YG',ds1.YG.data)
        dsave.coords['XG'] = ('XG',ds1.XG.data)
        dsave.coords['sigp1'] = ('sigp1',ds1.sigp1)
    #     dsave.to_netcdf(op.join(savedir,'%4d/divE-P_flux_parallel-%2d_%010d.nc' 
    #                             % (year,nremap,itime)), mode='w')
        dsave.to_zarr(op.join(savedir,'%4d/E-P/%07d/' % (year,itime)), mode='w')
        dsave.close()

        del e00, e01, e10, e11
        del e00_x, e01_x, e10_y, e11_y
        del e20_0, e20_1, e21_0, e21_1, e20_0b, e20_1b, e21_0b, e21_1b
        ds1.close()
        ds2.close()
        ds3.close()
        print(itime)

    # client.restart()

Lapse time: 637.5940260887146
943920


distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError


In [8]:
year = 1963
nremap = 80
ystart = ySstart

ntimes = np.arange( 943920    , 943920+dnf,dnf,dtype=int)
# ntimes = np.arange( 943920+dnf,1101600+dnf,dnf,dtype=int)
# ntimes = np.arange(1101600+dnf,1259280+dnf,dnf,dtype=int)
# ntimes = np.arange(1259280+dnf,1416960+dnf,dnf,dtype=int)
# ntimes = np.arange(1416960+dnf,1574640+dnf,dnf,dtype=int)
for itime in ntimes:
    
    ds1 = xr.open_dataset(op.join(savedir,'%4d/Zs_membs_parallel-%2d_%010d.nc' 
                                  % (year,nremap,itime))
                         ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
                               YG=slice(ystart-12**-1,yend+2*12**-1)
                              ).chunk({'nmemb':membchunk,'XC':xchunk,'XG':xchunk})
    ds2 = xr.open_dataset(op.join(savedir,'%4d/Dyn_membs_parallel-%2d_%010d.nc' 
                                  % (year,nremap,itime))
                         ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
                               YG=slice(ystart-12**-1,yend+2*12**-1)
                              ).chunk({'nmemb':membchunk,'XC':xchunk,'XG':xchunk})
    ds3 = xr.open_dataset(op.join(savedir,'%4d/Varpidz-compressible_membs_parallel-%2d_%010d.nc' 
                                  % (year,nremap,itime))
                         ).sel(YC=slice(ystart-12**-1,yend+2*12**-1)
                              ).chunk({'nmemb':membchunk,'XC':xchunk})
#     ds1 = xr.open_zarr(op.join(savedir,'%4d/Zs/%7d' % (year,itime))
#                       ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
#                             YG=slice(ystart-12**-1,yend+2*12**-1)
#                            )
#     ds2 = xr.open_zarr(op.join(savedir,'%4d/Dyns/%7d' % (year,itime))
#                       ).sel(YC=slice(ystart-12**-1,yend+2*12**-1),
#                             YG=slice(ystart-12**-1,yend+2*12**-1)
#                            )
#     ds3 = xr.open_zarr(op.join(savedir,'%4d/Varpi/%7d' % (year,itime))
#                       ).sel(YC=slice(ystart-12**-1,yend+2*12**-1)
#                            )
    buoyb = -grav * ds1.sig * rhoConst**-1
    buoybp1 = -grav * ds1.sigp1 * rhoConst**-1

    grid = Grid(ds1, periodic=['X'], 
                coords={'Z':{'center':'sig','outer':'sigp1'},
                        'Y':{'center':'YC','left':'YG'},
                        'X':{'center':'XC','left':'XG'}}
               )
    
###########################    
    sigma = ds1.dzetaF * -grid.diff(buoybp1,'Z',boundary='fill')**-1
    sigmay = ds1.dzetaFy * -grid.diff(buoybp1,'Z',boundary='fill')**-1
    sigmax = ds1.dzetaFx * -grid.diff(buoybp1,'Z',boundary='fill')**-1

    zetab = (ds1.zdz*ds1.dzetaF**-1).mean('nmemb',skipna=True)
    zetaxb = (ds1.zxdz*ds1.dzetaFx**-1).mean('nmemb',skipna=True)
    zetayb = (ds1.zydz*ds1.dzetaFy**-1).mean('nmemb',skipna=True)
    
    sigmab = (ds1.dzetaF.mean('nmemb',skipna=True) 
              * -grid.diff(buoybp1,'Z',boundary='fill')**-1
             )
    sigmaxb = (ds1.dzetaFx.mean('nmemb',skipna=True) 
               * -grid.diff(buoybp1,'Z',boundary='fill')**-1
              )
    sigmayb = (ds1.dzetaFy.mean('nmemb',skipna=True) 
               * -grid.diff(buoybp1,'Z',boundary='fill')**-1
              )
    
    vhat = ds2.vdz.mean('nmemb',skipna=True) * ds1.dzetaFy.mean('nmemb',skipna=True)**-1
    uhat = ds2.udz.mean('nmemb',skipna=True) * ds1.dzetaFx.mean('nmemb',skipna=True)**-1
    varpihat = ds3.varpidz.mean('nmemb',skipna=True) * ds1.dzetaF.mean('nmemb',skipna=True)**-1

    vpp = ds2.vdz*ds1.dzetaFy**-1 - vhat
    upp = ds2.udz*ds1.dzetaFx**-1 - uhat
    varpipp = ds3.varpidz*ds1.dzetaF**-1 - varpihat

    zetap = (ds1.zdz*ds1.dzetaF**-1 - zetab) 
    zetayp = (ds1.zydz*ds1.dzetaFy**-1 - zetayb)
    zetaxp = (ds1.zxdz*ds1.dzetaFx**-1 - zetaxb)

    mbx = ds2.pxdz * ds1.dzetaFx**-1
    mby = ds2.pydz * ds1.dzetaFy**-1
    mbxp = mbx - mbx.mean('nmemb',skipna=True)
    mbyp = mby - mby.mean('nmemb',skipna=True)
########################### 
    tmp20_0 = (varpipp*grid.interp(upp,'X')
               * sigma
              ).mean('nmemb',skipna=True)
    tmp20_1 = grid.interp(mbxp*zetaxp,'X').mean('nmemb',skipna=True)
    tmp21_0 = (varpipp*grid.interp(vpp,'Y',boundary='fill')
               * sigma
              ).mean('nmemb',skipna=True)
    tmp21_1 = grid.interp(mbyp*zetayp,'Y',boundary='fill').mean('nmemb',skipna=True)

###########################    
    start = time.time()
    e20_0 = (tmp20_0 * sigmab**-1).compute()
    e20_1 = (tmp20_1 * sigmab**-1).compute()
    e21_0 = (tmp21_0 * sigmab**-1).compute()
    e21_1 = (tmp21_1 * sigmab**-1).compute()

###########################  
    e20_0b = (grid.interp(grid.diff(tmp20_0,'Z',boundary='fill') 
                          * grid.diff(buoyb,'Z',boundary='fill')**-1,
                          'Z',boundary='fill')
              * sigmab**-1).compute()
    e20_1b = (grid.interp(grid.diff(tmp20_1,'Z',boundary='fill') 
                          * grid.diff(buoyb,'Z',boundary='fill')**-1,
                          'Z',boundary='fill')
              * sigmab**-1).compute()
    e21_0b = (grid.interp(grid.diff(tmp21_0,'Z',boundary='fill') 
                          * grid.diff(buoyb,'Z',boundary='fill')**-1,
                          'Z',boundary='fill')
              * sigmab**-1).compute()
    e21_1b = (grid.interp(grid.diff(tmp21_1,'Z',boundary='fill') 
                          * grid.diff(buoyb,'Z',boundary='fill')**-1,
                          'Z',boundary='fill')
              * sigmab**-1).compute()
    
    end = time.time()
    print("Lapse time:", end-start)
    
    dsave = e20_0.to_dataset(name='varpipup')
    dsave['zpmxp'] = e20_1
    dsave['varpipvp'] = e21_0
    dsave['zpmyp'] = e21_1
    dsave['varpipup_b'] = e20_0b
    dsave['zpmxp_b'] = e20_1b
    dsave['varpipvp_b'] = e21_0b
    dsave['zpmyp_b'] = e21_1b
    dsave.coords['YG'] = ('YG',ds1.YG.data)
    dsave.coords['XG'] = ('XG',ds1.XG.data)
    dsave.coords['sigp1'] = ('sigp1',ds1.sigp1)
#     dsave.to_netcdf(op.join(savedir,'%4d/divE-P_flux_parallel-%2d_%010d.nc' 
#                             % (year,nremap,itime)), mode='w')
    dsave.to_zarr(op.join(savedir,'%4d/Adiab_barocli/%07d' % (year,itime)))
    dsave.close()
    
    del e20_0, e20_1, e21_0, e21_1
    del e20_0b, e21_0b, e20_1b, e21_1b
    ds1.close()
    ds2.close()
    ds3.close()
    print(itime)
    
# client.restart()

KeyboardInterrupt: 



In [9]:
dsave.to_zarr(op.join(savedir,'%4d/Adiab_barocli/%07d' % (year,itime)))
dsave.close()