Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaboost Classifier :- Stopped iterations when sample weights overflow and print a warning message #10096

Merged
merged 18 commits into from
Aug 6, 2021

Conversation

fenilsuchak
Copy link
Contributor

@fenilsuchak fenilsuchak commented Nov 9, 2017

Fixes #10077

Error fixed in adaboost classifier.At some iterations weighted error was underflowing, due to high learning rate.,Making error as NaN and hence making subsequent iterations useless.
This update halts the iterations when such warning is encountered and prints a warning message to inform about the underflow.

This is my first PR. I am open to all sorts of criticism.

Copy link
Contributor

@massich massich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take care of pep8

@@ -145,7 +145,11 @@ def fit(self, X, y, sample_weight=None):
random_state)

# Early termination
if sample_weight is None:

if sample_weight is None and math.isnan(estimator_error):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to test sample_weight twice.

if sample_weight is None:
      if math.isnan(estimator_error):
          do_something_specific
      break

if sample_weight is None:

if sample_weight is None and math.isnan(estimator_error):
print("Underflow of weighted error occured during iterations! Iterations stopped ! High chances of Overfitting!, Try decreasing the learning rate or n_estimators to avoid this! ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use proper warning messages

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and lines should be <80 characters

@@ -497,6 +501,9 @@ def _boost_real(self, iboost, X, y, sample_weight, random_state):
# Error fraction
estimator_error = np.mean(
np.average(incorrect, weights=sample_weight, axis=0))
if math.isnan(estimator_error):
return None,None,estimator_error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after commas (otherwise is not pep8)

@@ -552,6 +559,8 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
# Error fraction
estimator_error = np.mean(
np.average(incorrect, weights=sample_weight, axis=0))
if math.isnan(estimator_error):
return None,None,estimator_error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does "idem" mean? Sorry if the question is lame.

@massich
Copy link
Contributor

massich commented Nov 9, 2017

wellcome. Tag your title [WIP] and [MRG] when you are ready for revision. Thx

@fenilsuchak fenilsuchak changed the title Adaboost Classifier :- Stopped iterations when underflow and printed a warning message [WIP}Adaboost Classifier :- Stopped iterations when underflow and printed a warning message Nov 10, 2017
@fenilsuchak fenilsuchak changed the title [WIP}Adaboost Classifier :- Stopped iterations when underflow and printed a warning message [WIP] Adaboost Classifier :- Stopped iterations when underflow and printed a warning message Nov 10, 2017
if sample_weight is None:
if estimator_error is not None and math.isnan(estimator_error):
print("Early termination due to underflow of estimated_error.Iterations stopped")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use print statements for warning. (see contributing guide for a warning example)

Or see this example in the code:

if not np.allclose(mean_2, 0):
warnings.warn("Numerical issues were encountered "
"when scaling the data "
"and might not be solved. The standard "
"deviation of the data is probably "
"very close to 0. ")

More over if you introduce a if statement, the code is not tested therefore a test needs to be written. See the test of the previous example here:

w = "standard deviation of the data is probably very close to 0"
x_scaled = assert_warns_message(UserWarning, w, scale, x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll work on the test

Copy link
Contributor Author

@fenilsuchak fenilsuchak Nov 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@massich I am not so sure how a good test for this would be like
A separate function for testing?
Heres a snapshot. Any suggestions would be great,
ada_f

Copy link
Contributor

@massich massich Nov 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no place where something similar is tested, create a new function. But I would use a more descriptive name like test_early_stop_when_estimator_error_bcomes_nan. I would also add a reference to the original issue.

To avoid PEP8 errors, I would also set the warning message into a variable to further compare to. (And make sure that the warning raised and catch are the same (try to find a name bettern than w or early_stop_warnmsg)

early_stop_warnmsg = "Early termination due to underflow of estimated_error. Iterations stopped"
assert_warns_message(UserWarning, early_stop_warnsg, clf.fit, iris.data, iris.target)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Thanks, I will make these changes.

Copy link
Contributor Author

@fenilsuchak fenilsuchak Dec 1, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@massich I don't think the estimator error is underflowing. if it is underflowing then won't estimated error be approximated to zero? What I suppose is that sample_weight are shooting up to infinite when summed(probably due to high learning rate). Causing inf/inf division for estimated error. Giving a nan. Any suggestions?

@fenilsuchak fenilsuchak changed the title [WIP] Adaboost Classifier :- Stopped iterations when underflow and printed a warning message [MRG] Adaboost Classifier :- Stopped iterations when underflow and printed a warning message Dec 8, 2017
@fenilsuchak
Copy link
Contributor Author

@massich Any changes needed? Actually there is no underflow of estimator but sample weights reaching infinite values are causing the Nan. The original issue title is not what the problem is. So should I change my PR title?

@jnothman
Copy link
Member

Yes, please update the PR title.

Does normalizing the sample_weight vector after each boosting step help? We normalize them when they are first input, and sample weights should be relative, so I don't see any harm in normalizing except that it could result in underflow.

If this is about numerical precision, it might also help to perform the boosting in log space, so using sample_weight = np.log(np.exp(sample_weight) + estimator_weight * incorrect * ...) instead of sample_weight *= np.exp(estimator_weight * incorrect * ...).

@fenilsuchak fenilsuchak changed the title [MRG] Adaboost Classifier :- Stopped iterations when underflow and printed a warning message [WIP] Adaboost Classifier :- Stopped iterations when underflow and printed a warning message Dec 14, 2017
@fenilsuchak
Copy link
Contributor Author

We do normalise the sample weight after each boosting step , which requires summing up all the sample weights, and thats where we get 'inf'.
Converting to logspace seems workable. I will give it a try.

@jnothman
Copy link
Member

You're right, we do. Yes, please do try log-sum-exp.

@fenilsuchak
Copy link
Contributor Author

@jnothman Converting to log-space isn't working well ,it is affecting accuracy and causing test_iris() to fail due to low accuracy.

some_jpeg

So I think its better to generate a warning and stop iterations.

@amueller
Copy link
Member

@Fenil3510 that sounds odd. Can you please commit the change so we can review?

@fenilsuchak
Copy link
Contributor Author

@amueller please review

Copy link
Contributor Author

@fenilsuchak fenilsuchak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted to log-space I hope the conversion is right.

@@ -581,7 +581,8 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
# Only boost the weights if I will fit again
if not iboost == self.n_estimators - 1:
# Only boost positive weights
sample_weight *= np.exp(estimator_weight * incorrect *
sample_weight = np.log(np.exp(sample_weight) +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have the exp and the log the wrong way around. You want to turn a product into a sum of logs.

Copy link
Contributor Author

@fenilsuchak fenilsuchak Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok,
sample_weight = np.log(sample_weight * np.exp(estimator_weight*incorrect....))
This is fine I guess.
I'll commit the changes if the above is fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you need to exp(log(sample_weight) + est_weight * incorrect * ...)

@@ -581,7 +581,7 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
# Only boost the weights if I will fit again
if not iboost == self.n_estimators - 1:
# Only boost positive weights
sample_weight = np.log(np.exp(sample_weight) +
sample_weight = np.exp(np.log(sample_weight) *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, you need a sum of logs. That * should be a +.

Copy link
Contributor Author

@fenilsuchak fenilsuchak Dec 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh! I am really sorry. will fix

@jnothman
Copy link
Member

Well Travis gave you green for everything but flake8

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that still means we're getting overflow...

@fenilsuchak
Copy link
Contributor Author

@jnothman Yes there is still an overflow. Anything else we could try?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @fenilsuchak !

Please add an entry to the change log at doc/whats_new/v0.24.rst with tag |Fix|. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:.

sample_weight *= np.exp(estimator_weight * incorrect *
((sample_weight > 0) |
(estimator_weight < 0)))
sample_weight = np.exp(np.log(sample_weight) +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's include a comment here explaining why we are doing this in log space.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we had an advantage doing this we still had an overflow.

sklearn/ensemble/weight_boosting.py Outdated Show resolved Hide resolved
w = "Sample weights have reached infinite values"
clf = AdaBoostClassifier(n_estimators=30, learning_rate=5.,
algorithm="SAMME")
assert_warns_message(UserWarning, w, clf.fit, iris.data, iris.target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have been moving toward using pytest.warns:

msg = "Sample weights have reached infinite values"
with pytest.warns(UserWarning, match=msg):
    clf.fit(iris.data, iris.target)

@fenilsuchak
Copy link
Contributor Author

@cmarmo not sure why the tests failing.

@cmarmo
Copy link
Contributor

cmarmo commented Jul 13, 2020

@cmarmo not sure why the tests failing.

Not sure either, but it's failing on upstream/master too with the same error...
@jeremiedbb I've found this old issue you fixed, #15443, are we in a similar situation?

@jeremiedbb
Copy link
Member

It's not the same issue but it's a compilation on macOS issue as usual :)
There's a tmp fix in #17913.

@fenilsuchak
Copy link
Contributor Author

Ok, will wait for the merge of #17913 then to rerun tests again?

@cmarmo
Copy link
Contributor

cmarmo commented Jul 18, 2020

Ok, will wait for the merge of #17913 then to rerun tests again?

#17913 has been merged ... if you have some time to sync... thanks!

@cmarmo
Copy link
Contributor

cmarmo commented Aug 15, 2020

Hi @thomasjpfan this is a three year old PR ... do you mind checking if your comments have been addressed? The check failure is unrelated with the PR itself. Thanks!

sample_weight = np.exp(np.log(sample_weight) +
estimator_weight * incorrect *
((sample_weight > 0) |
(estimator_weight < 0)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the reason behind estimator_weight < 0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks like a backward incompatible change to me. I reverted this change.

Base automatically changed from master to main January 22, 2021 10:49
@rth rth changed the title [MRG+1] Adaboost Classifier :- Stopped iterations when sample weights overflow and print a warning message Adaboost Classifier :- Stopped iterations when sample weights overflow and print a warning message Aug 6, 2021
@rth rth removed the Stalled label Aug 6, 2021
Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fenilsuchak ! I fixed conflicts and added a changelog entry.

@rth rth merged commit cef0282 into scikit-learn:main Aug 6, 2021
samronsin pushed a commit to samronsin/scikit-learn that referenced this pull request Nov 30, 2021
Co-authored-by: Fenil Suchak <fenilsuchak@fenil.local>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Roman Yurchak <rth.yurchak@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Underflow in weight boosting
8 participants