In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import xarray as xr
from xgcm import Grid
from matplotlib import pyplot as plt
import cartopy.crs as ccrs

from backgroung import plot_map
import bgcalc_xr as bgxr

In [2]:
model_path = 'http://barataria.tamu.edu:8080/thredds/dodsC/NcML/txla_hindcast_agg'
model = xr.open_dataset(model_path)

In [3]:
# Grid projection setting
llcrnrlat=22.85
llcrnrlon=-97.9
urcrnrlat=30.5
urcrnrlon=-87.5
lon_0 = (llcrnrlon+urcrnrlon)*0.5
extent=[llcrnrlon, urcrnrlon, llcrnrlat, urcrnrlat]
p = ccrs.PlateCarree(central_longitude=lon_0)

#Define region of interest
hmin = 10.
hmax = 50.
mab = 5

lonl = -95
lonr = -91
latl = 28.
latu = 29.6

ROI =(model.h>=hmin)&(model.h<=hmax)& \
    (model.coords['lon_rho']>=lonl)&(model.coords['lon_rho']<=lonr)& \
    (model.coords['lat_rho']>=latl)& (model['lat_rho'] < latu)& \
    (model.mask_rho==1)
ROI[161:166,274:279]=False
model.coords['mask2D'] = (('eta_rho', 'xi_rho'), ROI)

In [4]:
ds = bgxr.Calculator(model)

  yield from self._mapping


In [5]:
tidx = 0
ds.subset('2010-08-01T00')
rate = ds.get_intrate()
SOD = ds.get_btmflux()

In [6]:
uadv , vadv = ds.get_inthadv()

In [7]:
dudx = ds.grid.diff(ds.ds.u, 'xi').where(ds.ROI)

In [8]:
diverg = ds.get_divergence()

In [10]:
vort = ds.get_vorticity()

In [11]:
# start = time.time()
# vadv = ds.advflux_z()
# end = time.time()

# print(end-start)
# vadv.to_netcdf('vadv.nc')

vadv = xr.open_dataset('vadv.nc')

In [34]:
vdiff = ds.get_difflux_z(btm_diff=SOD)

In [35]:
vdiff

<xarray.DataArray (eta_rho: 191, xi_rho: 671)>
array([[ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       ..., 
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan],
       [ nan,  nan,  nan, ...,  nan,  nan,  nan]])
Coordinates:
    lon_rho  (eta_rho, xi_rho) float64 -95.48 -95.48 -95.48 -95.48 -95.48 ...
    lat_rho  (eta_rho, xi_rho) float64 23.11 23.14 23.17 23.2 23.23 23.26 ...
    mask2D   (eta_rho, xi_rho) bool False False False False False False ...
Dimensions without coordinates: eta_rho, xi_rho

In [32]:
start = time.time()

end = time.time()

In [33]:
start - end

-51.41118836402893

In [None]:
plot_map(ds.ds.lon_rho, ds.ds.lat_rho, vadv, vmin=-1e5, vmax=1e5)

In [13]:
from scipy.interpolate import interp1d
import time

In [None]:
flux = xr.zeros_like(ds.zmab)
Js, Is = np.where(ds.ROI == 1)

def get_vflux(ds, point):
    j = point[0]
    i = point[1]
    colum = ds.ds.isel(eta_rho=j, xi_rho=i)
    
    var = ds.ds.isel(eta_rho=j, xi_rho=i)[ds.var]
    w = ds.ds.isel(eta_rho=j, xi_rho=i)['w']
    
    stg_depths = ds.z_w.isel(eta_rho=j, xi_rho=i)
    var_depths = ds.z_rho.isel(eta_rho=j, xi_rho=i)
    zmab = ds.zmab.isel(eta_rho=j, xi_rho=i)
    
    fw_z = interp1d(stg_depths, w)
    if zmab >= var_depths[-1]:
        if zmab >= stg_depths[-1]:
            flux[j, i] = 0.
        else:
            flux[j, i] = fw_z(zmab) * var[-1]
    elif zmab <= var_depths[0]:
        flux[j, i] = fw_z(zmab) * var[0]
    else:
        fvar_z = interp1d(var_depths, var)
        flux[j, i] = fw_z(zmab) * fvar_z(zmab)

In [None]:
from concurrent.futures import ThreadPoolExecutor

In [None]:
e = ThreadPoolExecutor(8)

points = zip(Js, Is)

e.map(get_vflux,points)

In [None]:
dir(ds)