Skip to content

Commit

Permalink
[MRG] Add elastic net penalty to LogisticRegression (#11646)
Browse files Browse the repository at this point in the history
* First draft on elasticnet penaly for LogisticRegression

* Some basic tests

* Doc update

* First draft for LogisticRegressionCV.

It seems to be working for binary classification and for multiclass when
multi_class='ovr'. I'm having a hard time figuring out the intricacies
of multi_class='multinomial'.

* Changed default to None for l1_ratio.

added warning message is user sets l1_ratio while penalty is not
elastic-net

* Some more doc

* Updated example to plot elastic net sparsity

* Fixed flake8

* Fixed test by not modifying attribute in fit

* Fixed doc issues

* WIP

* Partially fixed logistic_reg_CV for multinomial.

Also added some comments that are hopefully clear.
Still need to fix refit=False

* Fixed doc issue

* WIP

* Fixed test for refit=False in LogisticRegressionCV

* Fixed Python 2 numpy version issue

* minor doc updates

* Weird doc error...

* Added test to ensure that elastic net is at least as good as L1 or L2
once l1_ratio has been optimized with grid search

Also addressed minor reviews

* Fixed test

* addressed comments

* Added back ignore warning on tests

* Added a functional test

* Scale data in test... Now failing

* elastic-net --> elasticnet

* Updated doc for some attributes and checked their shape in tests

* Added l1_ratio dimension to coefs_paths and scores attr

* improve example + fix test

* FIX incorrect lagged_update in SAGA

* Add non-regression test for SAGA's bug

* FIX flake8 and warning

* Re fixed warning

* Updated some tests

* Addressed comments

* more comments and added dimension to LogisticRegressionCV.n_iter_ attribute

* Updated whatsnew for 0.21

* better doc shape looks

* Fixed whatnew entry after merges

* Added dot

* Addressed comments + standardized optional default param docstrings

* Addessed comments

* use swapaxes instead of unsupported moveaxis (hopefully fixes tests)
  • Loading branch information
NicolasHug authored and amueller committed Nov 22, 2018
1 parent f6f7e3c commit c1f5874
Show file tree
Hide file tree
Showing 8 changed files with 617 additions and 171 deletions.
47 changes: 29 additions & 18 deletions doc/modules/linear_model.rst
Expand Up @@ -338,7 +338,7 @@ the algorithm to fit the coefficients.

.. _elastic_net:

Elastic Net
Elastic-Net
===========
:class:`ElasticNet` is a linear regression model trained with L1 and L2 prior
as regularizer. This combination allows for learning a sparse model where
Expand Down Expand Up @@ -390,7 +390,7 @@ the duality gap computation used for convergence control.

.. _multi_task_elastic_net:

Multi-task Elastic Net
Multi-task Elastic-Net
======================

The :class:`MultiTaskElasticNet` is an elastic-net model that estimates sparse
Expand Down Expand Up @@ -730,7 +730,7 @@ or the log-linear classifier. In this model, the probabilities describing the po

The implementation of logistic regression in scikit-learn can be accessed from
class :class:`LogisticRegression`. This implementation can fit binary, One-vs-
Rest, or multinomial logistic regression with optional L2 or L1
Rest, or multinomial logistic regression with optional L2, L1 or Elastic-Net
regularization.

As an optimization problem, binary class L2 penalized logistic regression
Expand All @@ -739,12 +739,22 @@ minimizes the following cost function:
.. math:: \min_{w, c} \frac{1}{2}w^T w + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1) .

Similarly, L1 regularized logistic regression solves the following
optimization problem
optimization problem:

.. math:: \min_{w, c} \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1).

Elastic-Net regularization is a combination of L1 and L2, and minimizes the
following cost function:

.. math:: \min_{w, c} \frac{1 - \rho}{2}w^T w + \rho \|w\|_1 + C \sum_{i=1}^n \log(\exp(- y_i (X_i^T w + c)) + 1),

where :math:`\rho` controls the strengh of L1 regularization vs L2
regularization (it corresponds to the `l1_ratio` parameter).

Note that, in this notation, it's assumed that the observation :math:`y_i` takes values in the set
:math:`{-1, 1}` at trial :math:`i`.
:math:`{-1, 1}` at trial :math:`i`. We can also see that Elastic-Net is
equivalent to L1 when :math:`\rho = 1` and equivalent to L2 when
:math:`\rho=0`.

The solvers implemented in the class :class:`LogisticRegression`
are "liblinear", "newton-cg", "lbfgs", "sag" and "saga":
Expand Down Expand Up @@ -772,10 +782,12 @@ than other solvers for large datasets, when both the number of samples and the
number of features are large.

The "saga" solver [7]_ is a variant of "sag" that also supports the
non-smooth `penalty="l1"` option. This is therefore the solver of choice
for sparse multinomial logistic regression.
non-smooth `penalty="l1"`. This is therefore the solver of choice for sparse
multinomial logistic regression. It is also the only solver that supports
`penalty="elasticnet"`.

In a nutshell, the following table summarizes the penalties supported by each solver:
In a nutshell, the following table summarizes the penalties supported by
each solver:

+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| | **Solvers** |
Expand All @@ -790,6 +802,8 @@ In a nutshell, the following table summarizes the penalties supported by each so
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| OVR + L1 penalty | yes | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Elastic-Net | no | no | no | no | yes |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| **Behaviors** | |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+
| Penalize the intercept (bad) | yes | no | no | no | no |
Expand All @@ -799,8 +813,8 @@ In a nutshell, the following table summarizes the penalties supported by each so
| Robust to unscaled datasets | yes | yes | yes | no | no |
+------------------------------+-----------------+-------------+-----------------+-----------+------------+

The "saga" solver is often the best choice but requires scaling. The "liblinear" solver is
used by default for historical reasons.
The "saga" solver is often the best choice but requires scaling. The
"liblinear" solver is used by default for historical reasons.

For large dataset, you may also consider using :class:`SGDClassifier`
with 'log' loss.
Expand Down Expand Up @@ -838,14 +852,11 @@ with 'log' loss.
thus be used to perform feature selection, as detailed in
:ref:`l1_feature_selection`.

:class:`LogisticRegressionCV` implements Logistic Regression with
builtin cross-validation to find out the optimal C parameter.
"newton-cg", "sag", "saga" and "lbfgs" solvers are found to be faster
for high-dimensional dense data, due to warm-starting. For the
multiclass case, if `multi_class` option is set to "ovr", an optimal C
is obtained for each class and if the `multi_class` option is set to
"multinomial", an optimal C is obtained by minimizing the cross-entropy
loss.
:class:`LogisticRegressionCV` implements Logistic Regression with built-in
cross-validation support, to find the optimal `C` and `l1_ratio` parameters
according to the ``scoring`` attribute. The "newton-cg", "sag", "saga" and
"lbfgs" solvers are found to be faster for high-dimensional dense data, due
to warm-starting (see :term:`Glossary <warm_start>`).

.. topic:: References:

Expand Down
3 changes: 2 additions & 1 deletion doc/tutorial/statistical_inference/supervised_learning.rst
Expand Up @@ -183,6 +183,7 @@ Linear models: :math:`y = X\beta + \epsilon`
[ 0.30349955 -237.63931533 510.53060544 327.73698041 -814.13170937
492.81458798 102.84845219 184.60648906 743.51961675 76.09517222]


>>> # The mean square error
>>> np.mean((regr.predict(diabetes_X_test) - diabetes_y_test)**2)
... # doctest: +ELLIPSIS
Expand Down Expand Up @@ -378,7 +379,7 @@ function or **logistic** function:
... multi_class='multinomial')
>>> log.fit(iris_X_train, iris_y_train) # doctest: +NORMALIZE_WHITESPACE
LogisticRegression(C=100000.0, class_weight=None, dual=False,
fit_intercept=True, intercept_scaling=1, max_iter=100,
fit_intercept=True, intercept_scaling=1, l1_ratio=None, max_iter=100,
multi_class='multinomial', n_jobs=None, penalty='l2', random_state=None,
solver='lbfgs', tol=0.0001, verbose=0, warm_start=False)

Expand Down
12 changes: 12 additions & 0 deletions doc/whats_new/v0.21.rst
Expand Up @@ -22,6 +22,9 @@ random sampling procedures.
- Decision trees and derived ensembles when both `max_depth` and
`max_leaf_nodes` are set. |Fix|
- :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` with 'saga' solver. |Fix|


Details are listed in the changelog below.

Expand Down Expand Up @@ -146,6 +149,15 @@ Support for Python 3.4 and below has been officially dropped.
affects all ensemble methods using decision trees.
:pr:`12344` by :user:`Adrin Jalali <adrinjalali>`.

:mod:`sklearn.linear_model`
...........................

- |Feature| :class:`linear_model.LogisticRegression` and
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.

- |Fix| Fixed a bug in the 'saga' solver where the weights would not be
correctly updated in some cases. :issue:`11646` by `Tom Dupre la Tour`_.

Multiple modules
................
Expand Down
59 changes: 35 additions & 24 deletions examples/linear_model/plot_logistic_l1_l2_sparsity.py
Expand Up @@ -4,10 +4,11 @@
==============================================
Comparison of the sparsity (percentage of zero coefficients) of solutions when
L1 and L2 penalty are used for different values of C. We can see that large
values of C give more freedom to the model. Conversely, smaller values of C
constrain the model more. In the L1 penalty case, this leads to sparser
solutions.
L1, L2 and Elastic-Net penalty are used for different values of C. We can see
that large values of C give more freedom to the model. Conversely, smaller
values of C constrain the model more. In the L1 penalty case, this leads to
sparser solutions. As expected, the Elastic-Net penalty sparsity is between
that of L1 and L2.
We classify 8x8 images of digits into two classes: 0-4 against 5-9.
The visualization shows coefficients of the models for varying C.
Expand Down Expand Up @@ -35,45 +36,55 @@
# classify small against large digits
y = (y > 4).astype(np.int)

l1_ratio = 0.5 # L1 weight in the Elastic-Net regularization

fig, axes = plt.subplots(3, 3)

# Set regularization parameter
for i, C in enumerate((1, 0.1, 0.01)):
for i, (C, axes_row) in enumerate(zip((1, 0.1, 0.01), axes)):
# turn down tolerance for short training time
clf_l1_LR = LogisticRegression(C=C, penalty='l1', tol=0.01, solver='saga')
clf_l2_LR = LogisticRegression(C=C, penalty='l2', tol=0.01, solver='saga')
clf_en_LR = LogisticRegression(C=C, penalty='elasticnet', solver='saga',
l1_ratio=l1_ratio, tol=0.01)
clf_l1_LR.fit(X, y)
clf_l2_LR.fit(X, y)
clf_en_LR.fit(X, y)

coef_l1_LR = clf_l1_LR.coef_.ravel()
coef_l2_LR = clf_l2_LR.coef_.ravel()
coef_en_LR = clf_en_LR.coef_.ravel()

# coef_l1_LR contains zeros due to the
# L1 sparsity inducing norm

sparsity_l1_LR = np.mean(coef_l1_LR == 0) * 100
sparsity_l2_LR = np.mean(coef_l2_LR == 0) * 100
sparsity_en_LR = np.mean(coef_en_LR == 0) * 100

print("C=%.2f" % C)
print("Sparsity with L1 penalty: %.2f%%" % sparsity_l1_LR)
print("score with L1 penalty: %.4f" % clf_l1_LR.score(X, y))
print("Sparsity with L2 penalty: %.2f%%" % sparsity_l2_LR)
print("score with L2 penalty: %.4f" % clf_l2_LR.score(X, y))
print("{:<40} {:.2f}%".format("Sparsity with L1 penalty:", sparsity_l1_LR))
print("{:<40} {:.2f}%".format("Sparsity with Elastic-Net penalty:",
sparsity_en_LR))
print("{:<40} {:.2f}%".format("Sparsity with L2 penalty:", sparsity_l2_LR))
print("{:<40} {:.2f}".format("Score with L1 penalty:",
clf_l1_LR.score(X, y)))
print("{:<40} {:.2f}".format("Score with Elastic-Net penalty:",
clf_en_LR.score(X, y)))
print("{:<40} {:.2f}".format("Score with L2 penalty:",
clf_l2_LR.score(X, y)))

l1_plot = plt.subplot(3, 2, 2 * i + 1)
l2_plot = plt.subplot(3, 2, 2 * (i + 1))
if i == 0:
l1_plot.set_title("L1 penalty")
l2_plot.set_title("L2 penalty")

l1_plot.imshow(np.abs(coef_l1_LR.reshape(8, 8)), interpolation='nearest',
cmap='binary', vmax=1, vmin=0)
l2_plot.imshow(np.abs(coef_l2_LR.reshape(8, 8)), interpolation='nearest',
cmap='binary', vmax=1, vmin=0)
plt.text(-8, 3, "C = %.2f" % C)

l1_plot.set_xticks(())
l1_plot.set_yticks(())
l2_plot.set_xticks(())
l2_plot.set_yticks(())
axes_row[0].set_title("L1 penalty")
axes_row[1].set_title("Elastic-Net\nl1_ratio = %s" % l1_ratio)
axes_row[2].set_title("L2 penalty")

for ax, coefs in zip(axes_row, [coef_l1_LR, coef_en_LR, coef_l2_LR]):
ax.imshow(np.abs(coefs.reshape(8, 8)), interpolation='nearest',
cmap='binary', vmax=1, vmin=0)
ax.set_xticks(())
ax.set_yticks(())

axes_row[0].set_ylabel('C = %s' % C)

plt.show()

0 comments on commit c1f5874

Please sign in to comment.