In [1]:
import numpy as np
import xarray as xr
import dask.array as dsar
# import xroms as xm
import os.path as op
import scipy.interpolate as naiso
# import scipy.integrate as intg
# import xgcm.grid as xgd
import importlib

import matplotlib.pyplot as plt
%matplotlib inline

In [16]:
N = 16
ds = xr.DataArray(np.random.rand(N,N,N,N),
                  dims=['time','s','eta','xi'],
                  coords={'t':('time',np.arange(N)),
                          'z':(('s','eta','xi'),-np.arange(N)[:,np.newaxis,np.newaxis]*np.ones((N,N))),
                         'y':('eta',np.arange(N)),'x':('xi',np.arange(N))}
                 )
ds

<xarray.DataArray (time: 16, s: 16, eta: 16, xi: 16)>
array([[[[ 0.801352, ...,  0.398072],
         ..., 
         [ 0.61678 , ...,  0.430534]],

        ..., 
        [[ 0.025122, ...,  0.325356],
         ..., 
         [ 0.626726, ...,  0.969897]]],


       ..., 
       [[[ 0.338694, ...,  0.162314],
         ..., 
         [ 0.535866, ...,  0.470606]],

        ..., 
        [[ 0.977966, ...,  0.57539 ],
         ..., 
         [ 0.327781, ...,  0.501317]]]])
Coordinates:
    t        (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
    z        (s, eta, xi) float64 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ...
    y        (eta) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
    x        (xi) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
Dimensions without coordinates: time, s, eta, xi

In [8]:
znew = np.linspace(0,N,int(2*N))

In [4]:
def _interpn(y,x,xnew):
    """
    Interpolates and flips the vertical coordinate so as
    the bottom layer is at the top of the array in sigma coordinates.
    """
    f = naiso.interp1d(x,y,
                    fill_value='extrapolate')
    return f(xnew)

def _interp_wrap(interp_func):
    """
    Wrapper function for `xroms._interpolate1D`.
    """
    def func(a, b, bnew, axes=None):
        # if a.ndim > 4 or len(axes) > 3:
        #     raise ValueError("Data has too many dimensions "
        #                     "and/or too many axes to detrend over.")
        if axes is None:
            axes = tuple(range(a.ndim))
        else:
            if len(set(axes)) < len(axes):
                raise ValueError("Duplicate axes are not allowed.")

        if len(axes)>1:
            raise NotImplementedError("Interpolation over multiple "
                                    "axes is not implemented yet.")

        for each_axis in axes:
            if len(a.chunks[each_axis]) != 1:
                raise ValueError('The axis along the interpolation applied '
                                'cannot be chunked.')

        if a.ndim == 1:
            return dsar.map_blocks(interp_func, a, b, bnew,
                                   chunks=a.chunks, dtype=a.dtype
                                  )
        else:
            for each_axis in range(a.ndim):
                if each_axis not in axes:
                    if len(a.chunks[each_axis]) != a.shape[each_axis]:
                        raise ValueError("The axes other than ones to "
                                        "interpolate should have "
                                        "a chunk length of 1.")
            return dsar.map_blocks(interp_func, a, b, bnew,
                                   chunks=a.chunks, dtype=a.dtype
                                  )

    return func

def _apply_interp(da, z, znew, axes):
    """Wrapper function for applying interpolation"""
    if da.chunks:
        func = _interp_wrap(_interpn)
        da = xr.DataArray(func(da.data, z, znew, axes=axes),
                        dims=da.dims, coords=da.coords)
    else:
        if da.ndim == 1:
            da = xr.DataArray(_interpn(da, z, znew),
                            dims=da.dims, coords=da.coords)
        else:
            # da = detrendn(da, axes=axis_num)
        # else:
            raise ValueError("Data should be dask array "
                            "for multidimensional data.")

    return da

In [12]:
def sig2z(da, zr, zi, nvar=None, axes=None,
        dim=None, coord=None):
    """
    Interpolate variables on \sigma coordinates onto z coordinates.

    Parameters
    ----------
    da : `dask.array`
        The data on sigma coordinates to be interpolated
    grid : `xgcm.object`
        The grid information of the data set.
    zr : `numpy.array`
        The depths corresponding to sigma layers
    zi : `numpy.array`
        The depths which to interpolate the data on
    nvar : str (optional)
        Name of the variable. Only necessary when the variable is
        horizontal velocity.

    Returns
    -------
    dai : `dask.array`
        The data interpolated onto a spatial uniform z coordinate
    """

    if np.diff(zi)[0] < 0. or zi.max() <= 0.:
        raise ValueError("The values in `zi` should be postive and increasing.")
    if np.any(np.absolute(zr[0]) > np.absolute(zr[-1])):
        raise ValueError("`zr` should have the deepest depth at the last index")
    if zr.shape != da.shape[-3:]:
        raise ValueError("`zr` should have the same "
                        "spatial dimensions as `da`.")

    if dim == None:
        dim = da.dims
    if coord == None:
        coord = da.coords
    # N = da.shape
    # nzi = len(zi)
    # if len(N) == 4:
    #     dai = np.empty((N[0],nzi,N[-2],N[-1]))
    #     # dim = [dimd[0],'z',dimd[-2],dimd[-1]]
    #     # coord = {dimd[0]:da.coords[dimd[0]],
    #     #         'z':-zi, dimd[-2]:da.coords[dimd[-2]],
    #     #         dimd[-1]:da.coords[dimd[-1]]
    #     #         }
    # elif len(N) == 3:
    #     dai = np.empty((nzi,N[-2],N[-1]))
    #     # dim = ['z',dimd[-2],dimd[-1]]
    #     # coord = {'z':-zi, dimd[-2]:da.coords[dimd[-2]],
    #     #         dimd[-1]:da.coords[dimd[-1]]
    #     #         }
    # else:
    #     raise ValueError("The data should at least have three dimensions")
    # dai[:] = np.nan

#     zi = -zi[::-1] # ROMS has deepest level at index=0
    zi = -zi

#     if nvar=='u':  # u variables
#         zl = grid.interp(zr, 'X', boundary='extend')
#         # zl = .5*(np.roll(np.roll(zr, -1, axis=-1), -1, axis=-2)
#         #          + np.roll(zr, -1, axis=-2)
#         #         )
#     elif nvar=='v': # v variables
#         zl = grid.interp(zr, 'Y', boundary='extend')
#         # zl = .5*(np.roll(zr, -1, axis=-1)
#         #          + np.roll(np.roll(zr, -1, axis=-2), -1, axis=-1)
#         #         )
#     else:
    zl = zr

    dai = _apply_interp(da, zl, zi, axes)
    # for i in range(N[-1]):
    #     for j in range(N[-2]):
    #         # only bother for sufficiently deep regions
    #         if zl[:,j,i].min() < -1e2:
    #             # only interp on z above topo
    #             ind = np.argwhere(zi >= zl[:,j,i].min())
    #             if len(N) == 4:
    #                 for s in range(N[0]):
    #                     dai[s,:len(ind),j,i] = _interpolate(zl[:,j,i],
    #                                                         da[s,:,j,i].values,
    #                                                         zi[int(ind[0]):])
    #             else:
    #                 dai[:len(ind),j,i] = _interpolate(zl[:,j,i],
    #                                                  da[:,j,i].values,
    #                                                  zi[int(ind[0]):])
    return dai
#     return xr.DataArray(dai.data, dims=dim, coords=coord)

In [18]:
test = sig2z(ds.chunk({'time':1,'eta':1,'xi':1}), ds.z.chunk({'eta':1,'xi':1}).data, znew, 
             axes=(1,),
             dim=[wdim[0],'z',wdim[-2],wdim[-1]], 
             coord={wdim[0]:w.coords[wdim[0]],'z':-zi,
                    wdim[-2]:w.coords[wdim[-2]],wdim[-1]:w.coords[wdim[-1]]}
            )
test.compute()

NameError: name 'wdim' is not defined