-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG + 2] Add check to regression models to raise error when targets are NaN #5431
Changes from 14 commits
92068ef
b4e746b
7745319
b1a1aa8
a5860a1
6df1fe2
d62e316
34c8399
7f82026
ef4faf6
011dcd2
c657d77
eea5ba3
3bfbe79
52a6acc
7b3da77
290e0ea
1f48814
f725485
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,8 @@ | |
from ..base import BaseEstimator, ClassifierMixin | ||
from ..preprocessing import LabelEncoder | ||
from ..multiclass import _ovr_decision_function | ||
from ..utils import check_array, check_random_state, column_or_1d | ||
from ..utils import compute_class_weight, deprecated | ||
from ..utils import check_array, check_random_state, column_or_1d, check_X_y | ||
from ..utils import ConvergenceWarning, compute_class_weight, deprecated | ||
from ..utils.extmath import safe_sparse_dot | ||
from ..utils.validation import check_is_fitted | ||
from ..utils.multiclass import check_classification_targets | ||
|
@@ -151,7 +151,8 @@ def fit(self, X, y, sample_weight=None): | |
raise TypeError("Sparse precomputed kernels are not supported.") | ||
self._sparse = sparse and not callable(self.kernel) | ||
|
||
X = check_array(X, accept_sparse='csr', dtype=np.float64, order='C') | ||
#X = check_array(X, accept_sparse='csr', dtype=np.float64, order='C') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this commented line please |
||
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this needed? looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well actually no, sorry (which makes |
||
y = self._validate_targets(y) | ||
|
||
sample_weight = np.asarray([] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
from ..base import RegressorMixin | ||
from ..externals import six | ||
from ..feature_selection.from_model import _LearntSelectorMixin | ||
from ..utils import check_array | ||
from ..utils import check_array, check_X_y | ||
from ..utils import check_random_state | ||
from ..utils import compute_sample_weight | ||
from ..utils.multiclass import check_classification_targets | ||
|
@@ -151,6 +151,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, | |
random_state = check_random_state(self.random_state) | ||
if check_input: | ||
X = check_array(X, dtype=DTYPE, accept_sparse="csc") | ||
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we don't accept csr? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should but converts them into csc form. Probably the csc format is optimized for the DecisionTree code hence. |
||
if issparse(X): | ||
X.sort_indices() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,6 +131,30 @@ def _yield_classifier_checks(name, Classifier): | |
if 'class_weight' in Classifier().get_params().keys(): | ||
yield check_class_weight_classifiers | ||
|
||
def check_supervised_y_no_nan(name, Estimator): | ||
# Checks that the Estimator targets are not NaN. | ||
|
||
warnings.simplefilter("ignore") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are there warnings? If it's due to the division by zero, you can replace it with y1 = np.inf * np.ones(10) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are there warnings? If it's due to the division by zero, you can replace it with y1 = np.inf * np.ones(10) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh, smart =] I'll push this change soon. |
||
np.random.seed(888) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Can you replace with this rng = np.random.RandomState(888) just to be consistent with others. |
||
X = np.random.randn(10, 5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should use a local random number generator here, as in 'rng = np.random.RandomState(888)' and not seed the global, as it create a side effect and is not concurrency safe. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought I commented the same thing. Maybe disappeared in the git diffs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MechCoder: I thought that I fixed this in 290e0ea. |
||
y1 = np.random.randn(10) / 0. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should set the seed, I guess There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it matter? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will not matter in the absolute sense. But it is better to have deterministic input for every run (both X and y1) |
||
y2 = np.random.randn(10, 2) / 0. | ||
|
||
errmsg = "Input contains NaN, infinity or a value too large for " \ | ||
"dtype('float64')." | ||
try: | ||
if "MultiTask" in name: | ||
Estimator().fit(X, y2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ironically if I didn't add that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoops. Nice catch. However I still think code reuse would be a good idea. Can you
in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in commit 52a6acc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, you would have still discovered the bug, I think, as that also checks the name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, can you change this to use the existing function? Thanks ! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check out commit 1f48814. Thanks! |
||
else: | ||
Estimator().fit(X, y1) | ||
except ValueError as e: | ||
if str(e) != errmsg: | ||
raise ValueError("Estimator {0} raised warning as expected, but " | ||
"does not match expected error message" \ | ||
.format(name)) | ||
else: | ||
raise ValueError("Estimator {0} should have raised error on fitting " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nan in target / y maybe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Addressed in commit 86487ae.) |
||
"array y with NaN value.".format(name)) | ||
|
||
def _yield_regressor_checks(name, Regressor): | ||
# TODO: test with intercept | ||
|
@@ -141,6 +165,7 @@ def _yield_regressor_checks(name, Regressor): | |
yield check_estimators_partial_fit_n_features | ||
yield check_regressors_no_decision_function | ||
yield check_supervised_y_2d | ||
yield check_supervised_y_no_nan | ||
if name != 'CCA': | ||
# check that the regressor handles int input | ||
yield check_regressors_int | ||
|
@@ -207,10 +232,10 @@ def check_estimator(Estimator): | |
Parameters | ||
---------- | ||
Estimator : class | ||
Class to check. | ||
Class to check. Estimator is a class object (not an instance). | ||
|
||
""" | ||
name = Estimator.__class__.__name__ | ||
name = Estimator.__name__ | ||
check_parameters_default_constructible(name, Estimator) | ||
for check in _yield_all_checks(name, Estimator): | ||
check(name, Estimator) | ||
|
@@ -695,6 +720,7 @@ def check_estimators_empty_data_messages(name, Estimator): | |
|
||
|
||
def check_estimators_nan_inf(name, Estimator): | ||
# Checks that Estimator X's do not contain NaN or inf. | ||
rnd = np.random.RandomState(0) | ||
X_train_finite = rnd.uniform(size=(10, 3)) | ||
X_train_nan = rnd.uniform(size=(10, 3)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from sklearn.utils.estimator_checks import check_estimator | ||
from sklearn.utils.estimator_checks import check_estimators_unfitted | ||
from sklearn.ensemble import AdaBoostClassifier | ||
from sklearn.linear_model import MultiTaskElasticNet | ||
from sklearn.utils.validation import check_X_y, check_array | ||
|
||
|
||
|
@@ -75,7 +76,8 @@ def test_check_estimator(): | |
msg = "Estimator doesn't check for NaN and inf in predict" | ||
assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict) | ||
# check for sparse matrix input handling | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sweet |
||
msg = "Estimator type doesn't seem to fail gracefully on sparse data" | ||
name = NoSparseClassifier.__name__ | ||
msg = "Estimator " + name + " doesn't seem to fail gracefully on sparse data" | ||
# the check for sparse input handling prints to the stdout, | ||
# instead of raising an error, so as not to remove the original traceback. | ||
# that means we need to jump through some hoops to catch it. | ||
|
@@ -92,6 +94,7 @@ def test_check_estimator(): | |
|
||
# doesn't error on actual estimator | ||
check_estimator(AdaBoostClassifier) | ||
check_estimator(MultiTaskElasticNet) | ||
|
||
|
||
def test_check_estimators_unfitted(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like ConvergenceWarning is unused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one wasn't addressed, right?