Skip to content

Commit

Permalink
MNT improve the convergence warning message for LogisticRegression (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel authored and Pan Jan committed Mar 3, 2020
1 parent bf110d6 commit b30b32c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
10 changes: 9 additions & 1 deletion sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
from ..metrics import get_scorer


_LOGISTIC_SOLVER_CONVERGENCE_MSG = (
"Please also refer to the documentation for alternative solver options:\n"
" https://scikit-learn.org/stable/modules/linear_model.html"
"#logistic-regression")


# .. some helper functions for logistic_regression_path ..
def _intercept_dot(w, X, y):
"""Computes y * np.dot(X, w).
Expand Down Expand Up @@ -928,7 +934,9 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
args=(X, target, 1. / C, sample_weight),
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter}
)
n_iter_i = _check_optimize_result(solver, opt_res, max_iter)
n_iter_i = _check_optimize_result(
solver, opt_res, max_iter,
extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)
w0, loss = opt_res.x, opt_res.fun
elif solver == 'newton-cg':
args = (X, target, 1. / C, sample_weight)
Expand Down
16 changes: 11 additions & 5 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from sklearn.utils._testing import skip_if_no_parallel

from sklearn.exceptions import ConvergenceWarning
from sklearn.exceptions import ChangedBehaviorWarning
from sklearn.linear_model._logistic import (
LogisticRegression,
logistic_regression_path,
Expand Down Expand Up @@ -391,13 +390,20 @@ def test_logistic_regression_path_convergence_fail():
y = [1] * 100 + [-1] * 100
Cs = [1e3]

msg = (r"lbfgs failed to converge.+Increase the number of iterations or "
r"scale the data")

with pytest.warns(ConvergenceWarning, match=msg):
# Check that the convergence message points to both a model agnostic
# advice (scaling the data) and to the logistic regression specific
# documentation that includes hints on the solver configuration.
with pytest.warns(ConvergenceWarning) as record:
_logistic_regression_path(
X, y, Cs=Cs, tol=0., max_iter=1, random_state=0, verbose=0)

assert len(record) == 1
warn_msg = record[0].message.args[0]
assert "lbfgs failed to converge" in warn_msg
assert "Increase the number of iterations" in warn_msg
assert "scale the data" in warn_msg
assert "linear_model.html#logistic-regression" in warn_msg


def test_liblinear_dual_random_state():
# random_state is relevant for liblinear solver only if dual=True
Expand Down
19 changes: 12 additions & 7 deletions sklearn/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ def _newton_cg(grad_hess, func, grad, x0, args=(), tol=1e-4,
return xk, k


def _check_optimize_result(solver, result, max_iter=None):
def _check_optimize_result(solver, result, max_iter=None,
extra_warning_msg=None):
"""Check the OptimizeResult for successful convergence
Parameters
Expand All @@ -233,12 +234,16 @@ def _check_optimize_result(solver, result, max_iter=None):
# handle both scipy and scikit-learn solver names
if solver == "lbfgs":
if result.status != 0:
warnings.warn("{} failed to converge (status={}): {}. "
"Increase the number of iterations or scale the "
"data as shown in https://scikit-learn.org/stable/"
"modules/preprocessing.html"
.format(solver, result.status, result.message),
ConvergenceWarning, stacklevel=2)
warning_msg = (
"{} failed to converge (status={}):\n{}.\n\n"
"Increase the number of iterations (max_iter) "
"or scale the data as shown in:\n"
" https://scikit-learn.org/stable/modules/"
"preprocessing.html."
).format(solver, result.status, result.message.decode("latin1"))
if extra_warning_msg is not None:
warning_msg += "\n" + extra_warning_msg
warnings.warn(warning_msg, ConvergenceWarning, stacklevel=2)
if max_iter is not None:
# In scipy <= 1.0.0, nit may exceed maxiter for lbfgs.
# See https://github.com/scipy/scipy/issues/7854
Expand Down

0 comments on commit b30b32c

Please sign in to comment.