Skip to content

Commit

Permalink
FIX raise error for float to int casting with NaN, inf in check_array (
Browse files Browse the repository at this point in the history
  • Loading branch information
rth authored and glemaitre committed Sep 20, 2019
1 parent 97185ec commit e6a4dc9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.22.rst
Expand Up @@ -554,6 +554,10 @@ Changelog
and sparse matrix.
:pr:`14538` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

- |Fix| :func:`utils.check_array` is now raising an error instead of casting
NaN to integer.
:pr:`14872` by `Roman Yurchak`_.

:mod:`sklearn.metrics`
..................................

Expand Down
20 changes: 20 additions & 0 deletions sklearn/utils/tests/test_validation.py
Expand Up @@ -202,6 +202,26 @@ def test_check_array_force_all_finite_object():
check_array(X, dtype=None, force_all_finite=True)


@pytest.mark.parametrize(
"X, err_msg",
[(np.array([[1, np.nan]]),
"Input contains NaN, infinity or a value too large for.*int"),
(np.array([[1, np.nan]]),
"Input contains NaN, infinity or a value too large for.*int"),
(np.array([[1, np.inf]]),
"Input contains NaN, infinity or a value too large for.*int"),
(np.array([[1, np.nan]], dtype=np.object),
"cannot convert float NaN to integer")]
)
@pytest.mark.parametrize("force_all_finite", [True, False])
def test_check_array_force_all_finite_object_unsafe_casting(
X, err_msg, force_all_finite):
# casting a float array containing NaN or inf to int dtype should
# raise an error irrespective of the force_all_finite parameter.
with pytest.raises(ValueError, match=err_msg):
check_array(X, dtype=np.int, force_all_finite=force_all_finite)


@ignore_warnings
def test_check_array():
# accept_sparse == False
Expand Down
20 changes: 17 additions & 3 deletions sklearn/utils/validation.py
Expand Up @@ -32,7 +32,7 @@
warnings.simplefilter('ignore', NonBLASDotWarning)


def _assert_all_finite(X, allow_nan=False):
def _assert_all_finite(X, allow_nan=False, msg_dtype=None):
"""Like assert_all_finite, but only for ndarray."""
# validation is also imported in extmath
from .extmath import _safe_accumulator_op
Expand All @@ -52,7 +52,11 @@ def _assert_all_finite(X, allow_nan=False):
if (allow_nan and np.isinf(X).any() or
not allow_nan and not np.isfinite(X).all()):
type_err = 'infinity' if allow_nan else 'NaN, infinity'
raise ValueError(msg_err.format(type_err, X.dtype))
raise ValueError(
msg_err.format
(type_err,
msg_dtype if msg_dtype is not None else X.dtype)
)
# for object dtype data, we only check for NaNs (GH-13254)
elif X.dtype == np.dtype('object') and not allow_nan:
if _object_dtype_isnan(X).any():
Expand Down Expand Up @@ -494,7 +498,17 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
with warnings.catch_warnings():
try:
warnings.simplefilter('error', ComplexWarning)
array = np.asarray(array, dtype=dtype, order=order)
if dtype is not None and np.dtype(dtype).kind in 'iu':
# Conversion float -> int should not contain NaN or
# inf (numpy#14412). We cannot use casting='safe' because
# then conversion float -> int would be disallowed.
array = np.asarray(array, order=order)
if array.dtype.kind == 'f':
_assert_all_finite(array, allow_nan=False,
msg_dtype=dtype)
array = array.astype(dtype, casting="unsafe", copy=False)
else:
array = np.asarray(array, order=order, dtype=dtype)
except ComplexWarning:
raise ValueError("Complex data not supported\n"
"{}\n".format(array))
Expand Down

0 comments on commit e6a4dc9

Please sign in to comment.