In [3]:
import numpy as np
import xarray as xr
import intake
from xomega import w_rigid
import xrft
import dask.array as dsar
from xgcm.grid import Grid
from scipy.interpolate import PchipInterpolator as pchip
from dask.diagnostics import ProgressBar
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
cat = intake.Catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean/channel.yaml")
ds = cat["MITgcm_channel_flatbottom_02km_run01_phys_snap15D"].to_dask()
ds

In [None]:
ds_center = ds.sel(YC=slice(5.5e5,14.5e5),YG=slice(5.5e5,14.5e5))

In [None]:
f0 = -1.1e-4
beta = 1.4e-11
g = 9.81
r0 = 1e3
alpha = 2e-4

# Data starts from July 1 every 15 days

In [None]:
date = 4                   # Aug 15

In [None]:
zPew = np.linspace(ds.Zp1[0], ds.Zp1[-1], 200)
dZ = np.float(np.abs(np.diff(zPew)[0]))
grid = Grid(ds_center, periodic=['X'])
fC = f0 + beta*(ds_center.YC-10e5)
fG = f0 + beta*(ds_center.YG-10e5)

u = ds_center.U[date]      # zonal vel
v = ds_center.V[date]      # meridional vel
phi = ds_center.PH[date]   # hydrostatic potential pressure
b = grid.diff(phi,'Z',boundary='fill')/grid.diff(phi.Z,'Z',boundary='fill')
pt = ds_center['T'][date]  # potential temperature
N2 = alpha*g * (grid.diff(pt,'Z',boundary='extend',to='outer')
                / grid.diff(pt.Z,'Z',boundary='fill',to='outer')
               )
psi = phi*fC**-1
phik = xrft.dft(phi.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
psik = xrft.dft(psi.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
kx = 2*np.pi*phik.freq_XC
phix = xr.DataArray(dsar.fft.ifft((1j*phik*kx).chunk(chunks={'freq_XC':500}).data, 
                                  axis=-1).real, 
                   dims=phi.dims, coords=phi.coords)
psix = xr.DataArray(dsar.fft.ifft((1j*psik*kx).chunk(chunks={'freq_XC':500}).data, 
                                  axis=-1).real, 
                   dims=psi.dims, coords=psi.coords)
del phik, psik

phiy = grid.diff(phi,'Y',boundary='fill')/grid.diff(phi.YC,'Y',boundary='fill')
psiy = grid.diff(psi,'Y',boundary='fill')/grid.diff(psi.YC,'Y',boundary='fill')
uy = grid.diff(u,'Y',boundary='fill')/grid.diff(u.YC,'Y',boundary='fill')
vy = grid.diff(v,'Y',boundary='fill')/grid.diff(v.YG,'Y',boundary='fill')

uk = xrft.dft(u.chunk(chunks={'XG':500}), dim=['XG'], shift=False)
vk = xrft.dft(v.chunk(chunks={'XC':500}), dim=['XC'], shift=False) 
ux = xr.DataArray(dsar.fft.ifft((1j*2*np.pi*uk*uk.freq_XG).chunk(chunks={'freq_XG':500}).data, 
                                axis=-1).real, 
                 dims=u.dims, coords=u.coords)
vx = xr.DataArray(dsar.fft.ifft((1j*2*np.pi*vk*vk.freq_XC).chunk(chunks={'freq_XC':500}).data, 
                                axis=-1).real, 
                 dims=v.dims, coords=v.coords)
del uk, vk

ug = -psiy
vg = psix
ugy = grid.diff(ug,'Y',boundary='fill')/grid.diff(ug.YG,'Y',boundary='fill')
vgy = grid.diff(vg,'Y',boundary='fill')/grid.diff(vg.YC,'Y',boundary='fill')
ugk = xrft.dft(ug.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
vgk = xrft.dft(vg.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
ugx = xr.DataArray(dsar.fft.ifft((1j*ugk*kx).chunk(chunks={'freq_XC':500}).data, 
                                axis=-1).real, 
                  dims=psiy.dims, coords=psiy.coords)
vgx = xr.DataArray(dsar.fft.ifft((1j*vgk*kx).chunk(chunks={'freq_XC':500}).data, 
                                axis=-1).real, 
                  dims=psi.dims, coords=psi.coords)
    
bx = grid.diff(phix,'Z',boundary='fill')/grid.diff(phix.Z,'Z',boundary='fill')
by = grid.diff(phiy,'Z',boundary='fill')/grid.diff(phiy.Z,'Z',boundary='fill')


#############################
### Right-hand side terms ###
#############################
Qtgx = -(grid.interp(ugx,'Y',boundary='fill') * grid.interp(bx,'Z',boundary='fill') 
        + vgx * grid.interp(grid.interp(by,'Z',boundary='fill'),'Y',boundary='fill')
        )
Qtgy = -(ugy * grid.interp(bx,'Z',boundary='fill')
        + grid.interp(vgy * grid.interp(by,'Z',boundary='fill'),'Y',boundary='fill')
        )

Qtwx = -(grid.interp(ux,'X')*grid.interp(bx,'Z',boundary='fill') 
        + grid.interp(vx*grid.interp(by,'Z',boundary='fill'),'Y',boundary='fill') 
        )
Qtwy = -(grid.interp(grid.interp(uy,'X'),'Y',boundary='fill')*grid.interp(bx,'Z',boundary='fill')
        + vy*grid.interp(grid.interp(by,'Z',boundary='fill'),'Y',boundary='fill')
        )
    
ug = -psiy
vg = psix
ua = grid.interp(u,'X') - grid.interp(ug,'Y',boundary='fill')
va = grid.interp(v,'Y',boundary='fill') - vg
x1 = grid.interp(vx,'Y',boundary='fill') * grid.interp(grid.diff(ua,'Z',boundary='fill')
                                                      / grid.diff(ua.Z,'Z',boundary='fill'),
                                                      'Z',boundary='fill')
x2 = grid.interp(ux,'X') * grid.interp(grid.diff(va,'Z',boundary='fill')
                                      / grid.diff(va.Z,'Z',boundary='fill'),
                                      'Z',boundary='fill')
Qdagx = (x1-x2) * fC
y1 = vy * grid.interp(grid.diff(ua,'Z',boundary='fill')/ grid.diff(ua.Z,'Z',boundary='fill'),
                     'Z',boundary='fill')
y2 = grid.interp(grid.interp(uy,'X'),'Y',boundary='fill') * grid.interp(grid.diff(va,'Z',boundary='fill')
                                                                        / grid.diff(va.Z,'Z',boundary='fill'),
                                                                        'Z',boundary='fill')
Qdagy = (y1-y2) * fC
del x1, x2, y1, y2, ux, vx, ugx, vgx
    
Qqgx = 2*Qtgx
Qqgy = 2*Qtgy
Qhox = 2*Qtwx + Qdagx
Qhoy = 2*Qtwy + Qdagy
del Qtgx, Qtgy, Qtwx, Qtwy, Qdagx, Qdagy
    
Qqgxhat = xrft.dft(Qqgx.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
Qqgyhat = xrft.dft(Qqgy.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
Qhoxhat = xrft.dft(Qhox.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
Qhoyhat = xrft.dft(Qhoy.chunk(chunks={'XC':500}), dim=['XC'], shift=False)
bxhat = xrft.dft(grid.interp(bx,'Z',boundary='fill').chunk(chunks={'XC':500}), 
                dim=['XC'], shift=False)
del Qqgx, Qqgy, Qhox, Qhoy, bx

with ProgressBar():
    Qqgxhat = xrft.dft(Qqgxhat.chunk(chunks={'YC':450}), dim=['YC'], 
                      shift=False, window=True).compute()
    Qqgyhat = xrft.dft(Qqgyhat.chunk(chunks={'YC':450}), dim=['YC'], 
                      shift=False, window=True).compute()
    Qhoxhat = xrft.dft(Qhoxhat.chunk(chunks={'YC':450}), dim=['YC'], 
                      shift=False, window=True).compute()
    Qhoyhat = xrft.dft(Qhoyhat.chunk(chunks={'YC':450}), dim=['YC'], 
                      shift=False, window=True).compute()
    bxhat = xrft.dft(bxhat.chunk(chunks={'YC':450}), dim=['YC'], 
                    shift=False, window=True).compute()
    
kx = 2*np.pi*Qqgxhat.freq_XC
ky = 2*np.pi*Qqgxhat.freq_YC
QGrhs = beta*bxhat + (1j*Qqgxhat*kx + 1j*Qqgyhat*ky)
HOrhs = beta*bxhat + (1j*Qhoxhat*kx + 1j*Qhoyhat*ky)
del Qqgxhat, Qqgyhat, Qhoxhat, Qhoyhat, bxhat

####################################################################
### Interpolate right-hand side to be on a monotic vertical grid ###
####################################################################
func = pchip(np.abs(N2.Zp1), N2.mean(['YC','XC']), axis=0)
N2intp = xr.DataArray(func(np.abs(zPew[1:-1])),  
                     dims=['Zp1'], 
                     coords={'Zp1':zPew[1:-1]}
                     )
func = pchip(np.abs(HOrhs.Z), HOrhs, axis=0, extrapolate='True')
HOintp = xr.DataArray(np.append(np.append(np.zeros_like(HOrhs[0])[np.newaxis,:,:],
                                         func(np.abs(zPew[1:-1])), axis=0),
                               np.zeros_like(HOrhs[0])[np.newaxis,:,:], axis=0),  
                     dims=['Zp1','freq_YC','freq_XC'],
                     coords={'Zp1':zPew,'freq_YC':HOrhs.freq_YC.data,'freq_XC':HOrhs.freq_XC.data}
                     )
func = pchip(np.abs(QGrhs.Z), QGrhs, axis=0, extrapolate='True')
QGintp = xr.DataArray(np.append(np.append(np.zeros_like(QGrhs[0])[np.newaxis,:,:],
                                         func(np.abs(zPew[1:-1])), axis=0),
                               np.zeros_like(QGrhs[0])[np.newaxis,:,:], axis=0),  
                     dims=['Zp1','freq_YC','freq_XC'],
                     coords={'Zp1':zPew,'freq_YC':QGrhs.freq_YC.data,'freq_XC':QGrhs.freq_XC.data}
                     )

##############################
### Higher-order inversion ###
##############################
with ProgressBar():
    wa_ho = w_rigid(N2intp, f0, beta, HOintp, kx, ky, 
                    dZ, dZ0=dZ, dZ1=dZ, zdim='Zp1',
                    dim=['Zp1','YC','XC'], 
                    coord={'Zp1':HOintp.Zp1.data, 'YC':N2.YC.data, 'XC':N2.XC.data}
                   ).compute()

##################################
### Quasi-geostropic inversion ###
##################################
with ProgressBar():
    wa_qg = w_rigid(N2intp, f0, beta, QGintp, kx, ky, 
                    dZ, dZ0=dZ, dZ1=dZ, zdim='Zp1',
                    dim=['Zp1','YC','XC'], 
                    coord={'Zp1':QGintp.Zp1.data, 'YC':N2.YC.data, 
                          'XC':N2.XC.data}
                   ).compute()

del QGrhs, HOrhs

In [None]:
func = pchip(np.abs(ds.Zp1), np.append(ds_center.W[date].data,np.zeros_like(ds_center.W[date,0])[np.newaxis,:,:],axis=0), axis=0)
wintp = xr.DataArray(func(np.abs(zPew)), dims=['Zp1','YC','XC'], 
                    coords={'Zp1':zPew,
                           'YC':ds_center.YC.data,'XC':ds_center.XC.data}
                    )
windowy = np.hanning(len(ds_center.YC))
wintp *= windowy[np.newaxis,:,np.newaxis]
wintp

In [None]:
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(8,11))
gs = GridSpec(3,2)
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[0,1])
ax3 = fig.add_subplot(gs[1:,:])
# fig.set_tight_layout(True)
im = ax1.pcolormesh(wintp.XC*1e-3, wintp.YC.sel(YC=slice(7e5,13e5))*1e-3, 
                   wintp.sel(YC=slice(7e5,13e5))[15]*86400, 
                   vmin=-32, vmax=32, cmap='RdBu_r', rasterized=True)
ax2.pcolormesh(wa_ho.XC*1e-3, wa_ho.YC.sel(YC=slice(7e5,13e5))*1e-3, 
              wa_ho.sel(YC=slice(7e5,13e5))[15]*86400, 
              vmin=-32, vmax=32, cmap='RdBu_r', rasterized=True)
ax3.pcolormesh(wintp.XC*1e-3, wintp.YC.sel(YC=slice(7e5,13e5))*1e-3, 
              (wintp[15] - wa_ho[15]).sel(YC=slice(7e5,13e5))*86400, 
              vmin=-32, vmax=32, cmap='RdBu_r', rasterized=True)
ax1.set_ylabel(r"Y [km]", fontsize=13)
ax3.set_ylabel(r"Y [km]", fontsize=13)
ax3.set_xlabel(r"X [km]", fontsize=13)
ax1.set_title('w', fontsize=16)
# axes[0,1].set_title('w$_{qg}$', fontsize=16)
ax2.set_title('w$_b$', fontsize=16)
ax3.set_title('w$_{ub}$', fontsize=16)
fig.subplots_adjust(right=0.86, hspace=0.28, wspace=.25)
cbar_ax = fig.add_axes([0.9, 0.25, 0.024, 0.5])
cbar = fig.colorbar(im, cax=cbar_ax, ticks=[-32,-24,-16,-8,0,8,16,24,32])
cbar.set_label(r'[m d$^{-1}$]', fontsize=12)