Skip to content

Commit

Permalink
Merge pull request #6817 from TomDLT/logistic_class_weight
Browse files Browse the repository at this point in the history
[MRG] use class_weight through sample_weight in LogisticRegression with liblinear
  • Loading branch information
jnothman committed May 25, 2016
2 parents 20f89ef + 1107f22 commit af171b8
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 52 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -203,6 +203,11 @@ Bug fixes
- Fix bug where expected and adjusted mutual information were incorrect if
cluster contingency cells exceeded ``2**16``. By `Joel Nothman`_.

- Fix bug in :class:`linear_model.LogisticRegressionCV` where
``solver='liblinear'`` did not accept ``class_weights='balanced``.
(`#6817 <https://github.com/scikit-learn/scikit-learn/pull/6817>`_).
By `Tom Dupre la Tour`_.


API changes summary
-------------------
Expand Down
20 changes: 3 additions & 17 deletions sklearn/linear_model/logistic.py
Expand Up @@ -618,23 +618,9 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
# are assigned to the original labels. If it is "balanced", then
# the class_weights are assigned after masking the labels with a OvR.
le = LabelEncoder()

if isinstance(class_weight, dict) or multi_class == 'multinomial':
if solver == "liblinear":
if classes.size == 2:
# Reconstruct the weights with keys 1 and -1
temp = {1: class_weight[pos_class],
-1: class_weight[classes[0]]}
class_weight = temp.copy()
else:
raise ValueError("In LogisticRegressionCV the liblinear "
"solver cannot handle multiclass with "
"class_weight of type dict. Use the lbfgs, "
"newton-cg or sag solvers or set "
"class_weight='balanced'")
else:
class_weight_ = compute_class_weight(class_weight, classes, y)
sample_weight *= class_weight_[le.fit_transform(y)]
class_weight_ = compute_class_weight(class_weight, classes, y)
sample_weight *= class_weight_[le.fit_transform(y)]

# For doing a ovr, we need to mask the labels first. for the
# multinomial case this is not necessary.
Expand Down Expand Up @@ -740,7 +726,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
maxiter=max_iter, tol=tol)
elif solver == 'liblinear':
coef_, intercept_, n_iter_i, = _fit_liblinear(
X, target, C, fit_intercept, intercept_scaling, class_weight,
X, target, C, fit_intercept, intercept_scaling, None,
penalty, dual, verbose, max_iter, tol, random_state,
sample_weight=sample_weight)
if fit_intercept:
Expand Down
67 changes: 33 additions & 34 deletions sklearn/linear_model/tests/test_logistic.py
Expand Up @@ -547,35 +547,35 @@ def test_logistic_regression_solvers_multiclass():


def test_logistic_regressioncv_class_weights():
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
n_classes=3, random_state=0)

msg = ("In LogisticRegressionCV the liblinear solver cannot handle "
"multiclass with class_weight of type dict. Use the lbfgs, "
"newton-cg or sag solvers or set class_weight='balanced'")
clf_lib = LogisticRegressionCV(class_weight={0: 0.1, 1: 0.2},
solver='liblinear')
assert_raise_message(ValueError, msg, clf_lib.fit, X, y)
y_ = y.copy()
y_[y == 2] = 1
clf_lib.fit(X, y_)
assert_array_equal(clf_lib.classes_, [0, 1])

# Test for class_weight=balanced
X, y = make_classification(n_samples=20, n_features=20, n_informative=10,
random_state=0)
clf_lbf = LogisticRegressionCV(solver='lbfgs', fit_intercept=False,
class_weight='balanced')
clf_lbf.fit(X, y)
clf_lib = LogisticRegressionCV(solver='liblinear', fit_intercept=False,
class_weight='balanced')
clf_lib.fit(X, y)
clf_sag = LogisticRegressionCV(solver='sag', fit_intercept=False,
class_weight='balanced', max_iter=2000)
clf_sag.fit(X, y)
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4)
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=4)
assert_array_almost_equal(clf_lib.coef_, clf_sag.coef_, decimal=4)
for weight in [{0: 0.1, 1: 0.2}, {0: 0.1, 1: 0.2, 2: 0.5}]:
n_classes = len(weight)
for class_weight in (weight, 'balanced'):
X, y = make_classification(n_samples=30, n_features=3,
n_repeated=0,
n_informative=3, n_redundant=0,
n_classes=n_classes, random_state=0)

clf_lbf = LogisticRegressionCV(solver='lbfgs', Cs=1,
fit_intercept=False,
class_weight=class_weight)
clf_ncg = LogisticRegressionCV(solver='newton-cg', Cs=1,
fit_intercept=False,
class_weight=class_weight)
clf_lib = LogisticRegressionCV(solver='liblinear', Cs=1,
fit_intercept=False,
class_weight=class_weight)
clf_sag = LogisticRegressionCV(solver='sag', Cs=1,
fit_intercept=False,
class_weight=class_weight,
tol=1e-5, max_iter=10000,
random_state=0)
clf_lbf.fit(X, y)
clf_ncg.fit(X, y)
clf_lib.fit(X, y)
clf_sag.fit(X, y)
assert_array_almost_equal(clf_lib.coef_, clf_lbf.coef_, decimal=4)
assert_array_almost_equal(clf_ncg.coef_, clf_lbf.coef_, decimal=4)
assert_array_almost_equal(clf_sag.coef_, clf_lbf.coef_, decimal=4)


def test_logistic_regression_sample_weights():
Expand Down Expand Up @@ -926,7 +926,6 @@ def test_n_iter():
assert_equal(clf.n_iter_.shape, (1, n_cv_fold, n_Cs))


@ignore_warnings
def test_warm_start():
# A 1-iteration second fit on same data should give almost same result
# with warm starting, and quite different result without warm starting.
Expand All @@ -947,11 +946,11 @@ def test_warm_start():
solver=solver,
random_state=42, max_iter=100,
fit_intercept=fit_intercept)
clf.fit(X, y)
coef_1 = clf.coef_
with ignore_warnings(category=ConvergenceWarning):
clf.fit(X, y)
coef_1 = clf.coef_

clf.max_iter = 1
with ignore_warnings():
clf.max_iter = 1
clf.fit(X, y)
cum_diff = np.sum(np.abs(coef_1 - clf.coef_))
msg = ("Warm starting issue with %s solver in %s mode "
Expand Down
4 changes: 3 additions & 1 deletion sklearn/utils/optimize.py
Expand Up @@ -17,6 +17,8 @@
import warnings
from scipy.optimize.linesearch import line_search_wolfe2, line_search_wolfe1

from ..exceptions import ConvergenceWarning


class _LineSearchError(RuntimeError):
pass
Expand Down Expand Up @@ -198,5 +200,5 @@ def newton_cg(grad_hess, func, grad, x0, args=(), tol=1e-4,

if warn and k >= maxiter:
warnings.warn("newton-cg failed to converge. Increase the "
"number of iterations.")
"number of iterations.", ConvergenceWarning)
return xk, k

0 comments on commit af171b8

Please sign in to comment.