Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Fixes #7578 added check_decision_proba_consistency in estimator_checks #8253

Merged
merged 32 commits into from
Mar 7, 2017

Conversation

shubham0704
Copy link

@shubham0704 shubham0704 commented Jan 31, 2017

Reference Issue

Fix #7578

What does this implement/fix? Explain your changes.

It fixes the need for a test function to check whether the output predict_proba and decision_function are perfectly correlated or not.

Any other comments?

I need to understand the testing part of this function.I have done the pep8 linting and pyflakes but recieved 1 error while nose check stating set_testing_parameters() takes exactly 1 value 0 given and error=1.Also where is the best palce to yield this function.I did not make that change because I was unsure.

@@ -114,7 +116,7 @@ def _yield_classifier_checks(name, Classifier):
yield check_classifiers_regression_target
if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"]
# TODO some complication with -1 label
and name not in ["DecisionTreeClassifier",
and name not in ["DecisionTreeClassifier",
Copy link
Member

@lesteve lesteve Feb 1, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to put the and on the previous line to make flake8 happy. Error from Travis is (it gives you an hint as to what to do):

./sklearn/utils/estimator_checks.py:119:9: W503 line break before binary operator
        and name not in ["DecisionTreeClassifier",

@lesteve lesteve changed the title [WEP] Fixes #7578 added check_rank_corr in estimator_checks [WIP] Fixes #7578 added check_rank_corr in estimator_checks Feb 1, 2017
@jnothman
Copy link
Member

jnothman commented Feb 2, 2017

Using spearmanr internally performs rankdata followed by corrcoef. I think rankdata (or a stable argsort) followed by testing for equality should suffice and be more efficient



@ignore_warnings(category=DeprecationWarning)
def check_rank_corr(name, Estimator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps call it check_decision_proba_consistency.

predict_proba methods has outputs with perfect rank correlation.
"""

X, Y = make_multilabel_classification(n_classes=2, n_labels=1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using multilabel data? Why not just binary?

try:
classif = OneVsRestClassifier(estimator)
classif.fit(X, Y)
a = classif.predict_proba([i for i in range(20)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually the input would be 2d. Why is it 1d? For test data, can just generate something random and uniform, or can take from a similar distribution to training data.


if hasattr(estimator, "predict_proba"):
try:
classif = OneVsRestClassifier(estimator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we doing OvR?

a = classif.predict_proba([i for i in range(20)])
b = classif.decision_function([i for i in range(20)])
assert_equal(
rankdata(a, method='average'), rankdata(b, method='average'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

method shouldn't matter as long as tied values have tied ranks. But if we're working with non-binary classification, we need to do this comparison column-wise. Use assert_array_equal rather than assert_equal.

assert_equal(
rankdata(a, method='average'), rankdata(b, method='average'))

except ValueError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this to catch? The try block should go around the smallest scope that we want to exclude otherwise this test can pass when all estimators raise a ValueError because the test is broken.


if hasattr(estimator, "decision_function"):

if hasattr(estimator, "predict_proba"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use and

@shubham0704
Copy link
Author

@jnothman Travis-ci fails even though no errors found.

@lesteve
Copy link
Member

lesteve commented Feb 3, 2017

@jnothman Travis-ci fails even though no errors found.

Yeah I have opened an issue on Travis about that (this is under investigation):
travis-ci/travis-ci#7264

Trigger
@jnothman
Copy link
Member

jnothman commented Feb 6, 2017

You seem to have confused a number of different patches in your PR. You should be using a branch in your fork to avoid this, not master.

@jnothman
Copy link
Member

jnothman commented Feb 6, 2017

You can keep this one on master, but you need to revert or otherwise remove your changes pertaining to other issues.

@shubham0704 shubham0704 changed the title [WIP] Fixes #7578 added check_rank_corr in estimator_checks [MRG] Fixes #7578 added check_decision_proba_consistency in estimator_checks Feb 7, 2017
@shubham0704
Copy link
Author

@jnothman anything else needed?

@@ -56,8 +56,9 @@

BOSTON = None
CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']
MULTI_OUTPUT = ['CCA', 'DecisionTreeRegressor', 'ElasticNet',
'ExtraTreeRegressor', 'ExtraTreesRegressor', 'GaussianProcess',
MULTI_OUTPUT = ['CCA', 'DecisionTreeClassifier', 'DecisionTreeRegressor',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all others here are regressors. What makes you sure it's appropriate to include multioutput classifiers here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the function check_supervised_y_2d the line inside the warning `section-
estimator.fit(X, y[:, np.newaxis]) does not give any warnings for the classifiers I included. therefore I included it in the MULTI_OUTPUT list . Otherwise it would give me a error that ```expected 1 DataConversionWarning, got: ```. I checked sklearn documents for DecisionTreeClassifier it says y can accept [n_samples, n_outputs].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change relevant to the rest of the PR? Perhaps it should be a separate PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, making changes then.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nosetests fail if I do not include them. Can 1 reference 2 issues with 1 pr otherwise this will not pass.Maybe I will open one and reference this issue and the opened one with this pr.What do you say?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you be more specific what fails if you do not include them?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

These are the errors when I do not include them.

@@ -113,12 +114,12 @@ def _yield_classifier_checks(name, Classifier):
# basic consistency testing
yield check_classifiers_train
yield check_classifiers_regression_target
if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"]
if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't add these parentheses

# TODO some complication with -1 label
and name not in ["DecisionTreeClassifier",
"ExtraTreeClassifier"]):
if (name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't add these parentheses

@@ -8,6 +8,7 @@

set -e


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added by mistake during the days when travis went down, when I foolishly tried to make travis work in order to get my pr pass tests as this was my first one.Will correct it.

return (p.name != 'self'
and p.kind != p.VAR_KEYWORD
and p.kind != p.VAR_POSITIONAL)
return (p.name != 'self' and p.kind != p.VAR_KEYWORD and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

travis.log Outdated
@@ -0,0 +1,87 @@
Command line:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this file.


@ignore_warnings(category=DeprecationWarning)
def check_decision_proba_consistency(name, Estimator):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't usually add docstrings to checks, because nose doesn't play nicely.

predict_proba methods has outputs with perfect rank correlation.
"""
rnd = np.random.RandomState(0)
X_train = (3*rnd.uniform(size=(10, 4))).astype(int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually use integer features, or binary, in common tests in case estimators can't deal with real-valued features.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely I added .astype(int).

if (hasattr(estimator, "decision_function") and
hasattr(estimator, "predict_proba")):

estimator.fit(X_train, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing seems to be entering this case: I've modified it to say assert False but nothing is failing from sklearn/tests/test_common.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. You've put this in as a regressor check.

@@ -162,6 +163,7 @@ def _yield_regressor_checks(name, Regressor):
yield check_regressors_no_decision_function
yield check_supervised_y_2d
yield check_supervised_y_no_nan
yield check_decision_proba_consistency
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in ...classifier_checks not regressors

@shubham0704
Copy link
Author

Awesome thats great I get what you are trying to say (likelihood(point belongs to(A))/likelihood(belongs to(B))) if both are around 0.5 that is we can have 0.6/0.4 meaning it can belong to either side then the values wont peak so much.Thanks a lot. It should work definitely.I will make changes and update.

@shubham0704
Copy link
Author

shubham0704 commented Feb 17, 2017

@jnothman I did not exactly take the points in the middle I just bought the cluster centres nearer and they kind of overlap. Other thing that I thought was to make ellipses on both blobs and consider all the points outside them for test set but this worked. Is it fine or should I improvise?

# TODO some complication with -1 label
and name not in ["DecisionTreeClassifier",
"ExtraTreeClassifier"]):
if name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you change this from and to a separate if. That's what creates the errors in your screenshot.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making changes.

centers = [(2, 2), (4, 4)]
X, y = make_blobs(n_samples=100, random_state=0, n_features=4,
centers=centers, cluster_std=1.0, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this approach, the probabilities are again going to be very peaked around 0 and 1, since the blobs are more-or-less linearly separable, encouraging numerical precision errors etc. For test, I'd just use np.random.randn() + 3 or something.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
random_state=0)

X_test = np.random.randn(20, 2)+4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please insert spaces around +

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for all the reviews on pr. Learnt a lot.Making changes.

@jnothman jnothman changed the title [MRG] Fixes #7578 added check_decision_proba_consistency in estimator_checks [MRG+1] Fixes #7578 added check_decision_proba_consistency in estimator_checks Feb 22, 2017
@jnothman
Copy link
Member

Please add an entry in what's new. Put it in API changes to say "Estimators with both x and y are now required ..."

@shubham0704
Copy link
Author

shubham0704 commented Feb 22, 2017

Got my Network exam now will surely do it by evening.

@lesteve
Copy link
Member

lesteve commented Feb 23, 2017

@shubham0704 please use "Fix #issueNumber" in your PR description this way the associated issue gets closed automatically when the PR is merged. For more details, look at this. I have edited your description but please remember do it next time.

@shubham0704
Copy link
Author

Sure.Thanks @lesteve .

@shubham0704
Copy link
Author

[RFC] -request for close :)
Note: (this is on master so I have to use ad-hoc methods to address other issues)
Thanks

return (p.name != 'self'
and p.kind != p.VAR_KEYWORD
and p.kind != p.VAR_POSITIONAL)
return (p.name != 'self' and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For next time, try not to change things that are not related to your PR. This adds noise into the diff and makes it harder for the review to be efficient.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure @lesteve .Thanks a lot.

@lesteve lesteve merged commit 02c705e into scikit-learn:master Mar 7, 2017
@lesteve
Copy link
Member

lesteve commented Mar 7, 2017

LGTM, merging, thanks a lot!

@Przemo10 Przemo10 mentioned this pull request Mar 17, 2017
herilalaina pushed a commit to herilalaina/scikit-learn that referenced this pull request Mar 26, 2017
massich pushed a commit to massich/scikit-learn that referenced this pull request Apr 26, 2017
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Common test: predict_proba as a monotonic transformation of decision_function
3 participants