Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG + 2] Add check to regression models to raise error when targets are NaN #5431

Merged
merged 19 commits into from Oct 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
92068ef
#5322: Wrote check_supervised_y_no_nan in estimator_checks
hlin117 Oct 17, 2015
b4e746b
Resolving rebase conflicts on forest.py
hlin117 Oct 20, 2015
7745319
#5322: Added check for DecisionTreeRegressor
hlin117 Oct 17, 2015
b1a1aa8
#5322: Resolving errors for SVR and NuSVR
hlin117 Oct 18, 2015
a5860a1
#5322: Fixed build for tree failures, fixing estimator_checks.check_e…
hlin117 Oct 18, 2015
6df1fe2
#5322: Added check for MultiTaskElasticNet and MultiTaskLasso
hlin117 Oct 18, 2015
d62e316
#5322: Resolving Random Forest build errors (bad input shape)
hlin117 Oct 18, 2015
34c8399
#5322: Fixing the fix - now RFs should be okay
hlin117 Oct 18, 2015
7f82026
#5322: Fixing test in test_estimator_checks.test_check_estimator
hlin117 Oct 19, 2015
ef4faf6
#5322: check_supervised_y_no_nan raises error when error message is i…
hlin117 Oct 19, 2015
011dcd2
#5322: Fixed python3 build
hlin117 Oct 19, 2015
c657d77
#5322: Regression test for test_check_estimator
hlin117 Oct 19, 2015
eea5ba3
#5322: Made error message in check_supervised_y_no_nan more helpful
hlin117 Oct 19, 2015
3bfbe79
#5322: Making check_supervised_y_no_nan deterministic
hlin117 Oct 20, 2015
52a6acc
#5322: Addressing @MechCoder's changes
hlin117 Oct 21, 2015
7b3da77
#5322: Removing unused ConvergenceWarning import in svm/base.py
hlin117 Oct 21, 2015
290e0ea
#5322: Changed check_supervised_y_no_nan to not seed global
hlin117 Oct 21, 2015
1f48814
#5322: Using multioutput_estimator_convert_y_2d in check_supervised_y…
hlin117 Oct 21, 2015
f725485
#5322: Small change to estimator_check.py
hlin117 Oct 21, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion sklearn/ensemble/forest.py
Expand Up @@ -209,7 +209,8 @@ def fit(self, X, y, sample_weight=None):
Returns self.
"""
# Validate or convert input data
X = check_array(X, dtype=DTYPE, accept_sparse="csc")
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/coordinate_descent.py
Expand Up @@ -1642,7 +1642,7 @@ def fit(self, X, y):
# X and y must be of type float64
X = check_array(X, dtype=np.float64, order='F',
copy=self.copy_X and self.fit_intercept)
y = np.asarray(y, dtype=np.float64)
y = check_array(y, dtype=np.float64, ensure_2d=False)

if hasattr(self, 'l1_ratio'):
model_str = 'ElasticNet'
Expand Down
4 changes: 2 additions & 2 deletions sklearn/svm/base.py
Expand Up @@ -10,7 +10,7 @@
from ..base import BaseEstimator, ClassifierMixin
from ..preprocessing import LabelEncoder
from ..multiclass import _ovr_decision_function
from ..utils import check_array, check_random_state, column_or_1d
from ..utils import check_array, check_random_state, column_or_1d, check_X_y
from ..utils import compute_class_weight, deprecated
from ..utils.extmath import safe_sparse_dot
from ..utils.validation import check_is_fitted
Expand Down Expand Up @@ -151,7 +151,7 @@ def fit(self, X, y, sample_weight=None):
raise TypeError("Sparse precomputed kernels are not supported.")
self._sparse = sparse and not callable(self.kernel)

X = check_array(X, accept_sparse='csr', dtype=np.float64, order='C')
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? looks like self._validate_targets does it anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well actually no, sorry (which makes _validate_targets a strange name.

y = self._validate_targets(y)

sample_weight = np.asarray([]
Expand Down
3 changes: 2 additions & 1 deletion sklearn/tree/tree.py
Expand Up @@ -28,7 +28,7 @@
from ..base import RegressorMixin
from ..externals import six
from ..feature_selection.from_model import _LearntSelectorMixin
from ..utils import check_array
from ..utils import check_array, check_X_y
from ..utils import check_random_state
from ..utils import compute_sample_weight
from ..utils.multiclass import check_classification_targets
Expand Down Expand Up @@ -151,6 +151,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
random_state = check_random_state(self.random_state)
if check_input:
X = check_array(X, dtype=DTYPE, accept_sparse="csc")
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we don't accept csr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote accept_sparse='csc' because the original check for sparse matrices checked for csc matrices. Not sure whether we've actually tested whether DecisionTreeRegressor or DecisionTreeClassifier can take in csr matrices for targets.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should but converts them into csc form. Probably the csc format is optimized for the DecisionTree code hence.

if issparse(X):
X.sort_indices()

Expand Down
33 changes: 27 additions & 6 deletions sklearn/utils/estimator_checks.py
Expand Up @@ -131,6 +131,26 @@ def _yield_classifier_checks(name, Classifier):
if 'class_weight' in Classifier().get_params().keys():
yield check_class_weight_classifiers

def check_supervised_y_no_nan(name, Estimator):
# Checks that the Estimator targets are not NaN.

rng = np.random.RandomState(888)
X = rng.randn(10, 5)
y = np.ones(10) * np.inf
y = multioutput_estimator_convert_y_2d(name, y)

errmsg = "Input contains NaN, infinity or a value too large for " \
"dtype('float64')."
try:
Estimator().fit(X, y)
except ValueError as e:
if str(e) != errmsg:
raise ValueError("Estimator {0} raised warning as expected, but "
"does not match expected error message" \
.format(name))
else:
raise ValueError("Estimator {0} should have raised error on fitting "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nan in target / y maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Addressed in commit 86487ae.)

"array y with NaN value.".format(name))

def _yield_regressor_checks(name, Regressor):
# TODO: test with intercept
Expand All @@ -141,6 +161,7 @@ def _yield_regressor_checks(name, Regressor):
yield check_estimators_partial_fit_n_features
yield check_regressors_no_decision_function
yield check_supervised_y_2d
yield check_supervised_y_no_nan
if name != 'CCA':
# check that the regressor handles int input
yield check_regressors_int
Expand Down Expand Up @@ -207,10 +228,10 @@ def check_estimator(Estimator):
Parameters
----------
Estimator : class
Class to check.
Class to check. Estimator is a class object (not an instance).

"""
name = Estimator.__class__.__name__
name = Estimator.__name__
check_parameters_default_constructible(name, Estimator)
for check in _yield_all_checks(name, Estimator):
check(name, Estimator)
Expand Down Expand Up @@ -695,6 +716,7 @@ def check_estimators_empty_data_messages(name, Estimator):


def check_estimators_nan_inf(name, Estimator):
# Checks that Estimator X's do not contain NaN or inf.
rnd = np.random.RandomState(0)
X_train_finite = rnd.uniform(size=(10, 3))
X_train_nan = rnd.uniform(size=(10, 3))
Expand Down Expand Up @@ -1431,9 +1453,8 @@ def param_filter(p):
def multioutput_estimator_convert_y_2d(name, y):
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
# Convert into a 2-D y for those estimators.
if name in (['MultiTaskElasticNetCV', 'MultiTaskLassoCV',
'MultiTaskLasso', 'MultiTaskElasticNet']):
return y[:, np.newaxis]
if "MultiTask" in name:
return np.reshape(y, (-1, 1))
return y


Expand All @@ -1445,7 +1466,7 @@ def check_non_transformer_estimators_n_iter(name, estimator,
X, y_ = iris.data, iris.target

if multi_output:
y_ = y_[:, np.newaxis]
y_ = np.reshape(y_, (-1, 1))

set_random_state(estimator, 0)
if name == 'AffinityPropagation':
Expand Down
5 changes: 4 additions & 1 deletion sklearn/utils/tests/test_estimator_checks.py
Expand Up @@ -8,6 +8,7 @@
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.estimator_checks import check_estimators_unfitted
from sklearn.ensemble import AdaBoostClassifier
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.utils.validation import check_X_y, check_array


Expand Down Expand Up @@ -75,7 +76,8 @@ def test_check_estimator():
msg = "Estimator doesn't check for NaN and inf in predict"
assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict)
# check for sparse matrix input handling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweet

msg = "Estimator type doesn't seem to fail gracefully on sparse data"
name = NoSparseClassifier.__name__
msg = "Estimator " + name + " doesn't seem to fail gracefully on sparse data"
# the check for sparse input handling prints to the stdout,
# instead of raising an error, so as not to remove the original traceback.
# that means we need to jump through some hoops to catch it.
Expand All @@ -92,6 +94,7 @@ def test_check_estimator():

# doesn't error on actual estimator
check_estimator(AdaBoostClassifier)
check_estimator(MultiTaskElasticNet)


def test_check_estimators_unfitted():
Expand Down