New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG + 2] Add check to regression models to raise error when targets are NaN #5431
Changes from all commits
92068ef
b4e746b
7745319
b1a1aa8
a5860a1
6df1fe2
d62e316
34c8399
7f82026
ef4faf6
011dcd2
c657d77
eea5ba3
3bfbe79
52a6acc
7b3da77
290e0ea
1f48814
f725485
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
from ..base import RegressorMixin | ||
from ..externals import six | ||
from ..feature_selection.from_model import _LearntSelectorMixin | ||
from ..utils import check_array | ||
from ..utils import check_array, check_X_y | ||
from ..utils import check_random_state | ||
from ..utils import compute_sample_weight | ||
from ..utils.multiclass import check_classification_targets | ||
|
@@ -151,6 +151,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, | |
random_state = check_random_state(self.random_state) | ||
if check_input: | ||
X = check_array(X, dtype=DTYPE, accept_sparse="csc") | ||
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we don't accept csr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should but converts them into csc form. Probably the csc format is optimized for the DecisionTree code hence. |
||
if issparse(X): | ||
X.sort_indices() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,6 +131,26 @@ def _yield_classifier_checks(name, Classifier): | |
if 'class_weight' in Classifier().get_params().keys(): | ||
yield check_class_weight_classifiers | ||
|
||
def check_supervised_y_no_nan(name, Estimator): | ||
# Checks that the Estimator targets are not NaN. | ||
|
||
rng = np.random.RandomState(888) | ||
X = rng.randn(10, 5) | ||
y = np.ones(10) * np.inf | ||
y = multioutput_estimator_convert_y_2d(name, y) | ||
|
||
errmsg = "Input contains NaN, infinity or a value too large for " \ | ||
"dtype('float64')." | ||
try: | ||
Estimator().fit(X, y) | ||
except ValueError as e: | ||
if str(e) != errmsg: | ||
raise ValueError("Estimator {0} raised warning as expected, but " | ||
"does not match expected error message" \ | ||
.format(name)) | ||
else: | ||
raise ValueError("Estimator {0} should have raised error on fitting " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nan in target / y maybe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Addressed in commit 86487ae.) |
||
"array y with NaN value.".format(name)) | ||
|
||
def _yield_regressor_checks(name, Regressor): | ||
# TODO: test with intercept | ||
|
@@ -141,6 +161,7 @@ def _yield_regressor_checks(name, Regressor): | |
yield check_estimators_partial_fit_n_features | ||
yield check_regressors_no_decision_function | ||
yield check_supervised_y_2d | ||
yield check_supervised_y_no_nan | ||
if name != 'CCA': | ||
# check that the regressor handles int input | ||
yield check_regressors_int | ||
|
@@ -207,10 +228,10 @@ def check_estimator(Estimator): | |
Parameters | ||
---------- | ||
Estimator : class | ||
Class to check. | ||
Class to check. Estimator is a class object (not an instance). | ||
|
||
""" | ||
name = Estimator.__class__.__name__ | ||
name = Estimator.__name__ | ||
check_parameters_default_constructible(name, Estimator) | ||
for check in _yield_all_checks(name, Estimator): | ||
check(name, Estimator) | ||
|
@@ -695,6 +716,7 @@ def check_estimators_empty_data_messages(name, Estimator): | |
|
||
|
||
def check_estimators_nan_inf(name, Estimator): | ||
# Checks that Estimator X's do not contain NaN or inf. | ||
rnd = np.random.RandomState(0) | ||
X_train_finite = rnd.uniform(size=(10, 3)) | ||
X_train_nan = rnd.uniform(size=(10, 3)) | ||
|
@@ -1431,9 +1453,8 @@ def param_filter(p): | |
def multioutput_estimator_convert_y_2d(name, y): | ||
# Estimators in mono_output_task_error raise ValueError if y is of 1-D | ||
# Convert into a 2-D y for those estimators. | ||
if name in (['MultiTaskElasticNetCV', 'MultiTaskLassoCV', | ||
'MultiTaskLasso', 'MultiTaskElasticNet']): | ||
return y[:, np.newaxis] | ||
if "MultiTask" in name: | ||
return np.reshape(y, (-1, 1)) | ||
return y | ||
|
||
|
||
|
@@ -1445,7 +1466,7 @@ def check_non_transformer_estimators_n_iter(name, estimator, | |
X, y_ = iris.data, iris.target | ||
|
||
if multi_output: | ||
y_ = y_[:, np.newaxis] | ||
y_ = np.reshape(y_, (-1, 1)) | ||
|
||
set_random_state(estimator, 0) | ||
if name == 'AffinityPropagation': | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from sklearn.utils.estimator_checks import check_estimator | ||
from sklearn.utils.estimator_checks import check_estimators_unfitted | ||
from sklearn.ensemble import AdaBoostClassifier | ||
from sklearn.linear_model import MultiTaskElasticNet | ||
from sklearn.utils.validation import check_X_y, check_array | ||
|
||
|
||
|
@@ -75,7 +76,8 @@ def test_check_estimator(): | |
msg = "Estimator doesn't check for NaN and inf in predict" | ||
assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict) | ||
# check for sparse matrix input handling | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sweet |
||
msg = "Estimator type doesn't seem to fail gracefully on sparse data" | ||
name = NoSparseClassifier.__name__ | ||
msg = "Estimator " + name + " doesn't seem to fail gracefully on sparse data" | ||
# the check for sparse input handling prints to the stdout, | ||
# instead of raising an error, so as not to remove the original traceback. | ||
# that means we need to jump through some hoops to catch it. | ||
|
@@ -92,6 +94,7 @@ def test_check_estimator(): | |
|
||
# doesn't error on actual estimator | ||
check_estimator(AdaBoostClassifier) | ||
check_estimator(MultiTaskElasticNet) | ||
|
||
|
||
def test_check_estimators_unfitted(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? looks like
self._validate_targets
does it anyway.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well actually no, sorry (which makes
_validate_targets
a strange name.