Skip to content

Commit

Permalink
TST Add requires_positive_y estimator tag (#14095)
Browse files Browse the repository at this point in the history
  • Loading branch information
rth authored and jnothman committed Jun 24, 2019
1 parent a717619 commit 78ac1ab
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 41 deletions.
5 changes: 4 additions & 1 deletion doc/developers/contributing.rst
Expand Up @@ -1514,9 +1514,12 @@ The current set of estimator tags are:
non_deterministic
whether the estimator is not deterministic given a fixed ``random_state``

requires_positive_data - unused for now
requires_positive_X - unused for now
whether the estimator requires positive X.

requires_positive_y
whether the estimator requires a positive y (only applicable for regression).

no_validation
whether the estimator skips input-validation. This is only meant for stateless and dummy transformers!

Expand Down
3 changes: 2 additions & 1 deletion sklearn/base.py
Expand Up @@ -17,7 +17,8 @@

_DEFAULT_TAGS = {
'non_deterministic': False,
'requires_positive_data': False,
'requires_positive_X': False,
'requires_positive_y': False,
'X_types': ['2darray'],
'poor_score': False,
'no_validation': False,
Expand Down
68 changes: 37 additions & 31 deletions sklearn/utils/estimator_checks.py
Expand Up @@ -145,7 +145,7 @@ def check_supervised_y_no_nan(name, estimator_orig):
rng = np.random.RandomState(888)
X = rng.randn(10, 5)
y = np.full(10, np.inf)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

errmsg = "Input contains NaN, infinity or a value too large for " \
"dtype('float64')."
Expand Down Expand Up @@ -509,7 +509,7 @@ def check_estimator_sparse_data(name, estimator_orig):
# catch deprecation warnings
with ignore_warnings(category=DeprecationWarning):
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)
for matrix_format, X in _generate_sparse_matrix(X_csr):
# catch deprecation warnings
with ignore_warnings(category=(DeprecationWarning, FutureWarning)):
Expand Down Expand Up @@ -592,7 +592,7 @@ def check_sample_weights_list(name, estimator_orig):
y = np.arange(10) % 2
else:
y = np.arange(10) % 3
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)
sample_weight = [3] * 10
# Test that estimators don't raise any exception
estimator.fit(X, y, sample_weight=sample_weight)
Expand All @@ -618,7 +618,7 @@ def check_sample_weights_invariance(name, estimator_orig):
[4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.dtype('float'))
y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
1, 1, 1, 1, 2, 2, 2, 2], dtype=np.dtype('int'))
y = multioutput_estimator_convert_y_2d(estimator1, y)
y = enforce_estimator_tags_y(estimator1, y)

estimator1.fit(X, y=y, sample_weight=np.ones(shape=len(y)))
estimator2.fit(X, y=y, sample_weight=None)
Expand Down Expand Up @@ -648,7 +648,7 @@ def check_dtype_object(name, estimator_orig):
else:
y = (X[:, 0] * 4).astype(np.int)
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

estimator.fit(X, y)
if hasattr(estimator, "predict"):
Expand Down Expand Up @@ -703,7 +703,7 @@ def check_dict_unchanged(name, estimator_orig):

y = X[:, 0].astype(np.int)
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)
if hasattr(estimator, "n_components"):
estimator.n_components = 1

Expand Down Expand Up @@ -742,7 +742,7 @@ def check_dont_overwrite_parameters(name, estimator_orig):
y = X[:, 0].astype(np.int)
if _safe_tags(estimator, 'binary_only'):
y[y == 2] = 1
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

if hasattr(estimator, "n_components"):
estimator.n_components = 1
Expand Down Expand Up @@ -795,7 +795,7 @@ def check_fit2d_predict1d(name, estimator_orig):
if tags['binary_only']:
y[y == 2] = 1
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

if hasattr(estimator, "n_components"):
estimator.n_components = 1
Expand Down Expand Up @@ -843,7 +843,7 @@ def check_methods_subset_invariance(name, estimator_orig):
if _safe_tags(estimator_orig, 'binary_only'):
y[y == 2] = 1
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

if hasattr(estimator, "n_components"):
estimator.n_components = 1
Expand Down Expand Up @@ -882,7 +882,7 @@ def check_fit2d_1sample(name, estimator_orig):
X = 3 * rnd.uniform(size=(1, 10))
y = X[:, 0].astype(np.int)
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

if hasattr(estimator, "n_components"):
estimator.n_components = 1
Expand Down Expand Up @@ -914,7 +914,7 @@ def check_fit2d_1feature(name, estimator_orig):
X = pairwise_estimator_convert_X(X, estimator_orig)
y = X[:, 0].astype(np.int)
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

if hasattr(estimator, "n_components"):
estimator.n_components = 1
Expand All @@ -927,7 +927,7 @@ def check_fit2d_1feature(name, estimator_orig):
if name == 'RANSACRegressor':
estimator.residual_threshold = 0.5

y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)
set_random_state(estimator, 1)

msgs = ["1 feature(s)", "n_features = 1", "n_features=1"]
Expand All @@ -950,7 +950,7 @@ def check_fit1d(name, estimator_orig):
if tags["no_validation"]:
# FIXME this is a bit loose
return
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

if hasattr(estimator, "n_components"):
estimator.n_components = 1
Expand Down Expand Up @@ -1086,7 +1086,7 @@ def check_pipeline_consistency(name, estimator_orig):
X -= X.min()
X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)
set_random_state(estimator)
pipeline = make_pipeline(estimator)
estimator.fit(X, y)
Expand Down Expand Up @@ -1115,7 +1115,7 @@ def check_fit_score_takes_y(name, estimator_orig):
else:
y = np.arange(10) % 3
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)
set_random_state(estimator)

funcs = ["fit", "score", "partial_fit", "fit_predict", "fit_transform"]
Expand Down Expand Up @@ -1145,7 +1145,7 @@ def check_estimators_dtypes(name, estimator_orig):
y = X_train_int_64[:, 0]
if _safe_tags(estimator_orig, 'binary_only'):
y[y == 2] = 1
y = multioutput_estimator_convert_y_2d(estimator_orig, y)
y = enforce_estimator_tags_y(estimator_orig, y)

methods = ["predict", "transform", "decision_function", "predict_proba"]

Expand Down Expand Up @@ -1176,7 +1176,7 @@ def check_estimators_empty_data_messages(name, estimator_orig):
X_zero_features = np.empty(0).reshape(3, 0)
# the following y should be accepted by both classifiers and regressors
# and ignored by unsupervised models
y = multioutput_estimator_convert_y_2d(e, np.array([1, 0, 1]))
y = enforce_estimator_tags_y(e, np.array([1, 0, 1]))
msg = (r"0 feature\(s\) \(shape=\(3, 0\)\) while a minimum of \d* "
"is required.")
assert_raises_regex(ValueError, msg, e.fit, X_zero_features, y)
Expand All @@ -1194,7 +1194,7 @@ def check_estimators_nan_inf(name, estimator_orig):
X_train_inf[0, 0] = np.inf
y = np.ones(10)
y[:5] = 0
y = multioutput_estimator_convert_y_2d(estimator_orig, y)
y = enforce_estimator_tags_y(estimator_orig, y)
error_string_fit = "Estimator doesn't check for NaN and inf in fit."
error_string_predict = ("Estimator doesn't check for NaN and inf in"
" predict.")
Expand Down Expand Up @@ -1276,8 +1276,7 @@ def check_estimators_pickle(name, estimator_orig):

estimator = clone(estimator_orig)

# some estimators only take multioutputs
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

set_random_state(estimator)
estimator.fit(X, y)
Expand Down Expand Up @@ -1464,7 +1463,7 @@ def check_classifiers_train(name, classifier_orig, readonly_memmap=False):
n_samples, n_features = X.shape
classifier = clone(classifier_orig)
X = pairwise_estimator_convert_X(X, classifier)
y = multioutput_estimator_convert_y_2d(classifier, y)
y = enforce_estimator_tags_y(classifier, y)

set_random_state(classifier)
# raises error on malformed input for fit
Expand Down Expand Up @@ -1669,7 +1668,7 @@ def check_estimators_fit_returns_self(name, estimator_orig,
X = pairwise_estimator_convert_X(X, estimator_orig)

estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

if readonly_memmap:
X, y = create_memmap_backed_data([X, y])
Expand Down Expand Up @@ -1706,6 +1705,7 @@ def check_supervised_y_2d(name, estimator_orig):
y = np.arange(10) % 2
else:
y = np.arange(10) % 3
y = enforce_estimator_tags_y(estimator_orig, y)
estimator = clone(estimator_orig)
set_random_state(estimator)
# fit
Expand Down Expand Up @@ -1821,7 +1821,7 @@ def check_regressors_int(name, regressor_orig):
X = pairwise_estimator_convert_X(X[:50], regressor_orig)
rnd = np.random.RandomState(0)
y = rnd.randint(3, size=X.shape[0])
y = multioutput_estimator_convert_y_2d(regressor_orig, y)
y = enforce_estimator_tags_y(regressor_orig, y)
rnd = np.random.RandomState(0)
# separate estimators to control random seeds
regressor_1 = clone(regressor_orig)
Expand Down Expand Up @@ -1850,7 +1850,7 @@ def check_regressors_train(name, regressor_orig, readonly_memmap=False):
y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled
y = y.ravel()
regressor = clone(regressor_orig)
y = multioutput_estimator_convert_y_2d(regressor, y)
y = enforce_estimator_tags_y(regressor, y)
if name in CROSS_DECOMPOSITION:
rnd = np.random.RandomState(0)
y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))])
Expand Down Expand Up @@ -1894,7 +1894,7 @@ def check_regressors_no_decision_function(name, regressor_orig):
rng = np.random.RandomState(0)
X = rng.normal(size=(10, 4))
regressor = clone(regressor_orig)
y = multioutput_estimator_convert_y_2d(regressor, X[:, 0])
y = enforce_estimator_tags_y(regressor, X[:, 0])

if hasattr(regressor, "n_components"):
# FIXME CCA, PLS is not robust to rank 1 effects
Expand Down Expand Up @@ -2034,7 +2034,7 @@ def check_estimators_overwrite_params(name, estimator_orig):
X -= X.min()
X = pairwise_estimator_convert_X(X, estimator_orig, kernel=rbf_kernel)
estimator = clone(estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

set_random_state(estimator)

Expand Down Expand Up @@ -2123,15 +2123,15 @@ def check_classifier_data_not_an_array(name, estimator_orig):
X = np.array([[3, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1]])
X = pairwise_estimator_convert_X(X, estimator_orig)
y = [1, 1, 1, 2, 2, 2]
y = multioutput_estimator_convert_y_2d(estimator_orig, y)
y = enforce_estimator_tags_y(estimator_orig, y)
check_estimators_data_not_an_array(name, estimator_orig, X, y)


@ignore_warnings(category=DeprecationWarning)
def check_regressor_data_not_an_array(name, estimator_orig):
X, y = _boston_subset(n_samples=50)
X = pairwise_estimator_convert_X(X, estimator_orig)
y = multioutput_estimator_convert_y_2d(estimator_orig, y)
y = enforce_estimator_tags_y(estimator_orig, y)
check_estimators_data_not_an_array(name, estimator_orig, X, y)


Expand Down Expand Up @@ -2237,7 +2237,13 @@ def param_filter(p):
assert param_value == init_param.default, init_param.name


def multioutput_estimator_convert_y_2d(estimator, y):
def enforce_estimator_tags_y(estimator, y):
# Estimators with a `requires_positive_y` tag only accept strictly positive
# data
if _safe_tags(estimator, "requires_positive_y"):
# Create strictly positive y. The minimal increment above 0 is 1, as
# y could be of integer dtype.
y += 1 + abs(y.min())
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
# Convert into a 2-D y for those estimators.
if _safe_tags(estimator, "multioutput_only"):
Expand Down Expand Up @@ -2270,7 +2276,7 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig):
if hasattr(estimator, 'max_iter'):
iris = load_iris()
X, y_ = iris.data, iris.target
y_ = multioutput_estimator_convert_y_2d(estimator, y_)
y_ = enforce_estimator_tags_y(estimator, y_)

set_random_state(estimator, 0)
if name == 'AffinityPropagation':
Expand Down Expand Up @@ -2480,7 +2486,7 @@ def check_fit_idempotent(name, estimator_orig):
y = rng.normal(size=n_samples)
else:
y = rng.randint(low=0, high=2, size=n_samples)
y = multioutput_estimator_convert_y_2d(estimator, y)
y = enforce_estimator_tags_y(estimator, y)

train, test = next(ShuffleSplit(test_size=.2, random_state=rng).split(X))
X_train, y_train = _safe_split(estimator, X, y, train)
Expand Down
31 changes: 23 additions & 8 deletions sklearn/utils/tests/test_estimator_checks.py
Expand Up @@ -22,12 +22,12 @@
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.estimator_checks import check_outlier_corruption
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LinearRegression, SGDClassifier
from sklearn.mixture import GaussianMixture
from sklearn.cluster import MiniBatchKMeans
from sklearn.decomposition import NMF
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.linear_model import MultiTaskElasticNet, LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeClassifier
Expand Down Expand Up @@ -292,6 +292,18 @@ def _more_tags(self):
return {'binary_only': True}


class RequiresPositiveYRegressor(LinearRegression):

def fit(self, X, y):
X, y = check_X_y(X, y)
if (y <= 0).any():
raise ValueError('negative y values not supported!')
return super().fit(X, y)

def _more_tags(self):
return {"requires_positive_y": True}


def test_check_fit_score_takes_y_works_on_deprecated_fit():
# Tests that check_fit_score_takes_y works on a class with
# a deprecated fit method
Expand Down Expand Up @@ -392,22 +404,25 @@ def test_check_estimator():
assert_raises_regex(AssertionError, msg, check_estimator,
LargeSparseNotSupportedClassifier)

# does error on binary_only untagged estimator
msg = 'Only 2 classes are supported'
assert_raises_regex(ValueError, msg, check_estimator,
UntaggedBinaryClassifier)

# non-regression test for estimators transforming to sparse data
check_estimator(SparseTransformer())

# doesn't error on actual estimator
check_estimator(AdaBoostClassifier)
check_estimator(AdaBoostClassifier())
check_estimator(LogisticRegression)
check_estimator(LogisticRegression())
check_estimator(MultiTaskElasticNet)
check_estimator(MultiTaskElasticNet())

# doesn't error on binary_only tagged estimator
check_estimator(TaggedBinaryClassifier)

# does error on binary_only untagged estimator
msg = 'Only 2 classes are supported'
assert_raises_regex(ValueError, msg, check_estimator,
UntaggedBinaryClassifier)
# Check regressor with requires_positive_y estimator tag
check_estimator(RequiresPositiveYRegressor)


def test_check_outlier_corruption():
Expand Down

0 comments on commit 78ac1ab

Please sign in to comment.