Warm start bug when fitting a LogisticRegression model on binary outcomes with `multi_class='multinomial'`. #10836
Comments
Thanks for the report. At a glance, that looks very plausible.. Test and patch welcome |
I'm happy to do this although would be interested in opinions on the test. I could do either
The pros of (1) are that its quick and easy however as mentioned previously it doesn't really get to the essence of what is causing the bug. The only reason it is failing is because the The pros of (2) are that it would correctly test that the warm starting occurred but the cons would be I don't know how I would do it as the |
Go for the simplest test first, open a PR and see where that leads you! |
Description
Bug when fitting a LogisticRegression model on binary outcomes with multi_class='multinomial' when using warm start. Note that it is similar to the issue here #9889 i.e. only using a 1D
coef
object on binary outcomes even when usingmulti_class='multinomial'
as opposed to a 2Dcoef
object.Steps/Code to Reproduce
Expected Results
The predictions should be the same as the model converged the first time it was run.
Actual Results
The predictions are different. In fact the more times you re-run the fit the worse it gets. This is actually the only reason I was able to catch the bug. It is caused by the line here https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/logistic.py#L678.
As
coef
is(1, n_features)
, butw0
is(2, n_features)
, this causes thecoef
value to be broadcast into thew0
. This some sort of singularity issue when training resulting in worse performance. Note that had it not done exactly this i.e.w0
was simply initialised by some random values, this bug would be very hard to catch because of course each time the model would converge just not as fast as one would hope when warm starting.Further Information
The fix I believe is very easy, just need to swap the previous line to
Versions
Linux-4.13.0-37-generic-x86_64-with-Ubuntu-16.04-xenial
Python 3.5.2 (default, Nov 23 2017, 16:37:01)
NumPy 1.14.2
SciPy 1.0.0
Scikit-Learn 0.20.dev0 (built from latest master)
The text was updated successfully, but these errors were encountered: