Skip to content

Commit

Permalink
ENH Improve column_or_1d error message (#15926)
Browse files Browse the repository at this point in the history
  • Loading branch information
lesteve authored and thomasjpfan committed Dec 20, 2019
1 parent a5542e9 commit 9626583
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -106,3 +106,9 @@ Changelog
- |Fix| :func:`tree.plot_tree` `rotate` parameter was unused and has been
deprecated.
:pr:`15806` by :user:`Chiara Marmo <cmarmo>`.

:mod:`sklearn.utils`
....................

- |Enhancement| improve error message in :func:`utils.validation.column_or_1d`.
:pr:`15926` by :user:`Loïc Estève <lesteve>`.
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_classification.py
Expand Up @@ -486,7 +486,7 @@ def test_multilabel_confusion_matrix_errors():
# Bad sample_weight
with pytest.raises(ValueError, match="inconsistent numbers of samples"):
multilabel_confusion_matrix(y_true, y_pred, sample_weight=[1, 2])
with pytest.raises(ValueError, match="bad input shape"):
with pytest.raises(ValueError, match="should be a 1d array"):
multilabel_confusion_matrix(y_true, y_pred,
sample_weight=[[1, 2, 3],
[2, 3, 4],
Expand Down
4 changes: 2 additions & 2 deletions sklearn/preprocessing/tests/test_label.py
Expand Up @@ -222,7 +222,7 @@ def test_label_encoder_negative_ints():
def test_label_encoder_str_bad_shape(dtype):
le = LabelEncoder()
le.fit(np.array(["apple", "orange"], dtype=dtype))
msg = "bad input shape"
msg = "should be a 1d array"
with pytest.raises(ValueError, match=msg):
le.transform("apple")

Expand All @@ -245,7 +245,7 @@ def test_label_encoder_errors():
le.inverse_transform([-2, -3, -4])

# Fail on inverse_transform("")
msg = "bad input shape ()"
msg = r"should be a 1d array.+shape \(\)"
with pytest.raises(ValueError, match=msg):
le.inverse_transform("")

Expand Down
4 changes: 3 additions & 1 deletion sklearn/utils/validation.py
Expand Up @@ -743,7 +743,9 @@ def column_or_1d(y, warn=False):
DataConversionWarning, stacklevel=2)
return np.ravel(y)

raise ValueError("bad input shape {0}".format(shape))
raise ValueError(
"y should be a 1d array, "
"got an array of shape {} instead.".format(shape))


def check_random_state(seed):
Expand Down

0 comments on commit 9626583

Please sign in to comment.