Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA Add a private check_array with additional parameters #25617

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from .exceptions import InconsistentVersionWarning
from .utils.validation import check_X_y
from .utils.validation import check_array
from .utils.validation import check_array, _check_array
from .utils.validation import _check_y
from .utils.validation import _num_features
from .utils.validation import _check_feature_names_in
Expand Down Expand Up @@ -574,7 +574,7 @@ def _validate_data(
if no_val_X and no_val_y:
raise ValueError("Validation should be done on X, y or both.")
elif not no_val_X and no_val_y:
X = check_array(X, input_name="X", **check_params)
X = _check_array(X, input_name="X", **check_params)
out = X
elif no_val_X and not no_val_y:
y = _check_y(y, **check_params)
Expand All @@ -588,10 +588,10 @@ def _validate_data(
check_X_params, check_y_params = validate_separately
if "estimator" not in check_X_params:
check_X_params = {**default_check_params, **check_X_params}
X = check_array(X, input_name="X", **check_X_params)
X = _check_array(X, input_name="X", **check_X_params)
if "estimator" not in check_y_params:
check_y_params = {**default_check_params, **check_y_params}
y = check_array(y, input_name="y", **check_y_params)
y = _check_array(y, input_name="y", **check_y_params)
else:
X, y = check_X_y(X, y, **check_params)
out = X, y
Expand Down
25 changes: 25 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,3 +1759,28 @@ def test_boolean_series_remains_boolean():

assert res.dtype == expected.dtype
assert_array_equal(res, expected)


def test_custom_asarray():
"""Check that a custom `asarray` function can be used during validation"""
est = BaseEstimator()

# This "special" asarray method converts elements of the input array
# to ints before creating a numpy array. Mostly so we can be sure it
# and not the standard asarray, was used for the data validation.
def my_asarray(array, copy=False, **kwargs):
converted = []
for row in array:
r = []
for element in row:
r.append(int(element))
converted.append(r)

if copy:
return np.array(converted, **kwargs)
else:
return np.asarray(converted, **kwargs)

x = est._validate_data([["1", "2", "3"], ["4", "5", "6"]], asarray=my_asarray)

assert x.dtype == int
156 changes: 149 additions & 7 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Sylvain Marie
# License: BSD 3 clause

from functools import wraps
from functools import partial, wraps
import warnings
import numbers
import operator
Expand Down Expand Up @@ -642,6 +642,130 @@ def check_array(
estimator=None,
input_name="",
):
"""Input validation on an array, list, sparse matrix or similar.

By default, the input is checked to be a non-empty 2D array containing
only finite values. If the dtype of the array is object, attempt
converting to float, raising on failure.

Parameters
----------
array : object
Input object to check / convert.

accept_sparse : str, bool or list/tuple of str, default=False
String[s] representing allowed sparse matrix formats, such as 'csc',
'csr', etc. If the input is sparse but not in the allowed format,
it will be converted to the first listed format. True allows the input
to be any format. False means that a sparse matrix input will
raise an error.

accept_large_sparse : bool, default=True
If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by
accept_sparse, accept_large_sparse=False will cause it to be accepted
only if its indices are stored with a 32-bit dtype.

.. versionadded:: 0.20

dtype : 'numeric', type, list of type or None, default='numeric'
Data type of result. If None, the dtype of the input is preserved.
If "numeric", dtype is preserved unless array.dtype is object.
If dtype is a list of types, conversion on the first type is only
performed if the dtype of the input is not in the list.

order : {'F', 'C'} or None, default=None
Whether an array will be forced to be fortran or c-style.
When order is None (default), then if copy=False, nothing is ensured
about the memory layout of the output array; otherwise (copy=True)
the memory layout of the returned array is kept as close as possible
to the original array.

copy : bool, default=False
Whether a forced copy will be triggered. If copy=False, a copy might
be triggered by a conversion.

force_all_finite : bool or 'allow-nan', default=True
Whether to raise an error on np.inf, np.nan, pd.NA in array. The
possibilities are:

- True: Force all values of array to be finite.
- False: accepts np.inf, np.nan, pd.NA in array.
- 'allow-nan': accepts only np.nan and pd.NA values in array. Values
cannot be infinite.

.. versionadded:: 0.20
``force_all_finite`` accepts the string ``'allow-nan'``.

.. versionchanged:: 0.23
Accepts `pd.NA` and converts it into `np.nan`

ensure_2d : bool, default=True
Whether to raise a value error if array is not 2D.

allow_nd : bool, default=False
Whether to allow array.ndim > 2.

ensure_min_samples : int, default=1
Make sure that the array has a minimum number of samples in its first
axis (rows for a 2D array). Setting to 0 disables this check.

ensure_min_features : int, default=1
Make sure that the 2D array has some minimum number of features
(columns). The default value of 1 rejects empty datasets.
This check is only enforced when the input data has effectively 2
dimensions or is originally 1D and ``ensure_2d`` is True. Setting to 0
disables this check.

estimator : str or estimator instance, default=None
If passed, include the name of the estimator in warning messages.

input_name : str, default=""
The data name used to construct the error message. In particular
if `input_name` is "X" and the data has NaN values and
allow_nan is False, the error message will link to the imputer
documentation.

.. versionadded:: 1.1.0

Returns
-------
array_converted : object
The converted and validated array.
"""
return _check_array(
array,
accept_sparse=accept_sparse,
accept_large_sparse=accept_large_sparse,
dtype=dtype,
order=order,
copy=copy,
force_all_finite=force_all_finite,
ensure_2d=ensure_2d,
allow_nd=allow_nd,
ensure_min_samples=ensure_min_samples,
ensure_min_features=ensure_min_features,
estimator=estimator,
input_name=input_name,
)


def _check_array(
array,
accept_sparse=False,
*,
accept_large_sparse=True,
dtype="numeric",
order=None,
copy=False,
force_all_finite=True,
ensure_2d=True,
allow_nd=False,
ensure_min_samples=1,
ensure_min_features=1,
estimator=None,
input_name="",
asarray=None,
):

"""Input validation on an array, list, sparse matrix or similar.

Expand Down Expand Up @@ -728,6 +852,15 @@ def check_array(

.. versionadded:: 1.1.0

asarray : callable, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The array return by asarray is required to have an interface that works with check_array. Specifically, the array object needs:

  1. array.ndim
  2. array.dtype.kind
  3. array.shape
  4. Work with _assert_all_finite?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference on dtype.kind: Not all array libraries define a dtype.kind attribute with the same semantics as NumPy. The Array API spec now has a isdtype, which is used to determine the "kind" of a dtype.

A callable to use instead of `np.asarray` when converting the input
array. Useful when the input array is not a Numpy array or when the
converted array should be a ndarray from a differnt library. The callable
should have the same signature as `np.asarray` and in addition support
they `copy` keyword argument.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the function signature be written out here? asarray(a, dtype=None, order=None, copy=False)

(np.asarray has a like kwarg which is not required)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm still thinking about what the best way is to describe how this asarray-like function should behave. Both in terms of its signature and what it can and can not return (your later comment).


.. versionadded:: 1.3.0

Returns
-------
array_converted : object
Expand All @@ -742,6 +875,9 @@ def check_array(

xp, is_array_api = get_namespace(array)

if asarray is None:
asarray = partial(_asarray_with_order, xp=xp)

# store reference to original array to check if copy is needed when
# function returns
array_orig = array
Expand Down Expand Up @@ -865,7 +1001,7 @@ def check_array(
# Conversion float -> int should not contain NaN or
# inf (numpy#14412). We cannot use casting='safe' because
# then conversion float -> int would be disallowed.
array = _asarray_with_order(array, order=order, xp=xp)
array = asarray(array, order=order)
if array.dtype.kind == "f":
_assert_all_finite(
array,
Expand All @@ -876,7 +1012,7 @@ def check_array(
)
array = xp.astype(array, dtype, copy=False)
else:
array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp)
array = asarray(array, order=order, dtype=dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the third party array libraries you are considering, do they usually work with NumPy dtypes, such as np.float32?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so indeed. It's the case for CuPy and dpctl.tensor which are our main use cases as part of the plugin API at the moment but it's probably the problem of the implementer of the custom asarray to make sure that they understand the usual dtypes we use in scikit-learn.

except ComplexWarning as complex_warning:
raise ValueError(
"Complex data not supported\n{}\n".format(array)
Expand Down Expand Up @@ -947,13 +1083,19 @@ def check_array(
if xp.__name__ in {"numpy", "numpy.array_api"}:
# only make a copy if `array` and `array_orig` may share memory`
if np.may_share_memory(array, array_orig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If array is a GPU array, I suspect np.may_share_memory will not work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then we won't go down this path no? At least I'd assume xp.__name__ would not be "numpy" in that case.

array = _asarray_with_order(
array, dtype=dtype, order=order, copy=True, xp=xp
array = asarray(
array,
dtype=dtype,
order=order,
copy=True,
)
else:
# always make a copy for non-numpy arrays
array = _asarray_with_order(
array, dtype=dtype, order=order, copy=True, xp=xp
array = asarray(
array,
dtype=dtype,
order=order,
copy=True,
)

return array
Expand Down