# Tests of vertical interpolation code

Relevant methods are located in `crocosi/gridop.py`:
- `interp2z_np_3d`: interpolates 1D/2D/3D arrays along their first dimensions
- `interp2z_np`: more flexible interpolation, allows for an extra dimension on input grid and data.
- `interp2z`: xarray based wrapper that may distribute the interpolation across workers

Only interp2z is tested for now.

We do not test extrapolations parameters

In [1]:
import numpy as np
import xarray as xr
import crocosi.gridop as gop

%matplotlib inline
from matplotlib import pyplot as plt

from itertools import permutations

In [2]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:63444  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 4  Cores: 4  Memory: 17.18 GB


---

## utils

In [3]:
# set synthetic data dimension and parameters

# input vertical grid dimensions
dims = {'time':10, 'z':5, 'y':15, 'x':20}
vmap = {'time':0., 'z':1., 'y':0., 'x':0.}

# potential dimensions of the target vertical grid
dims_t = ['z_target1d','y','x']

# target vertical grid
z_target = -.5+np.arange(dims['z']+2)

# true value of the interpolated field
z = np.arange(dims['z'])
v_truth = z_target + 0.
v_truth[z_target<z[0]]=np.NaN
v_truth[z_target>z[-1]]=np.NaN
print('z_target = {}'.format(z_target))
print('v_truth = {}'.format(v_truth))

z_target = [-0.5  0.5  1.5  2.5  3.5  4.5  5.5]
v_truth = [nan 0.5 1.5 2.5 3.5 nan nan]


In [4]:
def get_dimension_permutations(ndim):
    return [p for p in list(permutations(dims.keys(),ndim)) if 'z' in p]

def get_ztarget_dimension_permutations(ndim):
    return [p for p in list(permutations(dims_t,ndim)) if 'z_target1d' in p]

In [5]:
# example of dimension permutations for input variables
get_dimension_permutations(3)

[('time', 'z', 'y'),
 ('time', 'z', 'x'),
 ('time', 'y', 'z'),
 ('time', 'x', 'z'),
 ('z', 'time', 'y'),
 ('z', 'time', 'x'),
 ('z', 'y', 'time'),
 ('z', 'y', 'x'),
 ('z', 'x', 'time'),
 ('z', 'x', 'y'),
 ('y', 'time', 'z'),
 ('y', 'z', 'time'),
 ('y', 'z', 'x'),
 ('y', 'x', 'z'),
 ('x', 'time', 'z'),
 ('x', 'z', 'time'),
 ('x', 'z', 'y'),
 ('x', 'y', 'z')]

In [6]:
# example of dimension permutations for the target grid
get_ztarget_dimension_permutations(3)

[('z_target1d', 'y', 'x'),
 ('z_target1d', 'x', 'y'),
 ('y', 'z_target1d', 'x'),
 ('y', 'x', 'z_target1d'),
 ('x', 'z_target1d', 'y'),
 ('x', 'y', 'z_target1d')]

In [7]:
dims

{'time': 10, 'z': 5, 'y': 15, 'x': 20}

In [8]:
# synthetic data generator
def get_synthetic_data(input_p, zt_p=None, chunks=None, numpy=False):
    """ Create a synthetic dataset based on some dimension permutation
    Parameters
    ----------
    input_p: list
        Permutation of dimensions for input variables (e.g. v and z)
    zt_p: list, optional
        Permutation of dimensions for target grid zt
    input_chunks: dict, int
        Chunks for input variables
    np: boolean
        Flag in order to output numpy arrays
    """

    # assemble coordinates
    _coords = {d: np.arange(size) for d, size in dims.items()}
    _coords.update(z_target1d=z_target)
    ds = xr.Dataset(coords=_coords)

    # create data variable and initial grid
    ds['v'] = sum([ds[d]*vmap[d] for d in p])
    ds['z_v'] = ds['z'] + 0.*ds['v']
    # the line above necessarily imposes a dimension order which may vary
    # should loop around all potential dimension order for zv

    # create target grid
    if zt_p:
        _map = [1. if d=='z_target1d' else 0. for d in zt_p]
        ds['z_target'] = sum(ds[d]*m for d, m in zip(zt_p, _map))
    else:
        ds['z_target'] = ds['z_target1d']
        
    if numpy:
        return ds['z_target'].data, ds['z'].data, ds['v'].data
        
    # rechunk data
    if chunks:
        ds = ds.chunk(chunks)
        
    return ds['z_target'], ds['z'], ds['v']

In [9]:
p = get_dimension_permutations(3)[2]
pt = get_ztarget_dimension_permutations(3)[2]
ds = get_synthetic_data(p, zt_p=pt)
print('input variables shape {}'.format(p))
print('target grid shape {}'.format(pt))
for v in ds:
    print('{} {}'.format(v.name,v.dims))

input variables shape ('time', 'y', 'z')
target grid shape ('y', 'z_target1d', 'x')
z_target ('y', 'z_target1d', 'x')
z ('z',)
v ('time', 'y', 'z')


---

## interp2z tests

In [10]:
def test_interp2z(zt, z, v, verbose=False):
    v_out = gop.interp2z(zt, z, v,
                         zt_dim='z_target1d',
                         b_extrap=0, t_extrap=0).compute()
    if verbose:
        print('---------')
        print('v shape: {}'.format(list(v.dims)))
        print('Input shape = {}'.format(v.values.shape))
        print('Ouput shape = {}'.format(v_out.shape))
    z_pos = v_out._get_axis_num('z_target1d')
    v_out = v_out.data.swapaxes(0,z_pos)
    if v_out.ndim==4:
        v_out = v_out[:,0,0,0]
    elif v_out.ndim==3:
        v_out = v_out[:,0,0]
    elif v_out.ndim==2:
        v_out = v_out[:,0]
    np.testing.assert_equal(v_out,v_truth)

In [16]:
# not chunked data
for ndim in range(1,5):
    for p in get_dimension_permutations(ndim):
        test_interp2z(*get_synthetic_data(p))
        print('{} success'.format(p))

('z',) success
('time', 'z') success
('z', 'time') success
('z', 'y') success
('z', 'x') success
('y', 'z') success
('x', 'z') success
('time', 'z', 'y') success
('time', 'z', 'x') success
('time', 'y', 'z') success
('time', 'x', 'z') success
('z', 'time', 'y') success
('z', 'time', 'x') success
('z', 'y', 'time') success
('z', 'y', 'x') success
('z', 'x', 'time') success
('z', 'x', 'y') success
('y', 'time', 'z') success
('y', 'z', 'time') success
('y', 'z', 'x') success
('y', 'x', 'z') success
('x', 'time', 'z') success
('x', 'z', 'time') success
('x', 'z', 'y') success
('x', 'y', 'z') success
('time', 'z', 'y', 'x') success
('time', 'z', 'x', 'y') success
('time', 'y', 'z', 'x') success
('time', 'y', 'x', 'z') success
('time', 'x', 'z', 'y') success
('time', 'x', 'y', 'z') success
('z', 'time', 'y', 'x') success
('z', 'time', 'x', 'y') success
('z', 'y', 'time', 'x') success
('z', 'y', 'x', 'time') success
('z', 'x', 'time', 'y') success
('z', 'x', 'y', 'time') success
('y', 'time',

In [17]:
# chunked data: only one spatial dimension
for ndim in range(1,5):
    for p in get_dimension_permutations(ndim):
        if 'x' in p:
            test_interp2z(*get_synthetic_data(p, chunks={'x': 2}))
            print('{} success'.format(p))

('z', 'x') success
('x', 'z') success
('time', 'z', 'x') success
('time', 'x', 'z') success
('z', 'time', 'x') success
('z', 'y', 'x') success
('z', 'x', 'time') success
('z', 'x', 'y') success
('y', 'z', 'x') success
('y', 'x', 'z') success
('x', 'time', 'z') success
('x', 'z', 'time') success
('x', 'z', 'y') success
('x', 'y', 'z') success
('time', 'z', 'y', 'x') success
('time', 'z', 'x', 'y') success
('time', 'y', 'z', 'x') success
('time', 'y', 'x', 'z') success
('time', 'x', 'z', 'y') success
('time', 'x', 'y', 'z') success
('z', 'time', 'y', 'x') success
('z', 'time', 'x', 'y') success
('z', 'y', 'time', 'x') success
('z', 'y', 'x', 'time') success
('z', 'x', 'time', 'y') success
('z', 'x', 'y', 'time') success
('y', 'time', 'z', 'x') success
('y', 'time', 'x', 'z') success
('y', 'z', 'time', 'x') success
('y', 'z', 'x', 'time') success
('y', 'x', 'time', 'z') success
('y', 'x', 'z', 'time') success
('x', 'time', 'z', 'y') success
('x', 'time', 'y', 'z') success
('x', 'z', 'time

In [18]:
# chunked data: spatial and temporal dimension
for ndim in range(1,5):
    for p in get_dimension_permutations(ndim):
        if 'x' in p:
            test_interp2z(*get_synthetic_data(p, chunks={'x': 2, 'time': 1}))
            print('{} success'.format(p))

('z', 'x') success
('x', 'z') success
('time', 'z', 'x') success
('time', 'x', 'z') success
('z', 'time', 'x') success
('z', 'y', 'x') success
('z', 'x', 'time') success
('z', 'x', 'y') success
('y', 'z', 'x') success
('y', 'x', 'z') success
('x', 'time', 'z') success
('x', 'z', 'time') success
('x', 'z', 'y') success
('x', 'y', 'z') success
('time', 'z', 'y', 'x') success
('time', 'z', 'x', 'y') success
('time', 'y', 'z', 'x') success
('time', 'y', 'x', 'z') success
('time', 'x', 'z', 'y') success
('time', 'x', 'y', 'z') success
('z', 'time', 'y', 'x') success
('z', 'time', 'x', 'y') success
('z', 'y', 'time', 'x') success
('z', 'y', 'x', 'time') success
('z', 'x', 'time', 'y') success
('z', 'x', 'y', 'time') success
('y', 'time', 'z', 'x') success
('y', 'time', 'x', 'z') success
('y', 'z', 'time', 'x') success
('y', 'z', 'x', 'time') success
('y', 'x', 'time', 'z') success
('y', 'x', 'z', 'time') success
('x', 'time', 'z', 'y') success
('x', 'time', 'y', 'z') success
('x', 'z', 'time

In [19]:
# chunked data: only one spatial dimension
for ndim in range(1,5):
    for p in get_dimension_permutations(ndim):
        if 'x' in p:
            for ndimt in range(4):
                for pt in get_ztarget_dimension_permutations(3):
                    test_interp2z(*get_synthetic_data(p, zt_p=pt, chunks={'x': 2}))

ValueError: ('z_target1d', 'x') must be a permuted list of ('x', 'y', 'z_target1d'), unless `...` is included

---

## test interp2z_np

In [15]:
def test_interp2z_np_3d(ds):
    #print(ds)
    out = gop.interp2z_np_3d(ds['z_target'].values, 
                          (0.*ds['v']+ds['z_v']).values,
                          ds['v'].values, 
                          b_extrap=0, t_extrap=0)
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))    
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    if out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)
    #hdl = plt.plot(ds['z_target1d'].values, out)
    #plt.grid()
    
test_interp2z_np_3d(get_ds(get_P(3)[6]))
test_interp2z_np_3d(get_ds(get_P(2)[2]))
test_interp2z_np_3d(get_ds(('z',)))

NameError: name 'get_ds' is not defined

In [None]:
def test_interp2z_np(ds):
    #print(ds)
    z_pos = ds.v._get_axis_num('z')
    z_size = ds.dims['z']
    out = gop.interp2z_np(ds['z_target'].values, 
                          (0.*ds['v']+ds['z_v']).values,
                          ds['v'].values, 
                          zdim=(z_pos, z_size),
                          b_extrap=0, t_extrap=0)
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    if out.ndim==4:
        out = out.swapaxes(0,z_pos)[:,0,0,0]
    elif out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)
    
test_interp2z_np(get_ds(get_P(4)[0]))
test_interp2z_np(get_ds(get_P(4)[6]))
test_interp2z_np(get_ds(get_P(3)[6]))
test_interp2z_np(get_ds(get_P(2)[2]))
test_interp2z_np(get_ds(('z',)))

In [None]:
def test_interp2z(ds):
    #print(ds)
    out = gop.interp2z(ds['z_target'], ds['v'], 
                          (0.*ds['v']+ds['z_v']),
                          b_extrap=0, t_extrap=0)
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    if out.ndim==4:
        out = out.swapaxes(0,z_pos)[:,0,0,0]
    elif out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)
    
test_interp2z_np(get_ds(get_P(4)[0]))
test_interp2z_np(get_ds(get_P(4)[6]))
test_interp2z_np(get_ds(get_P(3)[6]))
test_interp2z_np(get_ds(get_P(2)[2]))
test_interp2z_np(get_ds(('z',)))

In [None]:
def test_interp2z(ds):
    #print(ds)
    out = gop.interp2z(ds['x']*0.+ds['z_target'], ds['v'],
                       (0.*ds['v']+ds['z_v']),
                       zt_dim='z_target1d',
                       b_extrap=0, t_extrap=0).compute()
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    z_pos = ds.v._get_axis_num('z')
    if out.ndim==4:
        out = out.data.swapaxes(0,z_pos)[:,0,0,0]
    elif out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)

test_interp2z(get_ds(get_P(4)[0], chunks={'x': 2}))
test_interp2z(get_ds(get_P(4)[0]))
test_interp2z(get_ds(get_P(4)[0], chunks=2))
#test_interp2z(get_ds(get_P(4)[6]))
#test_interp2z(get_ds(get_P(3)[6]))
#test_interp2z(get_ds(get_P(2)[2]))
#test_interp2z(get_ds(('z',)))

In [None]:
ds = get_ds(get_P(4)[0], chunks={'x': 2})

In [None]:
ds.v.dims

In [None]:
['' if d=='z' else d for d in list(ds.v.dims)]

In [None]:
ds

In [None]:
#xr.align(ds.v, ds.z_v, join='left')
xr.broadcast(ds.v, ds.z_v)

def get_ds(p, chunks=None):
    """ Create a synthetic dataset based on some dimension order p
    """

    # assemble coordinates
    _coords = {d: np.arange(dims[d]) for d in p}
    _coords.update(z_target1d=np.arange(.5,6))
    ds = xr.Dataset(coords=_coords)

    # create data variable and initial grid
    #ds['v'] = 0.
    #for d in p:
    #    ds['v'] = ds['v'] + ds[d]*vmap[d]
    #ds['v'] = sum([ds[d]*vmap[d] for d in p])
    #ds['z_v'] = ds['z'] + 0.*ds['v']
    z = xr.DataArray(np.arange(dims['z']), dims=['z']).rename('z_v')
    v = xr.DataArray(np.broadcast_to(z.data[:,None,None],(z.size, 10, 20)),
                     dims=['z', 'y', 'x']).rename('v')
    zt = xr.DataArray(np.arange(.5,6), dims=['z_target1d']).rename('z_target')
    ds = xr.merge([z,v,zt])
    
    # the line above necessarily imposes a dimension order which may vary
    # should loop around all potential dimension order for zv

    # create target grid
    #ds['z_target'] = ds['z_target1d']
    # need to vary number of dimensions and their order on the line above
    if chunks:
        ds = ds.chunk(chunks)
    return ds

ds = get_ds(get_P(4)[0], chunks=2)
ds.mean().compute()

xr.apply_ufunc(np.mean, ds['v'], (1,2), dask='parallelized', output_dtypes=[np.float64])

In [None]:
x,y = np.ones((2,)), np.ones((3,1,2))
print(x.shape, y.shape, _align_dims_but_dim0(x,y).shape)

x,y = np.ones((2,2,1)), np.ones((3,1,2))
print(x.shape, y.shape, _align_dims_but_dim0(x,y).shape)

x,y = np.ones((2,1, 2)), np.ones((3,1,2))
print(x.shape, y.shape, _align_dims_but_dim0(x,y).shape)

x,y = np.ones((2,2,2)), np.ones((3,1,2))
print(x.shape, y.shape, _align_dims_but_dim0(x,y).shape)

In [None]:
x

In [None]:
x.swapaxes