Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Enforcing strict monotonicity for IsotonicRegression #17790

Closed
wants to merge 29 commits into from

Conversation

dsleo
Copy link
Contributor

@dsleo dsleo commented Jun 30, 2020

Reference Issues/PRs

As per our discussion, this fixes issue #16321.

What does this implement/fix? Explain your changes.

IsotonicRegression not being strictly monotonic, calibration using IsotonicRegression have its rank-based metrics changes. This PR adds a strict argument to IsotonicRegression to enforce strict monotonicity by keeping only the unique values of the thresholds y_thresholds_.

Any other comments?

Not sure whether the strict argument should default to True for CalibratedClassifierCV if the selected method is isotonic ?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dsleo ,

I understand this is still WIP but I took a quick look

we should also add details in the User guides of both calibration and isotonic

sklearn/calibration.py Outdated Show resolved Hide resolved
sklearn/calibration.py Outdated Show resolved Hide resolved
sklearn/calibration.py Outdated Show resolved Hide resolved
sklearn/calibration.py Outdated Show resolved Hide resolved
sklearn/isotonic.py Outdated Show resolved Hide resolved
sklearn/isotonic.py Outdated Show resolved Hide resolved
sklearn/tests/test_calibration.py Outdated Show resolved Hide resolved
sklearn/tests/test_calibration.py Outdated Show resolved Hide resolved
sklearn/tests/test_calibration.py Show resolved Hide resolved
@NicolasHug
Copy link
Member

Not sure whether the strict argument should default to True for CalibratedClassifierCV if the selected method is isotonic ?

It should be False for CalibratedClassifierCV and IsotonicRegression so that results don't change in the next versions

dsleo and others added 7 commits July 1, 2020 17:03
doctstring wording

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
@dsleo
Copy link
Contributor Author

dsleo commented Jul 1, 2020

Thanks @NicolasHug for the review ! I did the modifications following your comments and I've added some description in the calibration's user guide. This should ready for full review now !

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dsleo

I think we should also detail in the User Guide of isotonic how the strict mode works.

We could also udpate the example https://scikit-learn.org/dev/auto_examples/miscellaneous/plot_isotonic_regression.html#sphx-glr-auto-examples-miscellaneous-plot-isotonic-regression-py to illustrate the difference between strict and non-strict mode.

(Also, not directly related to this PR, but if you could include a link to that example in the User Guide it would be great!)

sklearn/tests/test_calibration.py Outdated Show resolved Hide resolved
sklearn/tests/test_isotonic.py Outdated Show resolved Hide resolved
sklearn/tests/test_isotonic.py Outdated Show resolved Hide resolved
@dsleo
Copy link
Contributor Author

dsleo commented Jul 2, 2020

I've made the changes according to your comments and update the example. A few things to note:

  • to insure strict monotonicity outside the training domain, this extrapolates unless out_of_bound is set to clip. This is now made explicit in the docstring, let me know if it's not clear enough.
  • when fit on increasing data, setting increasing=False leads the output function to be constant, so strict monotonicity cannot be enforced. For now a ValueError is raised, but it's not technically correct. Ideally we could have a custom error, what do you think ?

I'll try to find the time tomorrow to fill in the details in the UG on the strict mode.

@dsleo dsleo changed the title [WIP] Enforcing strict monotonicity for IsotonicRegression [MRG] Enforcing strict monotonicity for IsotonicRegression Jul 3, 2020
@dsleo
Copy link
Contributor Author

dsleo commented Jul 9, 2020

little ping @NicolasHug in case this was drown in all the others issues and PRs :)

Copy link
Member

@lucyleeow lucyleeow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note about out_of_bounds that we should fix.

I think this shows that we should add strict=True/False with pytest parameterize to more tests.
Or at least to oob tests - test_isotonic_regression_oob_clip, test_isotonic_regression_oob_nan, test_isotonic_regression_oob_raise etc

sklearn/isotonic.py Outdated Show resolved Hide resolved
Comment on lines 165 to 166
When set to `True`, points outside the training domain will be
extrapolated, unless `out_of_bounds="clip"`.
Copy link
Member

@lucyleeow lucyleeow Aug 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want to add another option to out_of_bounds if we want to extrapolate values outside of the training domain.
It doesn't make sense that we give extrapolated values when out_of_bounds is: 'nan' or 'raise' (and it isn't consistent with the behaviour when strict=False). If we want to offer extrapolate option, we should add an 'extrapolate' option to out_of_bounds. But I am also not sure if it a good idea to extrapolate though @ogrisel and @NicolasHug will know more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a .. versionadded:: here

@lucyleeow
Copy link
Member

lucyleeow commented Aug 25, 2020

I also just wanted to raise that with strict=True we start extrapolating even within the train domain, specifically after the second to last largest train value - not sure how much of a problem this is. Here is an example (largest value is 0.919):

import numpy as np
from sklearn.isotonic import IsotonicRegression

X = np.array([0.49835541, 0.54572331, 0.91999828, 0.33876463, 0.87974298,
              0.02375396, 0.33838427, 0.43110351, 0.5300294 , 0.80951779])
y = np.array([0, 1, 1, 1, 1, 0, 0, 1, 1, 1])

iso = IsotonicRegression(strict=False)
iso.fit(X, y)
iso.predict(np.array([0.9]))

array([1.])

iso_strict = IsotonicRegression(strict=True)
iso_strict.fit(X, y)
iso_strict.predict(np.array([0.9]))

array([1.64477914])

(Of course this is a big problem here as we are predicting a proba >1 for a x value within the training domain)

We can understand why when we plot the thresholds:
For strict=False:
image

For strict=True:
image

@dsleo
Copy link
Contributor Author

dsleo commented Nov 6, 2020

Sorry for the delay in following up on this !

@lucyleeow, I've simplified the code, no more unnecessary extrapolation or issues with out_of_bounds . So the behavior is consistent for the two possible mode of strict.

Regarding the example you propose, when removing flat segments, we can fix both extremities of the interval (see here).

iso_strict = IsotonicRegression(strict=True)
iso_strict.fit(X, y)
iso_strict.predict(np.array([0.9]))

array([0.98853113])

And this is the corresponding curve with strict=True:
Capture d’écran 2020-11-06 à 08 53 19

This should now be ready for a second review then.

Base automatically changed from master to main January 22, 2021 10:52
@ogrisel
Copy link
Member

ogrisel commented Oct 25, 2021

@dsleo there is an alternative to this PR proposed as #21454. Any comment would be appreciated.

@dsleo
Copy link
Contributor Author

dsleo commented Oct 27, 2021

@ogrisel thanks for the heads up. I am not familiar with centered isotonic regression and its uses but if we can enforce monotonicity at the edges that would also handle this. Perhaps through two parameters, centered and strict (but where strict=True and centered=False wouldn't be an option) ?

@ogrisel
Copy link
Member

ogrisel commented Oct 29, 2021

@ogrisel thanks for the heads up. I am not familiar with centered isotonic regression and its uses but if we can enforce monotonicity at the edges that would also handle this. Perhaps through two parameters, centered and strict (but where strict=True and centered=False wouldn't be an option) ?

Let's centralize the discussion on the centered isotonic PR directly.

@lorentzenchr
Copy link
Member

See #21454.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants