Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #2559: Added normalize option to LogisticRegression #2567

Closed
10 changes: 7 additions & 3 deletions sklearn/linear_model/logistic.py
Expand Up @@ -36,6 +36,9 @@ class LogisticRegression(BaseLibLinear, LinearClassifierMixin,
Specifies if a constant (a.k.a. bias or intercept) should be
added the decision function.

normalize : boolean, optional, default False
If True, the regressors X will be normalized before logistic regression.

intercept_scaling : float, default: 1
when self.fit_intercept is True, instance vector x becomes
[x, self.intercept_scaling],
Expand Down Expand Up @@ -95,12 +98,13 @@ class LogisticRegression(BaseLibLinear, LinearClassifierMixin,
"""

def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0,
fit_intercept=True, intercept_scaling=1, class_weight=None,
random_state=None):
fit_intercept=True, normalize=False, intercept_scaling=1,
class_weight=None, random_state=None):

super(LogisticRegression, self).__init__(
penalty=penalty, dual=dual, loss='lr', tol=tol, C=C,
fit_intercept=fit_intercept, intercept_scaling=intercept_scaling,
fit_intercept=fit_intercept, normalize=normalize,
intercept_scaling=intercept_scaling,
class_weight=class_weight, random_state=random_state)

def predict_proba(self, X):
Expand Down
18 changes: 18 additions & 0 deletions sklearn/linear_model/tests/test_logistic.py
Expand Up @@ -155,3 +155,21 @@ def test_liblinear_random_state():
lr2 = logistic.LogisticRegression(random_state=0)
lr2.fit(X, y)
assert_array_almost_equal(lr1.coef_, lr2.coef_)

def test_normalize():
"""Test for normalize option in LogisticRegression
to verify that prediction of array that already normalize is same as if normalize option is enabled
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line too long. Consider a pep8 checker in your editor.

Also it should be:

to verify that the prediction with an array that has already been normalized is the same as if normalize option is enabled.

"""
X, y = iris.data, iris.target
X_norm = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
for kwargs in (
{},
{'fit_intercept': False},
{'intercept_scaling': 0.01}):
lr1 = logistic.LogisticRegression(normalize=False, **kwargs)
lr1.fit(X_norm, y)
lr2 = logistic.LogisticRegression(normalize=True, **kwargs)
lr2.fit(X, y)
pred1 = lr1.predict_proba(X_norm)
pred2 = lr2.predict_proba(X)
assert_array_almost_equal(pred1, pred2)
17 changes: 14 additions & 3 deletions sklearn/svm/base.py
Expand Up @@ -585,15 +585,17 @@ class BaseLibLinear(six.with_metaclass(ABCMeta, BaseEstimator)):

@abstractmethod
def __init__(self, penalty='l2', loss='l2', dual=True, tol=1e-4, C=1.0,
multi_class='ovr', fit_intercept=True, intercept_scaling=1,
class_weight=None, verbose=0, random_state=None):
multi_class='ovr', fit_intercept=True, normalize=False,
intercept_scaling=1, class_weight=None, verbose=0,
random_state=None):

self.penalty = penalty
self.loss = loss
self.dual = dual
self.tol = tol
self.C = C
self.fit_intercept = fit_intercept
self.normalize = normalize
self.intercept_scaling = intercept_scaling
self.multi_class = multi_class
self.class_weight = class_weight
Expand Down Expand Up @@ -660,7 +662,6 @@ def fit(self, X, y):
" one.")

X = atleast2d_or_csr(X, dtype=np.float64, order="C")

self.class_weight_ = compute_class_weight(self.class_weight,
self.classes_, y)

Expand All @@ -677,6 +678,12 @@ def fit(self, X, y):

# LibLinear wants targets as doubles, even for classification
y = np.asarray(y, dtype=np.float64).ravel()

# Center data if self.normalize
if self.normalize:
X_mean, X_std = np.mean(X, axis=0), np.std(X, axis=0)
X = (X - X_mean) / X_std
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if there will be a consensus to merge this but let me give feedback to help

doing this will break on sparse data. Maybe preventing normalize on sparse data is the way to go.
Also if X has already been copied you should write

X -= X_mean
X /= X_std

this will prevent a reallocation of the size of X.


self.raw_coef_ = liblinear.train_wrap(X, y,
sp.isspmatrix(X),
self._get_solver_type(),
Expand All @@ -696,6 +703,10 @@ def fit(self, X, y):
self.coef_ = self.raw_coef_
self.intercept_ = 0.

if self.normalize:
self.coef_ = self.coef_ / X_std
self.intercept_ = self.intercept_ - np.dot(X_mean, self.coef_.T)

if self.multi_class == "crammer_singer" and len(self.classes_) == 2:
self.coef_ = (self.coef_[1] - self.coef_[0]).reshape(1, -1)
if self.fit_intercept:
Expand Down