Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single precision support in moments functions #5344

Merged
merged 8 commits into from
Apr 23, 2021
54 changes: 53 additions & 1 deletion skimage/_shared/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
import numpy.testing as npt
from skimage._shared.utils import (check_nD, deprecate_kwarg,
_validate_interpolation_order,
change_default_value, remove_arg)
change_default_value, remove_arg,
_supported_float_type)
from skimage._shared import testing
from skimage._shared._warnings import expected_warnings

complex_dtypes = [np.complex64, np.complex128]
if hasattr(np, 'complex256'):
complex_dtypes += [np.complex256]


def test_remove_argument():

Expand Down Expand Up @@ -186,5 +191,52 @@ def test_validate_interpolation_order(dtype, order):
assert _validate_interpolation_order(dtype, order) == order


@pytest.mark.parametrize(
'dtype',
[bool, np.float16, np.float32, np.float64, np.uint8, np.uint16, np.uint32,
np.uint64, np.int8, np.int16, np.int32, np.int64]
)
def test_supported_float_dtype_real(dtype):
float_dtype = _supported_float_type(dtype)
if dtype in [np.float16, np.float32]:
assert float_dtype == np.float32
else:
assert float_dtype == np.float64


@pytest.mark.parametrize('dtype', complex_dtypes)
@pytest.mark.parametrize('allow_complex', [False, True])
def test_supported_float_dtype_complex(dtype, allow_complex):
if allow_complex:
float_dtype = _supported_float_type(dtype, allow_complex=allow_complex)
if dtype == np.complex64:
assert float_dtype == np.complex64
else:
assert float_dtype == np.complex128
else:
with testing.raises(ValueError):
_supported_float_type(dtype, allow_complex=allow_complex)


@pytest.mark.parametrize(
'dtype', ['f', 'float32', np.float32, np.dtype(np.float32)]
)
def test_supported_float_dtype_input_kinds(dtype):
assert _supported_float_type(dtype) == np.float32


@pytest.mark.parametrize(
'dtypes, expected',
[
((np.float16, np.float64), np.float64),
([np.float32, np.uint16, np.int8], np.float64),
({np.float32, np.float16}, np.float32),
]
)
def test_supported_float_dtype_sequence(dtypes, expected):
float_dtype = _supported_float_type(dtypes)
assert float_dtype == expected


if __name__ == "__main__":
npt.run_module_suite()
45 changes: 45 additions & 0 deletions skimage/_shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numbers
import sys
import warnings
from collections.abc import Iterable

import numpy as np
from numpy.lib import NumpyVersion
Expand Down Expand Up @@ -579,3 +580,47 @@ def _fix_ndimage_mode(mode):
if NumpyVersion(scipy.__version__) >= '1.6.0':
mode = grid_modes.get(mode, mode)
return mode


new_float_type = {
# preserved types
np.float32().dtype.char: np.float32,
np.float64().dtype.char: np.float64,
np.complex64().dtype.char: np.complex64,
np.complex128().dtype.char: np.complex128,
# altered types
np.float16().dtype.char: np.float32,
'g': np.float64, # np.float128 ; doesn't exist on windows
'G': np.complex128, # np.complex256 ; doesn't exist on windows
}


def _supported_float_type(input_dtype, allow_complex=False):
"""Return an appropriate floating-point dtype for a given dtype.

float32, float64, complex64, complex128 are preserved.
float16 is promoted to float32.
complex256 is demoted to complex128.
Other types are cast to float64.

Parameters
----------
input_dtype : np.dtype or Iterable of np.dtype
The input dtype. If a sequence of multiple dtypes is provided, each
dtype is first converted to a supported floating point type and the
final dtype is then determined by applying `np.result_type` on the
sequence of supported floating point types.
allow_complex : bool, optional
If False, raise a ValueError on complex-valued inputs.

Returns
-------
float_type : dtype
Floating-point dtype for the image.
"""
if isinstance(input_dtype, Iterable) and not isinstance(input_dtype, str):
return np.result_type(*(_supported_float_type(d) for d in input_dtype))
input_dtype = np.dtype(input_dtype)
if not allow_complex and input_dtype.kind == 'c':
raise ValueError("complex valued input is not supported")
return new_float_type.get(input_dtype.char, np.float64)
21 changes: 13 additions & 8 deletions skimage/measure/_moments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from .._shared.utils import check_nD
from .._shared.utils import _supported_float_type, check_nD
from . import _moments_cy
import itertools

Expand Down Expand Up @@ -115,15 +115,17 @@ def moments_coords_central(coords, center=None, order=3):
coords = np.stack(coords, axis=-1)
check_nD(coords, 2)
ndim = coords.shape[1]

float_type = _supported_float_type(coords.dtype)
if center is None:
center = np.mean(coords, axis=0)
center = np.mean(coords, axis=0, dtype=float)

# center the coordinates
coords = coords.astype(float) - center
coords = coords.astype(float_type, copy=False) - center

# generate all possible exponents for each axis in the given set of points
# produces a matrix of shape (N, D, order + 1)
coords = coords[..., np.newaxis] ** np.arange(order + 1)
coords = np.stack([coords ** c for c in range(order + 1)], axis=-1)

# add extra dimensions for proper broadcasting
coords = coords.reshape(coords.shape + (1,) * (ndim - 1))
Expand Down Expand Up @@ -240,10 +242,13 @@ def moments_central(image, center=None, order=3, **kwargs):
"""
if center is None:
center = centroid(image)
calc = image.astype(float)
float_dtype = _supported_float_type(image.dtype)
calc = image.astype(float_dtype, copy=False)
for dim, dim_length in enumerate(image.shape):
delta = np.arange(dim_length, dtype=float) - center[dim]
powers_of_delta = delta[:, np.newaxis] ** np.arange(order + 1)
delta = np.arange(dim_length, dtype=float_dtype) - center[dim]
powers_of_delta = (
delta[:, np.newaxis] ** np.arange(order + 1, dtype=float_dtype)
)
calc = np.rollaxis(calc, dim, image.ndim)
calc = np.dot(calc, powers_of_delta)
calc = np.rollaxis(calc, -1, dim)
Expand Down Expand Up @@ -407,7 +412,7 @@ def inertia_tensor(image, mu=None):
if mu is None:
mu = moments_central(image, order=2) # don't need higher-order moments
mu0 = mu[(0,) * image.ndim]
result = np.zeros((image.ndim, image.ndim))
result = np.zeros((image.ndim, image.ndim), dtype=mu.dtype)

# nD expression to get coordinates ([2, 0], [0, 2]) (2D),
# ([2, 0, 0], [0, 2, 0], [0, 0, 2]) (3D), etc.
Expand Down
61 changes: 49 additions & 12 deletions skimage/measure/tests/test_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skimage._shared import testing
from skimage._shared.testing import (assert_equal, assert_almost_equal,
assert_allclose)
from skimage._shared.utils import _supported_float_type
grlee77 marked this conversation as resolved.
Show resolved Hide resolved


def test_moments():
Expand Down Expand Up @@ -58,6 +59,23 @@ def test_moments_coords():
assert_almost_equal(mu_coords, mu_image)


@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_moments_coords_dtype(dtype):
image = np.zeros((20, 20), dtype=dtype)
image[13:17, 13:17] = 1

expected_dtype = _supported_float_type(dtype)
mu_image = moments(image)
assert mu_image.dtype == expected_dtype

coords = np.array([[r, c] for r in range(13, 17)
for c in range(13, 17)], dtype=dtype)
mu_coords = moments_coords(coords)
assert mu_coords.dtype == expected_dtype

assert_almost_equal(mu_coords, mu_image)


def test_moments_central_coords():
image = np.zeros((20, 20), dtype=np.double)
image[13:17, 13:17] = 1
Expand Down Expand Up @@ -133,32 +151,51 @@ def test_moments_hu():
assert_almost_equal(hu, hu2, decimal=1)


@pytest.mark.parametrize('dtype', ['float32', 'float64'])
def test_moments_hu_dtype(dtype):
image = np.zeros((20, 20), dtype=np.double)
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_moments_dtype(dtype):
image = np.zeros((20, 20), dtype=dtype)
image[13:15, 13:17] = 1

expected_dtype = _supported_float_type(image)
mu = moments_central(image, (13.5, 14.5))
assert mu.dtype == expected_dtype

nu = moments_normalized(mu)
hu = moments_hu(nu.astype(dtype))
assert nu.dtype == expected_dtype

assert hu.dtype == dtype
hu = moments_hu(nu)
assert hu.dtype == expected_dtype


def test_centroid():
image = np.zeros((20, 20), dtype=np.double)
@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_centroid(dtype):
image = np.zeros((20, 20), dtype=dtype)
image[14, 14:16] = 1
image[15, 14:16] = 1/3
image_centroid = centroid(image)
assert_allclose(image_centroid, (14.25, 14.5))


def test_inertia_tensor_2d():
image = np.zeros((40, 40))
if dtype == np.float16:
rtol = 1e-3
elif dtype == np.float32:
rtol = 1e-5
else:
rtol = 1e-7
assert_allclose(image_centroid, (14.25, 14.5), rtol=rtol)


@pytest.mark.parametrize('dtype', [np.float16, np.float32, np.float64])
def test_inertia_tensor_2d(dtype):
image = np.zeros((40, 40), dtype=dtype)
image[15:25, 5:35] = 1 # big horizontal rectangle (aligned with axis 1)
expected_dtype = _supported_float_type(image.dtype)

T = inertia_tensor(image)
assert T.dtype == expected_dtype
assert T[0, 0] > T[1, 1]
np.testing.assert_allclose(T[0, 1], 0)

v0, v1 = inertia_tensor_eigvals(image, T=T)
assert v0.dtype == expected_dtype
assert v1.dtype == expected_dtype
np.testing.assert_allclose(np.sqrt(v0/v1), 3, rtol=0.01, atol=0.05)


Expand Down