Skip to content

Commit

Permalink
Refactoring of upsampling utils + coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
ghiggi committed Jan 19, 2022
1 parent 2bb5b8c commit 2e5997c
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 58 deletions.
193 changes: 135 additions & 58 deletions pyresample/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ def upsample(self, x=1, y=1):
swath_def.upsample(x=1, y=1) simply returns the current swath_def.
"""
# TODO: An alternative would be to use geotiepoints.geointerpolator.GeoInterpolator
# But I have some problem using it, see code snippet in the PR description.
# But I have some problem using it, see code snippet in a comment of the PR.
import dask.array as da
import numpy as np
import pyproj
Expand All @@ -824,70 +824,31 @@ def upsample(self, x=1, y=1):
# Return SwathDefinition if nothing to upsample
if x == 1 and y == 1:
return self

def _upsample_ranges_1D(x, factor=1):
ranges2D = np.linspace(x[:-1], x[1:], num=factor, endpoint=False, axis=1)
return np.concatenate((ranges2D.ravel(), [x[-1]]))

def upsample_ranges_2D(x, factor=1, axis=0):
x = np.array(x)
if x.ndim not in [1, 2]:
raise ValueError("Expects 1D or 2D array.")
if not isinstance(axis, int):
raise TypeError("'axis' must be: 0 or 1 integer.")
if axis not in [0, 1]:
raise ValueError("Expects 'axis' 0 or 1")
if not isinstance(factor, int):
raise TypeError("'factor' must be an integer equal or larger than 1.")
if factor < 1:
raise ValueError("'factor' must be an integer equal or larger than 1.")

if x.ndim == 1:
return _upsample_ranges_1D(x, factor=factor)
else:
l_ranges = []
if axis == 1:
for i in range(x.shape[0]):
l_ranges.append(_upsample_ranges_1D(x[i, :], factor=factor))
return np.vstack(l_ranges)
else: # axis = 0
for i in range(x.shape[1]):
l_ranges.append(_upsample_ranges_1D(x[:, i], factor=factor))
return np.vstack(l_ranges).transpose()

def _upsample_corners(corners, x_factor=1, y_factor=1):
new_breaks_xx = upsample_ranges_2D(corners, factor=x_factor, axis=1)
new_corners = upsample_ranges_2D(new_breaks_xx, factor=y_factor, axis=0)
return new_corners
# --------------------------------------------------------------------.
# TODO:
# - Refactor for dask-compatibility
# - Should we make _infer_interval_breaks dask-compatible?

def _get_corners_from_centroids(centroids):
breaks_xx = _infer_interval_breaks(centroids, axis=1)
corners = _infer_interval_breaks(breaks_xx, axis=0)
return corners

def _get_centroids_from_corners(corners):
centroids = (corners[1:, 1:] + corners[:-1, :-1]) / 2
return centroids

# TODO: Decide if compute in memory or with dask
def upsample_centroids(centroid_x, centroid_y, centroid_z, x_factor=1, y_factor=1):
corners_x = _get_corners_from_centroids(centroid_x)
corners_y = _get_corners_from_centroids(centroid_y)
corners_z = _get_corners_from_centroids(centroid_z)
x_new_corners = _upsample_corners(corners_x, x_factor=x_factor, y_factor=y_factor)
y_new_corners = _upsample_corners(corners_y, x_factor=x_factor, y_factor=y_factor)
z_new_corners = _upsample_corners(corners_z, x_factor=x_factor, y_factor=y_factor)
x_new_centroids = _get_centroids_from_corners(x_new_corners)
y_new_centroids = _get_centroids_from_corners(y_new_corners)
z_new_centroids = _get_centroids_from_corners(z_new_corners)
return x_new_centroids, y_new_centroids, z_new_centroids

def _upsample_centroid(centroid, x_factor=1, y_factor=1):
def _upsample_centroid(centroid, x=1, y=1):
corners = _get_corners_from_centroids(centroid)
new_corners = _upsample_corners(corners, x_factor=x_factor, y_factor=y_factor)
new_centroids = _get_centroids_from_corners(new_corners)
# Retrieve corners of the the upsampled grid
new_corners = _linspace2D_between_values(corners, num_x=x - 1, num_y=y - 1)
# Get centroids from corners
new_centroids = (new_corners[:-1, :-1] + new_corners[1:, 1:]) / 2
return new_centroids

def upsample_centroids(centroid_x, centroid_y, centroid_z, x=1, y=1):
x_new_centroids = _upsample_centroid(centroid_x, x=x, y=y)
y_new_centroids = _upsample_centroid(centroid_y, x=x, y=y)
z_new_centroids = _upsample_centroid(centroid_z, x=x, y=y)
return x_new_centroids, y_new_centroids, z_new_centroids

# --------------------------------------------------------------------.
# Define geodetic and geocentric projection
geocent = pyproj.Proj(proj='geocent')
latlong = pyproj.Proj(proj='latlong')
Expand All @@ -913,10 +874,9 @@ def _upsample_centroid(centroid, x_factor=1, y_factor=1):
# x,
# y)
# res1 = xr.DataArray(res1, dims=['y', 'x', 'coord'], coords=src_lons.coords)

res = np.stack(upsample_centroids(res[:, :, 0].data,
res[:, :, 1].data,
res[:, :, 2].data, x_factor=x, y_factor=y), axis=2)
res[:, :, 2].data, x=x, y=y), axis=2)
new_centroids = xr.DataArray(da.from_array(res), dims=['y', 'x', 'xyz'])

# Back-conversion to geographic CRS
Expand Down Expand Up @@ -1267,6 +1227,123 @@ def _convert_2D_array(arr, to, dims=None):
raise NotImplementedError


def _linspace1D_between_values(arr, num=0):
"""Dask-friendly function linearly interpolating values between each 1D array values.
This function does not perform extrapolation.
It expects a 1D array as input!
Parameters
----------
arr : (np.ndarray, dask.array.Array)
Numpy or Dask Array to be linearly interpolated between values.
num : int, optional
The number of linearly spaced values to infer between array values.
The default is 0.
Returns
-------
arr : (np.ndarray, dask.array.Array)
Numpy or Dask Array with in-between linearly interpolated values.
Example
-------
Function call: _linspace1D_between_values(arr, num=1)
Input array:
np.array([5.0, 7.0])
Output array:
np.array([5.0, 6.0, 7.0])
"""
import xarray as xr

# Check input validity
if arr.ndim != 1:
raise ValueError("'_linspace1D_between_values' expects a 1D array.")
num = int(num)
if num < 0:
raise ValueError("'x' and 'y' must be an integer equal or larger than 0.")
if num == 0:
return arr
# Define src and dst ties
dst_N = (arr.size - 1) * (num + 1) + 1
src_ties = np.arange(dst_N, step=num + 1)
dst_ties = np.arange(dst_N)
# Interpolate
da = xr.DataArray(
data=arr,
dims=("x"),
coords={"x": src_ties}
)
da_interp = da.interp(x=dst_ties, method="linear")
return da_interp.data


def _linspace2D_between_values(arr, num_x=0, num_y=0):
"""Dask-friendly function linearly interpolating values between each 2D array values.
This function does not perform extrapolation.
It expects a 2D array as input!
Parameters
----------
arr : (np.ndarray, dask.array.Array)
Numpy or Dask Array to be linearly interpolated between values.
num_x : int, optional
The number of linearly spaced values to infer between array values (along x).
. The default is 0.
num_y : int, optional
The number of linearly spaced values to infer between array values (along y).
The default is 0.
Returns
-------
arr : (np.ndarray, dask.array.Array)
Numpy or Dask Array with in-between linearly interpolated values.
Example
-------
Function call: _linspace1D_between_values(arr, num=1)
Input:
np.array([[5.0, 7.0],
[7.0, 9.0]])
Output:
np.array([[5.0, 6.0, 7.0],
[6.0, 7.0, 8.0],
[7.0, 8.0, 9.0]])
"""
import xarray as xr

# Check input validity
if arr.ndim != 2:
raise ValueError("'_linspace2D_between_values' expects a 2D array.")
num_x = int(num_x)
num_y = int(num_y)
if num_x < 0 or num_y < 0:
raise ValueError("'x' and 'y' must be an integer equal or larger than 0.")
if num_x == 0 and num_y == 0:
return arr
# Define src and dst ties
shape = arr.shape
Nx_dst = (shape[1] - 1) * (num_x + 1) + 1
Ny_dst = (shape[0] - 1) * (num_y + 1) + 1

src_ties_x = np.arange(Nx_dst, step=num_x + 1)
src_ties_y = np.arange(Ny_dst, step=num_y + 1)
dst_ties_x = np.arange(Nx_dst)
dst_ties_y = np.arange(Ny_dst)
# Interpolate
da = xr.DataArray(
data=arr,
dims=("y", "x"),
coords={"y": src_ties_y, "x": src_ties_x}
)
da_interp = da.interp({"y": dst_ties_y, "x": dst_ties_x}, method="linear")
return da_interp.data


def _get_extended_lonlats(lon_start, lat_start, lon_end, lat_end, npts, transpose=True):
"""Utils employed by SwathDefinition.extend.
Expand Down
57 changes: 57 additions & 0 deletions pyresample/test/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,63 @@ def test_latslons_arr_conversion(self):
self.assertRaises(TypeError, _convert_2D_array, [dict_format['Numpy']], 'numpy')
self.assertRaises(ValueError, _convert_2D_array, dict_format['Numpy'], 'unvalid_format')

def test_linspace1D_between_values(self):
"""Test linspace1D_between_values."""
import dask.array as da
import numpy as np

from pyresample.geometry import _linspace1D_between_values
arr_np = np.array([5.0, 7.0, 9.0])
arr_dask = da.from_array(arr_np)

res_np = _linspace1D_between_values(arr_np, num=1)
res_dask = _linspace1D_between_values(arr_dask, num=1)

np.testing.assert_allclose(res_np, [5., 6., 7., 8., 9.])
np.testing.assert_allclose(res_dask, [5., 6., 7., 8., 9.])
assert isinstance(res_np, np.ndarray)
assert isinstance(res_dask, da.Array)

# Test for no interpolation inbetween values
res = _linspace1D_between_values(arr_np, num=0)
np.testing.assert_allclose(res, arr_np)

# Test for valid inputs
self.assertRaises(ValueError, _linspace1D_between_values, arr_np, -1)
self.assertRaises(ValueError, _linspace1D_between_values, np.zeros((2, 2)), 0)

def test_linspace2D_between_values(self):
"""Test linspace2D_between_values."""
import dask.array as da
import numpy as np

from pyresample.geometry import _linspace2D_between_values
arr_np = np.array([[5.0, 7.0],
[7.0, 9.0]])
arr_dask = da.from_array(arr_np)

res_np = _linspace2D_between_values(arr_np, num_x=1, num_y=3)
res_dask = _linspace2D_between_values(arr_dask, num_x=1, num_y=3)

output_expected = np.array([[5., 6., 7.],
[5.5, 6.5, 7.5],
[6., 7., 8.],
[6.5, 7.5, 8.5],
[7., 8., 9.]])
np.testing.assert_allclose(res_np, output_expected)
np.testing.assert_allclose(res_dask, output_expected)
assert isinstance(res_np, np.ndarray)
assert isinstance(res_dask, da.Array)

# Test for no interpolation inbetween values
res = _linspace2D_between_values(arr_np, num_x=0, num_y=0)
np.testing.assert_allclose(res, arr_np)

# Test for valid inputs
self.assertRaises(ValueError, _linspace2D_between_values, arr_np, -1, 0)
self.assertRaises(ValueError, _linspace2D_between_values, arr_np, 0, -1)
self.assertRaises(ValueError, _linspace2D_between_values, arr_np[0, :], 0, 0)

def test_get_extended_lonlats(self):
import numpy as np

Expand Down

0 comments on commit 2e5997c

Please sign in to comment.