Skip to content

Commit

Permalink
Review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Sep 24, 2019
1 parent 2b7dab1 commit 8cef195
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions sklearn/utils/estimator_checks.py
Expand Up @@ -1518,7 +1518,10 @@ def check_classifier_multioutput(name, estimator):
if hasattr(estimator, "decision_function"):
decision = estimator.decision_function(X)
assert isinstance(decision, np.ndarray)
assert decision.shape == (n_samples, n_labels)
assert (decision.shape == (n_samples, n_labels),
"The shape of the decision function output for"
" multioutput data is incorrect. Expected {}, got {}."
.format((n_samples, n_labels), decision.shape))

dec_pred = (decision > 0).astype(np.int)
dec_exp = estimator.classes_[dec_pred]
Expand All @@ -1529,13 +1532,19 @@ def check_classifier_multioutput(name, estimator):

if isinstance(y_prob, list) and not tags['poor_score']:
for i in range(n_labels):
y_prob[i].shape == (n_samples, n_labels)
assert (y_prob[i].shape == (n_samples, n_labels),
"The shape of the probability for multioutput data is"
" incorrect. Expected {}, got {}."
.format((n_samples, n_labels), y_prob[i].shape))
assert_array_equal(
np.argmax(y_prob[i], axis=1).astype(np.int),
y_pred[:, i]
)
elif not tags['poor_score']:
y_prob.shape == (n_samples, n_labels)
assert (y_prob.shape == (n_samples, n_labels),
"The shape of the probability for multioutput data is"
" incorrect. Expected {}, got {}."
.format((n_samples, n_labels), y_prob.shape))
assert_array_equal(y_prob.round().astype(int), y_pred)

if (hasattr(estimator, "decision_function") and
Expand All @@ -1561,8 +1570,12 @@ def check_regressor_multioutput(name, estimator):
estimator.fit(X, y)
y_pred = estimator.predict(X)

assert y_pred.dtype == np.dtype('float')
assert y_pred.shape == y.shape
assert (y_pred.dtype == np.dtype('f'),
"Multioutput predictions by a regressor are expected to be"
" floating-point precision. Got {} instead".format(y_pred.dtype))
assert (y_pred.shape == y.shape,
"The shape of the orediction for multioutput data is incorrect."
" Expected {}, got {}.")


@ignore_warnings(category=(DeprecationWarning, FutureWarning))
Expand Down

0 comments on commit 8cef195

Please sign in to comment.