In [None]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:46063")
client

In [2]:
import numpy as np
import xarray as xr
import pandas as pd
import os.path as op
import os
import gsw
from xgcm.grid import Grid
from xmitgcm import open_mdsdataset
from fastjmd95 import rho as densjmd95
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
ddir = '/tank/chaocean/qjamet/RUNS/ORAR/reruns2/'
gdir = '/tank/chaocean/grid_chaO/gridMIT_update1/'
chaos = '/tank/chaocean/tuchida/'
# savedir = '/tank/chaocean/tuchida/ECycle/'
savedir = '/tank/spectre/tuchida/chaocean/Snap/'

In [4]:
grav = 9.81
nensembs = 5
Kh = 20.
K4 = 1e10
Kr = 1e-5
thetaMax = 1e20
rhoConst = 9.998e2
# Cp = 3.994e3
Cp = 3989.244953
# years = np.arange(2008,2013,dtype=int)
ySstart = -15
yNstart = 15
yend = 50
dnf = 2160
nfile = 73

xchunk = 100
ychunk = 100

nmembs = np.arange(100,148)

In [5]:
years = np.arange(1967,1968,dtype=int)
ntimes = np.arange(1419120,1576800,dnf).reshape(len(years),nfile)
dd = 43  # 
yystart = 0
ntimes[yystart,dd:]

array([1512000, 1514160, 1516320, 1518480, 1520640, 1522800, 1524960,
       1527120, 1529280, 1531440, 1533600, 1535760, 1537920, 1540080,
       1542240, 1544400, 1546560, 1548720, 1550880, 1553040, 1555200,
       1557360, 1559520, 1561680, 1563840, 1566000, 1568160, 1570320,
       1572480, 1574640])

In [6]:
dsg = open_mdsdataset(op.join(ddir,'memb%03d/run1967/ocn/' 
                              % (100)), 
                      grid_dir=gdir,
                      prefix=['snap_diagOcn'], delta_t=2e2,
                      iters=ntimes[0,0], 
                     ).chunk({'YC':ychunk,'YG':ychunk,
                              'XC':xchunk,'XG':xchunk})
grid = Grid(dsg, periodic=['X'])
Vol = (dsg.rA * dsg.drF * dsg.hFacC).where(dsg.maskC!=0.)
grid

<xgcm.Grid>
Z Axis (not periodic, boundary=None):
  * center   Z --> left
  * outer    Zp1 --> center
  * right    Zu --> center
  * left     Zl --> center
X Axis (periodic, boundary=None):
  * center   XC --> left
  * left     XG --> center
Y Axis (not periodic, boundary=None):
  * center   YC --> left
  * left     YG --> center
T Axis (not periodic, boundary=None):
  * center   time

In [7]:
for itime in ntimes[yystart,dd+1:]:
    for nmemb in nmembs:
        if nmemb == 100:
            ds = open_mdsdataset(op.join(ddir,'memb%03d/run1967/ocn/' 
                                         % (nmemb)), 
                                 grid_dir=gdir,
                                 prefix=['snap_diagOcn'], delta_t=2e2,
                                 iters=itime, 
                                ).isel(time=0).chunk({'YC':ychunk,'YG':ychunk,
                                                      'XC':xchunk,'XG':xchunk})
        else:
            ds = xr.concat([ds, open_mdsdataset(op.join(ddir,'memb%03d/run1967/ocn/' 
                                                        % (nmemb)), 
                                                grid_dir=gdir,
                                                prefix=['snap_diagOcn'], delta_t=2e2,
                                                iters=itime, 
                                               ).isel(time=0).chunk({'YC':ychunk,'YG':ychunk,
                                                                     'XC':xchunk,'XG':xchunk})
                           ], dim='nmemb')

    ds.coords['nmemb'] = ('nmemb',nmembs)
    grid = Grid(ds, periodic=['X'])
    
    dsm = open_mdsdataset(op.join(ddir,'ensm/run1967/ocn/'), 
                          grid_dir=gdir,
                          prefix=['snap_diagOcn'], delta_t=2e2,
                          iters=itime, 
                         ).isel(time=0).chunk({'YC':ychunk,'YG':ychunk,
                                               'XC':xchunk,'XG':xchunk})
    
    wp = ds.WVEL - dsm.WVEL
    vp = ds.VVEL - dsm.VVEL
    up = ds.UVEL - dsm.UVEL
    
    xA = (ds.dyG * ds.drF * ds.hFacW).reset_coords(drop=True)
    yA = (ds.dxG * ds.drF * ds.hFacS).reset_coords(drop=True)

    upTrans = up * xA
    vpTrans = vp * yA
    rpTrans = wp * ds.rA
    Auup = (upTrans
            * up
           )
    Avup = (vpTrans
            * grid.interp(grid.interp(up,'X'),'Y',boundary='extend')
           )
    Awup = (rpTrans
            * grid.interp(grid.interp(up * ds.hFacW,'X'),
                          'Z',boundary='extend')
           )
    
    Gu = ((grid.diff(Auup,'X')
           + grid.diff(Avup,'Y',boundary='extend')
           - grid.diff(Awup,'Z',boundary='extend')
          ) / ds.drF/ds.rA/ds.hFacC
         ).where(ds.maskC!=0.).reset_coords(drop=True)
    
    Auvp = (upTrans
            * grid.interp(grid.interp(vp,'X'),'Y',boundary='extend')
           )
    Avvp = (vpTrans
            * vp
           )
    Awvp = (rpTrans
            * grid.interp(grid.interp(vp * ds.hFacS,'Y',boundary='extend'),
                          'Z',boundary='extend')
           )
    
    Gv = ((grid.diff(Auvp,'X')
           + grid.diff(Avvp,'Y',boundary='extend')
           - grid.diff(Awvp,'Z',boundary='extend')
          ) / ds.drF/ds.rA/ds.hFacC
         ).where(ds.maskC!=0.).reset_coords(drop=True)
    
    Kk = (grid.interp(dsm.UVEL,'X') * Gu.mean('nmemb')
          + grid.interp(dsm.VVEL,'Y',boundary='extend') * Gv.mean('nmemb')
         )
    
    if itime == ntimes[yystart,dd]:
        Kk.isel(Z=1).plot(vmax=1e-7)
        plt.show()
            
    dsave = Kk.reset_coords(drop=True).chunk({'Z':4,'YC':ychunk,'XC':xchunk}
                                            ).to_dataset(name='Kk')
    dsave.to_zarr(op.join(savedir,'%4d/%010d/K-k.zarr' 
                          % (years[yystart],itime)), 
                  mode='w')
    dsave.close()
    ds.close()
    dsm.close()
    del Gu, Gv, Kk
    
    print(itime)

1514160
1516320
1518480
1520640
1522800
1524960
1527120
1529280
1531440
1533600
1535760
1537920
1540080
1542240
1544400
1546560
1548720
1550880
1553040
1555200
1557360
1559520
1561680
1563840
1566000
1568160
1570320
1572480
1574640


In [8]:
ymax = 43
ymin = 14
xmin = 270
xmax = 337
for itime in ntimes[yystart]:
    if itime == ntimes[yystart,0]:
        Kk = xr.open_zarr(op.join(savedir,'%4d/%010d/K-k.zarr' 
                                  % (years[yystart],itime))
                         ).Kk
    else:
        Kk = xr.concat([Kk, xr.open_zarr(op.join(savedir,'%4d/%010d/K-k.zarr' 
                                  % (years[yystart],itime))
                                        ).Kk
                       ], 'time')
    # print(itime)
        
((Kk*Vol).sel(
           Z=slice(None,None),YC=slice(ymin,ymax),
           XC=slice(xmin,xmax)
          ).sum(['Z','YC','XC'],skipna=True)
 / Vol.sel(
           Z=slice(None,None),YC=slice(ymin,ymax),
           XC=slice(xmin,xmax)
          ).sum(['Z','YC','XC'],skipna=True)
).reset_coords(drop=True).chunk({'time':-1}
                               ).to_dataset(name='Kk').to_zarr(op.join(savedir,
                                                '%4d/K-k_VolAve_%2dN-%2dN%3dE-%3dE.zarr' 
                  % (years[yystart],ymin,ymax,xmin,xmax)), mode='w')

<xarray.backends.zarr.ZarrStore at 0x7fd6356b8350>