# Tests of vertical interpolation code

Relevant methods are located in `crocosi/gridop.py`:
- `interp2z_np_3d`: interpolates 1D/2D/3D arrays along their first dimensions
- `interp2z_np`: more flexible interpolation, allows for an extra dimension on input grid and data.
- `interp2z`: xarray based wrapper that may distribute the interpolation across workers

Only interp2z is tested for now.

We do not test extrapolations parameters

In [1]:
import numpy as np
import xarray as xr
import crocosi.gridop as gop

%matplotlib inline
from matplotlib import pyplot as plt

from itertools import permutations

In [2]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster()
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://127.0.0.1:50423  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 4  Cores: 4  Memory: 17.18 GB


---

## utils

In [3]:
# set synthetic data dimension and parameters

# input vertical grid dimensions
dims = {'time':10, 'z':5, 'y':15, 'x':20}
vmap = {'time':0., 'z':1., 'y':0., 'x':0.}

# potential dimensions of the target vertical grid
dims_t = ['z_target1d','y','x']

# target vertical grid
z_target = -.5+np.arange(dims['z']+2)

# true value of the interpolated field
z = np.arange(dims['z'])
v_truth = z_target + 0.
v_truth[z_target<z[0]]=np.NaN
v_truth[z_target>z[-1]]=np.NaN
print('z_target = {}'.format(z_target))
print('v_truth = {}'.format(v_truth))

z_target = [-0.5  0.5  1.5  2.5  3.5  4.5  5.5]
v_truth = [nan 0.5 1.5 2.5 3.5 nan nan]


In [4]:
def get_dimension_permutations(ndim, with_dims=None):
    """ Generate permutation of input variable dimensions
    
    Parameters
    ----------
    ndim: int
        Number of dimensions
    with_dims: list, optional
        Dimensions that required to appear
    """
    P = list(permutations(dims.keys(),ndim))
    if with_dims:
        P = [p for p in P if all([d in p for d in with_dims])]
    return [p for p in P if 'z' in p]

def get_ztarget_dimension_permutations(ndim, dims=None):
    """ Generate permutation of target variable dimensions
    
    Parameters
    ----------
    ndim: int
        Number of dimensions
    dims: list, optional
        Dimensions used on top of z_target1d
    """    
    if dims:
        P = list(permutations(['z_target1d']+dims, ndim))
    else:
        P = list(permutations(dims_t, ndim))
    return [p for p in P if 'z_target1d' in p]

In [5]:
# example of dimension permutations for input variables
get_dimension_permutations(3)

[('time', 'z', 'y'),
 ('time', 'z', 'x'),
 ('time', 'y', 'z'),
 ('time', 'x', 'z'),
 ('z', 'time', 'y'),
 ('z', 'time', 'x'),
 ('z', 'y', 'time'),
 ('z', 'y', 'x'),
 ('z', 'x', 'time'),
 ('z', 'x', 'y'),
 ('y', 'time', 'z'),
 ('y', 'z', 'time'),
 ('y', 'z', 'x'),
 ('y', 'x', 'z'),
 ('x', 'time', 'z'),
 ('x', 'z', 'time'),
 ('x', 'z', 'y'),
 ('x', 'y', 'z')]

In [6]:
# example of dimension permutations for the target grid
get_ztarget_dimension_permutations(3)

[('z_target1d', 'y', 'x'),
 ('z_target1d', 'x', 'y'),
 ('y', 'z_target1d', 'x'),
 ('y', 'x', 'z_target1d'),
 ('x', 'z_target1d', 'y'),
 ('x', 'y', 'z_target1d')]

In [7]:
# synthetic data generator
def get_synthetic_data(input_p, zt_p=None, chunks=None, numpy=False):
    """ Create a synthetic dataset based on some dimension permutation
    Parameters
    ----------
    input_p: list
        Permutation of dimensions for input variables (e.g. v and z)
    zt_p: list, optional
        Permutation of dimensions for target grid zt
    input_chunks: dict, int
        Chunks for input variables
    np: boolean
        Flag in order to output numpy arrays
    """

    # assemble coordinates
    _coords = {d: np.arange(size) for d, size in dims.items()}
    _coords.update(z_target1d=z_target)
    ds = xr.Dataset(coords=_coords)

    # create data variable and initial grid
    ds['v'] = sum([ds[d]*vmap[d] for d in p])
    ds['z_v'] = ds['z'] + 0.*ds['v']
    # the line above necessarily imposes a dimension order which may vary
    # should loop around all potential dimension order for zv

    # create target grid
    if zt_p:
        _map = [1. if d=='z_target1d' else 0. for d in zt_p]
        ds['z_target'] = sum(ds[d]*m for d, m in zip(zt_p, _map))
    else:
        ds['z_target'] = ds['z_target1d']
        
    if numpy:
        return ds['z_target'].data, ds['z'].data, ds['v'].data
        
    # rechunk data
    if chunks:
        ds = ds.chunk(chunks)
        
    return ds['z_target'], ds['z'], ds['v']

In [14]:
p = get_dimension_permutations(3)[2]
pt = get_ztarget_dimension_permutations(3)[2]
ds = get_synthetic_data(p, zt_p=pt)
print('input variables shape {}'.format(p))
print('target grid shape {}'.format(pt))
for v in ds:
    print('{} {}'.format(v.name,v.dims))

input variables shape ('time', 'y', 'z')
target grid shape ('y', 'z_target1d', 'x')
z_target ('y', 'z_target1d', 'x')
z ('z',)
v ('time', 'y', 'z')


---

## interp2z tests

In [11]:
def test_interp2z(zt, z, v, verbose=False, zt_dim='z_target1d', **kwargs):
    v_out = gop.interp2z(zt, z, v,
                         zt_dim=zt_dim,
                         b_extrap=0, t_extrap=0, **kwargs)
    if verbose:
        print('---------')
        print('v shape: {}'.format(list(v.dims)))
        print('Input shape = {}'.format(v.values.shape))
        print('Ouput shape = {}'.format(v_out.shape))
        print('interpolated field = {}'.format(v_out))
    z_pos = v_out._get_axis_num(zt_dim)
    v_out = v_out.values.swapaxes(0,z_pos)
    if v_out.ndim==4:
        v_out = v_out[:,0,0,0]
    elif v_out.ndim==3:
        v_out = v_out[:,0,0]
    elif v_out.ndim==2:
        v_out = v_out[:,0]
    size_min = min(v_out.size, v_truth.size)
    np.testing.assert_equal(v_out[:size_min],v_truth[:size_min])

### zt is 1D, z and v are not chunked

In [10]:
for ndim in range(1,5):
    for p in get_dimension_permutations(ndim):
        test_interp2z(*get_synthetic_data(p))
        #print('{} success'.format(p))
print('success')

success


### zt is 1D, z and v are chunked only in one spatial dimension

In [11]:
for ndim in range(1,5):
    for p in get_dimension_permutations(ndim):
        if 'x' in p:
            test_interp2z(*get_synthetic_data(p, chunks={'x': 2}))
            #print('{} success'.format(p))
print('success')

success


### zt is 1D, z and v are chunked in one spatial dimension + time

In [12]:
# chunked data: spatial and temporal dimension
for ndim in range(1,5):
    for p in get_dimension_permutations(ndim):
        if 'x' in p:
            test_interp2z(*get_synthetic_data(p, chunks={'x': 2, 'time': 1}))
            #print('{} success'.format(p))
print('success')

success


### zt has now a variable number of dimensions, z and v are up to 3D and chunked along x

In [13]:
for ndim in range(1,4): # should be range(1,5) but see next cell issues
    for p in get_dimension_permutations(ndim, with_dims=['x']):
        for ndimt in range(ndim+1):
            _tdim = list(p)
            _tdim.remove('z')
            for pt in get_ztarget_dimension_permutations(ndimt, dims=_tdim):
                #print(ndim, p, ndimt, _tdim, pt)
                test_interp2z(*get_synthetic_data(p, zt_p=pt, chunks={'x': 2}))
print('success')

success


### zt has now a variable number of dimensions, z and v are 4D and chunked along x

this may hang randomly ... distributed operations may be too rapidly executed, a pause was added in between computations

In [14]:
from time import sleep

ndim=4
p = get_dimension_permutations(ndim, with_dims=['x'])[0]
for ndimt in range(ndim+1):
    _tdim = list(p)
    _tdim.remove('z')
    for pt in get_ztarget_dimension_permutations(ndimt, dims=_tdim):
        #print(ndim, p, ndimt, _tdim, pt)
        test_interp2z(*get_synthetic_data(p, zt_p=pt, chunks={'x': 2}))
        sleep(2.)

print('success')

success


### zt has now the same dimension that z and v but different values, z and v are 4D and chunked along x

In [13]:
# same vertical dim name but different shape and values
ndim, ndimt=4, 2
p = get_dimension_permutations(ndim, with_dims=['x'])[0]
_tdim = list(p)
_tdim.remove('z')
pt = get_ztarget_dimension_permutations(ndimt, dims=_tdim)[0]
zt, z, v = get_synthetic_data(p, zt_p=pt, chunks={'x': 2})
zt = zt.isel(z_target1d=slice(0,v.z.size)).rename({'z_target1d': 'z'})

# test if same vertical dimension but different values fails by default
try:
    test_interp2z(zt, z, v, zt_dim='z')
except ValueError:
    print('success')

# test if override_dims does its job in the same case
test_interp2z(zt, z, v, zt_dim='z', verbose=True, override_dims=True)
print('success')

success
---------
v shape: ['time', 'z', 'y', 'x']
Input shape = (10, 5, 15, 20)
Ouput shape = (10, 5, 15, 20)
interpolated field = <xarray.DataArray (time: 10, z: 5, y: 15, x: 20)>
dask.array<transpose, shape=(10, 5, 15, 20), dtype=float64, chunksize=(10, 5, 15, 2), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) int64 0 1 2 3 4 5 6 7 8 9
  * z        (z) float64 -0.5 0.5 1.5 2.5 3.5
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
  * x        (x) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
success


---

## other tests that will need to be updated ... or not

In [15]:
def test_interp2z_np_3d(ds):
    #print(ds)
    out = gop.interp2z_np_3d(ds['z_target'].values, 
                          (0.*ds['v']+ds['z_v']).values,
                          ds['v'].values, 
                          b_extrap=0, t_extrap=0)
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))    
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    if out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)
    #hdl = plt.plot(ds['z_target1d'].values, out)
    #plt.grid()
    
test_interp2z_np_3d(get_ds(get_P(3)[6]))
test_interp2z_np_3d(get_ds(get_P(2)[2]))
test_interp2z_np_3d(get_ds(('z',)))

NameError: name 'get_ds' is not defined

In [None]:
def test_interp2z_np(ds):
    #print(ds)
    z_pos = ds.v._get_axis_num('z')
    z_size = ds.dims['z']
    out = gop.interp2z_np(ds['z_target'].values, 
                          (0.*ds['v']+ds['z_v']).values,
                          ds['v'].values, 
                          zdim=(z_pos, z_size),
                          b_extrap=0, t_extrap=0)
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    if out.ndim==4:
        out = out.swapaxes(0,z_pos)[:,0,0,0]
    elif out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)
    
test_interp2z_np(get_ds(get_P(4)[0]))
test_interp2z_np(get_ds(get_P(4)[6]))
test_interp2z_np(get_ds(get_P(3)[6]))
test_interp2z_np(get_ds(get_P(2)[2]))
test_interp2z_np(get_ds(('z',)))

In [None]:
def test_interp2z(ds):
    #print(ds)
    out = gop.interp2z(ds['z_target'], ds['v'], 
                          (0.*ds['v']+ds['z_v']),
                          b_extrap=0, t_extrap=0)
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    if out.ndim==4:
        out = out.swapaxes(0,z_pos)[:,0,0,0]
    elif out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)
    
test_interp2z_np(get_ds(get_P(4)[0]))
test_interp2z_np(get_ds(get_P(4)[6]))
test_interp2z_np(get_ds(get_P(3)[6]))
test_interp2z_np(get_ds(get_P(2)[2]))
test_interp2z_np(get_ds(('z',)))

In [None]:
def test_interp2z(ds):
    #print(ds)
    out = gop.interp2z(ds['x']*0.+ds['z_target'], ds['v'],
                       (0.*ds['v']+ds['z_v']),
                       zt_dim='z_target1d',
                       b_extrap=0, t_extrap=0).compute()
    print('---------')
    print('v shape: {}'.format(list(ds.v.dims)))
    print('Input shape = {}'.format(ds['v'].values.shape))
    print('Ouput shape = {}'.format(out.shape))
    z_pos = ds.v._get_axis_num('z')
    if out.ndim==4:
        out = out.data.swapaxes(0,z_pos)[:,0,0,0]
    elif out.ndim==3:
        out = out[:,0,0]
    elif out.ndim==2:
        out = out[:,0]
    print(out)

test_interp2z(get_ds(get_P(4)[0], chunks={'x': 2}))
test_interp2z(get_ds(get_P(4)[0]))
test_interp2z(get_ds(get_P(4)[0], chunks=2))
#test_interp2z(get_ds(get_P(4)[6]))
#test_interp2z(get_ds(get_P(3)[6]))
#test_interp2z(get_ds(get_P(2)[2]))
#test_interp2z(get_ds(('z',)))