diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 6bddd26faa606..dd0c007602654 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1176,7 +1176,15 @@ def column_or_1d(y, *, dtype=None, warn=False): If `y` is not a 1D array or a 2D array with a single row or column. """ xp, _ = get_namespace(y) - y = xp.asarray(y, dtype=dtype) + y = check_array( + y, + ensure_2d=False, + dtype=dtype, + input_name="y", + force_all_finite=False, + ensure_min_samples=0, + ) + shape = y.shape if len(shape) == 1: return _asarray_with_order(xp.reshape(y, -1), order="C", xp=xp)