[MRG+2] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge #8446

Merged
merged 1 commit into from Mar 27, 2017

Conversation

Projects
None yet
9 participants
@arthurmensch
Contributor

arthurmensch commented Feb 23, 2017

This PR proposes slight adaptations of the existing sag_fast.pyx module to be able to use SAGA in addition to the SAG algorithm. This is the way to go if we want to propose enet/l1 penalty for fast incremental solvers in for ridge and logistic regression.

A SAGA implementation is already available in lightning, which is adapted from the original paper. This one is slightly different as it is built around understanding SAGA update as a "corrected" version of SAG.
I believe it would also be possible to slightly adapt the module to have SVRG in addition to SAGA, if this interests people.

I tried to keep the changes made to sag_fast.pyx as scarse as possible. I reckon that the sag_fast.pyx module could be made a little more readable using 2d memoryviews instead of using strided pointers everywhere. For further work.

For the moment I adapted the test_logistic.py file to ensure correctness of the algorithm, but the saga algorithm should be tested within test_sag.py.

SAGA paper

TODO

  • Documentation
  • Reference for step size
  • Check optimal step size
  • Ridge API + tests
  • Test module sag.py directly
  • Implement l1 penalty (simple projection)
  • Different PR Use minibatches ? I cannot recall whether this is actually interesting.
  • Add nice benchmarks.
  • Benchmarks against liblinear and lightning
  • Different PR Add a SAGA solver for Lasso (which might imply a bit of refactoring...)
  • Add rcv1 example with multinomial + L1
  • Different PR Add elastic net l1_ratio in LogisticRegression
    ping @TomDLT @agramfort you might be interested by this :)
@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Feb 23, 2017

Member
Member

GaelVaroquaux commented Feb 23, 2017

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Feb 23, 2017

Member
Member

GaelVaroquaux commented Feb 23, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Feb 23, 2017

Contributor

Indeed quite a few benchmarks would be necessary. Added !

Contributor

arthurmensch commented Feb 23, 2017

Indeed quite a few benchmarks would be necessary. Added !

sklearn/linear_model/logistic.py
-from ..externals import six
-from ..metrics import SCORERS
+from ..utils.optimize import newton_cg
+from ..utils.validation import check_X_y

This comment has been minimized.

@arthurmensch

arthurmensch Feb 23, 2017

Contributor

I guessed my editor just put these automatically in alphabetic order, I can revert if necessary

@arthurmensch

arthurmensch Feb 23, 2017

Contributor

I guessed my editor just put these automatically in alphabetic order, I can revert if necessary

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Feb 23, 2017

Member
Member

GaelVaroquaux commented Feb 23, 2017

@codecov

This comment has been minimized.

Show comment
Hide comment
@codecov

codecov bot Feb 23, 2017

Codecov Report

Merging #8446 into master will increase coverage by <.01%.
The diff coverage is 100%.

@@            Coverage Diff             @@
##           master    #8446      +/-   ##
==========================================
+ Coverage   95.47%   95.48%   +<.01%     
==========================================
  Files         342      342              
  Lines       60907    61007     +100     
==========================================
+ Hits        58154    58255     +101     
+ Misses       2753     2752       -1
Impacted Files Coverage Δ
sklearn/linear_model/tests/test_sag.py 98.6% <100%> (+0.07%)
sklearn/linear_model/sag.py 94.36% <100%> (+0.52%)
sklearn/linear_model/tests/test_logistic.py 100% <100%> (ø)
sklearn/linear_model/logistic.py 97.65% <100%> (+0.03%)
sklearn/linear_model/tests/test_ridge.py 100% <100%> (ø)
sklearn/linear_model/ridge.py 93.88% <100%> (ø)
sklearn/linear_model/coordinate_descent.py 96.94% <0%> (ø)
sklearn/tree/tree.py 98.41% <0%> (ø)
sklearn/metrics/classification.py 97.77% <0%> (ø)
sklearn/decomposition/tests/test_pca.py 100% <0%> (ø)
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update fcb1403...2a5c4bf. Read the comment docs.

codecov bot commented Feb 23, 2017

Codecov Report

Merging #8446 into master will increase coverage by <.01%.
The diff coverage is 100%.

@@            Coverage Diff             @@
##           master    #8446      +/-   ##
==========================================
+ Coverage   95.47%   95.48%   +<.01%     
==========================================
  Files         342      342              
  Lines       60907    61007     +100     
==========================================
+ Hits        58154    58255     +101     
+ Misses       2753     2752       -1
Impacted Files Coverage Δ
sklearn/linear_model/tests/test_sag.py 98.6% <100%> (+0.07%)
sklearn/linear_model/sag.py 94.36% <100%> (+0.52%)
sklearn/linear_model/tests/test_logistic.py 100% <100%> (ø)
sklearn/linear_model/logistic.py 97.65% <100%> (+0.03%)
sklearn/linear_model/tests/test_ridge.py 100% <100%> (ø)
sklearn/linear_model/ridge.py 93.88% <100%> (ø)
sklearn/linear_model/coordinate_descent.py 96.94% <0%> (ø)
sklearn/tree/tree.py 98.41% <0%> (ø)
sklearn/metrics/classification.py 97.77% <0%> (ø)
sklearn/decomposition/tests/test_pca.py 100% <0%> (ø)
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update fcb1403...2a5c4bf. Read the comment docs.

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Feb 23, 2017

Contributor

A quick benchmark :

Classification performance:
===========================
Classifier               train-time   test-time   error-rate
------------------------------------------------------------
LinearRegression-SAGA        17.25s       0.07s       0.0810
LinearRegression-SAG         20.64s       0.10s       0.0824
Contributor

arthurmensch commented Feb 23, 2017

A quick benchmark :

Classification performance:
===========================
Classifier               train-time   test-time   error-rate
------------------------------------------------------------
LinearRegression-SAGA        17.25s       0.07s       0.0810
LinearRegression-SAG         20.64s       0.10s       0.0824
@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Feb 24, 2017

Contributor

figure_1

I added some code to do benchmarks, which should be removed for merging.

With the conservative auto step size that is used, SAGA performs better in the first epochs and SAG gets better for finer convergence (a behavior already observed in the litterature). With more aggressive stepsizes, SAGA performs better than SAG. We could use line-search as it tends to produce better results, it is used in sgd in scikit-learn ?

Contributor

arthurmensch commented Feb 24, 2017

figure_1

I added some code to do benchmarks, which should be removed for merging.

With the conservative auto step size that is used, SAGA performs better in the first epochs and SAG gets better for finer convergence (a behavior already observed in the litterature). With more aggressive stepsizes, SAGA performs better than SAG. We could use line-search as it tends to produce better results, it is used in sgd in scikit-learn ?

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Feb 24, 2017

Member
Member

GaelVaroquaux commented Feb 24, 2017

@agramfort

This comment has been minimized.

Show comment
Hide comment
@agramfort

agramfort Feb 26, 2017

Member

@arthurmensch can you show a benchmark with $log10(f(x^k) - f(x^*))$ on y axis and time on x axis?
I suspect marginal improvements for L2 log reg so it would need a convincing figure on L1 reg logistic so a comparison with liblinear that we use presently.

Member

agramfort commented Feb 26, 2017

@arthurmensch can you show a benchmark with $log10(f(x^k) - f(x^*))$ on y axis and time on x axis?
I suspect marginal improvements for L2 log reg so it would need a convincing figure on L1 reg logistic so a comparison with liblinear that we use presently.

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Feb 27, 2017

Contributor

@agramfort can you recall a dataset on which SAG would beat liblinear with l2 penalty ?

Contributor

arthurmensch commented Feb 27, 2017

@agramfort can you recall a dataset on which SAG would beat liblinear with l2 penalty ?

@agramfort

This comment has been minimized.

Show comment
Hide comment
@agramfort

agramfort Feb 27, 2017

Member
Member

agramfort commented Feb 27, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Feb 28, 2017

Contributor

I did some benchmarks against lightning for the moment, on rcv1. For some reason we are 10x faster than lightning with L1 penalty, any thought @fabianp ?

l2 logistic:

log_l2

l1 logistic:

log l1

Credits for the shortcut in the composition of prox operators goes to @fabianp and @mblondel, but I think it looks a bit cleaner that way.

I will post benchmarks against liblinear ASAP.

Contributor

arthurmensch commented Feb 28, 2017

I did some benchmarks against lightning for the moment, on rcv1. For some reason we are 10x faster than lightning with L1 penalty, any thought @fabianp ?

l2 logistic:

log_l2

l1 logistic:

log l1

Credits for the shortcut in the composition of prox operators goes to @fabianp and @mblondel, but I think it looks a bit cleaner that way.

I will post benchmarks against liblinear ASAP.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Feb 28, 2017

Member

This is looking great. And the fun part is that it seems that reimplementing code and comparing teaches us a lot, even when the two implementations are done by people that share so much.

Member

GaelVaroquaux commented Feb 28, 2017

This is looking great. And the fun part is that it seems that reimplementing code and comparing teaches us a lot, even when the two implementations are done by people that share so much.

@arthurmensch arthurmensch changed the title from [WIP] SAGA support for LogisticRegression and Ridge to [MRG] SAGA support for LogisticRegression and Ridge Feb 28, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Feb 28, 2017

Contributor

For single class l1/l2 logistic regression liblinear is hard to beat. But for multiclass, speed improvements look great:

figure_1-1

There is a discrepency in the final accuracy as both model are not the same (the loss is different, and liblinear penalizes the intercept, but I reckon they perform the same with correct cross validation.

Two-class problem (there is a problem there as the final training score should not be that different).

figure_1-2

One of the advantage of using SAGA instead of liblinear is that we can perform CV with memory-mapped data, which can be crucial when datasets are huge.

Contributor

arthurmensch commented Feb 28, 2017

For single class l1/l2 logistic regression liblinear is hard to beat. But for multiclass, speed improvements look great:

figure_1-1

There is a discrepency in the final accuracy as both model are not the same (the loss is different, and liblinear penalizes the intercept, but I reckon they perform the same with correct cross validation.

Two-class problem (there is a problem there as the final training score should not be that different).

figure_1-2

One of the advantage of using SAGA instead of liblinear is that we can perform CV with memory-mapped data, which can be crucial when datasets are huge.

@mblondel

This comment has been minimized.

Show comment
Hide comment
@mblondel

mblondel Feb 28, 2017

Member

It could be interesting to see if both implementations return more or less the same weight vectors.

I remember it took @fabianp and @zermelozf several iterations to implement the lazy updates correctly. We have a naive Pure python implementation in our tests to ensure correctness:

https://github.com/scikit-learn-contrib/lightning/blob/master/lightning/impl/tests/test_sag.py#L85

Member

mblondel commented Feb 28, 2017

It could be interesting to see if both implementations return more or less the same weight vectors.

I remember it took @fabianp and @zermelozf several iterations to implement the lazy updates correctly. We have a naive Pure python implementation in our tests to ensure correctness:

https://github.com/scikit-learn-contrib/lightning/blob/master/lightning/impl/tests/test_sag.py#L85

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 1, 2017

Contributor

single_target_l2
single_target_l1
multi_target_l2
multi_target_l1

These are the right benches liblinear vs lightning vs sklearn saga, for 1/20 of rcv1.

  • SAGA sklearn is better for multitarget regression for l2 and l1
  • SAGA sklearn is better for single target regression for l2, and a little slower than liblinear with l1
  • SAGA sklearn and lightning are roughtly on par with l2 regularisation, single-target
  • SAGA sklearn is ten times faster than lightning with l1 regularisation, single-target. This is surprising

I am currently running the benches on the whole dataset, to see if saga does not get better than liblinear at some point.

The benchmark file that is in this PR does not use callbaks or any hacks, so I think it can stay in the repo for future reference.

Contributor

arthurmensch commented Mar 1, 2017

single_target_l2
single_target_l1
multi_target_l2
multi_target_l1

These are the right benches liblinear vs lightning vs sklearn saga, for 1/20 of rcv1.

  • SAGA sklearn is better for multitarget regression for l2 and l1
  • SAGA sklearn is better for single target regression for l2, and a little slower than liblinear with l1
  • SAGA sklearn and lightning are roughtly on par with l2 regularisation, single-target
  • SAGA sklearn is ten times faster than lightning with l1 regularisation, single-target. This is surprising

I am currently running the benches on the whole dataset, to see if saga does not get better than liblinear at some point.

The benchmark file that is in this PR does not use callbaks or any hacks, so I think it can stay in the repo for future reference.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Mar 1, 2017

Member
Member

GaelVaroquaux commented Mar 1, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 2, 2017

Contributor

Should I try to add a solver for Lasso ? For the moment scikit-learn has quite a variety of Lasso solvers (Lasso,LassoLars, CV versions, IC versions, RandomizedLasso) so I am not sure where to put it. In Lasso, adding a solver option ?

Contributor

arthurmensch commented Mar 2, 2017

Should I try to add a solver for Lasso ? For the moment scikit-learn has quite a variety of Lasso solvers (Lasso,LassoLars, CV versions, IC versions, RandomizedLasso) so I am not sure where to put it. In Lasso, adding a solver option ?

@TomDLT

Nice work ! The code looks good so far.
I still need to understand the JIT prox update.

Use minibatches ? I cannot recall whether this is actually interesting.

Mark Schmidt says it's interesting in SAG (slide 73), yet I am not sure we need it in scikit-learn.

sklearn/linear_model/logistic.py
-
+ if penalty == 'l1':
+ if solver == 'sag':
+ raise ValueError("Unsupported penalty. Use `saga` instead.")

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

This is never reached, thanks to _check_solver_option, isn't it?

@TomDLT

TomDLT Mar 2, 2017

Member

This is never reached, thanks to _check_solver_option, isn't it?

-def get_auto_step_size(max_squared_sum, alpha_scaled, loss, fit_intercept):
+def get_auto_step_size(max_squared_sum, alpha_scaled, loss, fit_intercept,
+ n_samples=None,

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

update docstring

@TomDLT

TomDLT Mar 2, 2017

Member

update docstring

This comment has been minimized.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

Done

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

Done

sklearn/linear_model/sag_fast.pyx
+
+
+
+cdef double lagged_update(double* weights, double wscale, int xnnz,

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

update the docstring description

@TomDLT

TomDLT Mar 2, 2017

Member

update the docstring description

+ else:
+ for class_ind in range(n_classes):
+ idx = f_idx + class_ind
+ if fabs(sum_gradient[idx] * cum_sum) < cum_sum_prox:

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

can you explain this? and maybe add some comments

@TomDLT

TomDLT Mar 2, 2017

Member

can you explain this? and maybe add some comments

This comment has been minimized.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

This is a nice trick from lightning: instead of enrolling the whole delayed softmax(softmax(softmax(w - grad_update) - grad_update ...) in the loop below, we factorize it as we do not cross the non-linearity due to the softmax. There is no academical reference for this, but we should indeed add some comments.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

This is a nice trick from lightning: instead of enrolling the whole delayed softmax(softmax(softmax(w - grad_update) - grad_update ...) in the loop below, we factorize it as we do not cross the non-linearity due to the softmax. There is no academical reference for this, but we should indeed add some comments.

This comment has been minimized.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

It would be nice if there was a small blog post with the derivations for this on the web. @fabianp sent me an unfinished draft for this but I gathered it was to stay a draft.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

It would be nice if there was a small blog post with the derivations for this on the web. @fabianp sent me an unfinished draft for this but I gathered it was to stay a draft.

This comment has been minimized.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

If you do not have time I can write a small blog post wit mathematical derivation for this part, with due reference to @fabianp @zermelozf @mblondel. Then we can reference it in comment.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

If you do not have time I can write a small blog post wit mathematical derivation for this part, with due reference to @fabianp @zermelozf @mblondel. Then we can reference it in comment.

This comment has been minimized.

@fabianp

fabianp Mar 3, 2017

Member

That would be awesome. I'm sending you the .tex sources of that draft so you can use it as you please.

@fabianp

fabianp Mar 3, 2017

Member

That would be awesome. I'm sending you the .tex sources of that draft so you can use it as you please.

This comment has been minimized.

@ogrisel

ogrisel Mar 21, 2017

Member

Even without the full blog post with the derivation it would be great to have an inline comment to state:

@ogrisel

ogrisel Mar 21, 2017

Member

Even without the full blog post with the derivation it would be great to have an inline comment to state:

sklearn/linear_model/sag_fast.pyx
+ if reset:
+ cumulative_sums[sample_itr - 1] = 0.0
+ if prox:
+ cumulative_sums_prox[sample_itr - 1] = 0.0
# reset wscale to 1.0

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

You should return void and return 1.0 only in scale_weights

@TomDLT

TomDLT Mar 2, 2017

Member

You should return void and return 1.0 only in scale_weights

sklearn/linear_model/sag_fast.pyx
+ cdef np.ndarray[double, ndim=1] cumulative_sums_prox_array
+ cdef double* cumulative_sums_prox
+
+ cdef bint prox = beta > 0

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

it could be safe to check that saga is True

@TomDLT

TomDLT Mar 2, 2017

Member

it could be safe to check that saga is True

sklearn/linear_model/sag.py
else:
raise ValueError("Unknown loss function for SAG solver, got %s "
"instead of 'log' or 'squared'" % loss)
+ if is_saga:
+ mun = min(2 * n_samples * alpha_scaled, L)

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

where does it come from?

@TomDLT

TomDLT Mar 2, 2017

Member

where does it come from?

This comment has been minimized.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

SAGA original paper: proofs of convergence requires step_size < 1 / 3L or step_size < 1 / (2(L + mu n), where mu is the strong convexity modulus of the objective. We could use 1 / 3L but this is more optimal in the low dimensional regime. I will add a reference. By the way SAG use 1 / L whereas the only step size for which proofs are available is 1 / 16 L. I think this is a sound heuristic but we should add a reference as well. 1 / L is also a good heuristic for SAGA in most cases, but it actually make one test fail in the test suite :P We could allow the user to specify the step size, but it would complexify the API.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

SAGA original paper: proofs of convergence requires step_size < 1 / 3L or step_size < 1 / (2(L + mu n), where mu is the strong convexity modulus of the objective. We could use 1 / 3L but this is more optimal in the low dimensional regime. I will add a reference. By the way SAG use 1 / L whereas the only step size for which proofs are available is 1 / 16 L. I think this is a sound heuristic but we should add a reference as well. 1 / L is also a good heuristic for SAGA in most cases, but it actually make one test fail in the test suite :P We could allow the user to specify the step size, but it would complexify the API.

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

Ok
For SAG, I used the recommendation of Mark Schmidt: (slide 65)

@TomDLT

TomDLT Mar 2, 2017

Member

Ok
For SAG, I used the recommendation of Mark Schmidt: (slide 65)

@@ -261,26 +272,42 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1.,
if max_squared_sum is None:
max_squared_sum = row_norms(X, squared=True).max()
step_size = get_auto_step_size(max_squared_sum, alpha_scaled, loss,

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

should the step size depend on beta_scaled in the L1 case?

@TomDLT

TomDLT Mar 2, 2017

Member

should the step size depend on beta_scaled in the L1 case?

This comment has been minimized.

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

No, it depends on L which is the Lipschitz constant of the gradient of the smooth objective part (i.e logistic + optional l2).

@arthurmensch

arthurmensch Mar 2, 2017

Contributor

No, it depends on L which is the Lipschitz constant of the gradient of the smooth objective part (i.e logistic + optional l2).

This comment has been minimized.

@TomDLT

TomDLT Mar 2, 2017

Member

oh yes of course

@TomDLT

TomDLT Mar 2, 2017

Member

oh yes of course

@arthurmensch arthurmensch changed the title from [MRG] SAGA support for LogisticRegression and Ridge to [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ l1 support) Mar 2, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 2, 2017

Contributor

Should we add an example for SAGA + some narrative documentation ? I suggest multinomial logistic + l1 on rcv1, with comparison against one versus all with liblinear, although this would take a few minutes on the whole dataset (because of ovr).

I suggest we work on merging this PR before working on using SAGA for Lasso.

There is also a question regarding whether we should propose both sag and saga to the end-user, or deprecate sag ? @GaelVaroquaux.

Contributor

arthurmensch commented Mar 2, 2017

Should we add an example for SAGA + some narrative documentation ? I suggest multinomial logistic + l1 on rcv1, with comparison against one versus all with liblinear, although this would take a few minutes on the whole dataset (because of ovr).

I suggest we work on merging this PR before working on using SAGA for Lasso.

There is also a question regarding whether we should propose both sag and saga to the end-user, or deprecate sag ? @GaelVaroquaux.

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 2, 2017

Contributor

ping @TomDLT @agramfort @GaelVaroquaux I think this is ready for review. Reviews welcome from @fabianp and @mblondel for the core code if you have time.

A few UX questions to answer before I move on :

  • Do we agree on a multinomial + L1 example on rcv1, with comparison with ovr + liblinear ?
  • Do we deprecate sag solver or keep both sag and saga ? Is sag solver useful in any way compared to saga @fabianp ?
  • Do we add an l1_ratio to LogisticRegression as we now can ? C + l1_ratio sounds a bit original but it is still meaningful.
  • Naming is now a bit fishy (sag_solver for saga). I guess we can keep it that way for the moment, but it will require refactoring if we add for instance an svrg solver to the code base.

I reckon adding l1_ratio can be made in a new PR to keep this one minimalistic.

Contributor

arthurmensch commented Mar 2, 2017

ping @TomDLT @agramfort @GaelVaroquaux I think this is ready for review. Reviews welcome from @fabianp and @mblondel for the core code if you have time.

A few UX questions to answer before I move on :

  • Do we agree on a multinomial + L1 example on rcv1, with comparison with ovr + liblinear ?
  • Do we deprecate sag solver or keep both sag and saga ? Is sag solver useful in any way compared to saga @fabianp ?
  • Do we add an l1_ratio to LogisticRegression as we now can ? C + l1_ratio sounds a bit original but it is still meaningful.
  • Naming is now a bit fishy (sag_solver for saga). I guess we can keep it that way for the moment, but it will require refactoring if we add for instance an svrg solver to the code base.

I reckon adding l1_ratio can be made in a new PR to keep this one minimalistic.

@arthurmensch arthurmensch changed the title from [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ l1 support) to [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) Mar 2, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 3, 2017

Contributor

On the whole rcv1 dataset, single target with high regularization, saga is faster than liblinear: C=0.1

single_target_l1

This is not true for low regularization : C=1

single_target_l1

This is expected (with high regularization the soft thresholding shortcut trick works more often).

Contributor

arthurmensch commented Mar 3, 2017

On the whole rcv1 dataset, single target with high regularization, saga is faster than liblinear: C=0.1

single_target_l1

This is not true for low regularization : C=1

single_target_l1

This is expected (with high regularization the soft thresholding shortcut trick works more often).

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Mar 3, 2017

Member

How does SAGA perform on small datasets (e.g. iris)?
It would be great if it would compete with liblinear, so we can change the default solver in LogisticRegression.
The regularization of the intercept in liblinear is often confusing for users.

Member

TomDLT commented Mar 3, 2017

How does SAGA perform on small datasets (e.g. iris)?
It would be great if it would compete with liblinear, so we can change the default solver in LogisticRegression.
The regularization of the intercept in liblinear is often confusing for users.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Mar 3, 2017

Member

We will need a very clear paragraph in the documentation that explains which solver to choose when.

Member

GaelVaroquaux commented Mar 3, 2017

We will need a very clear paragraph in the documentation that explains which solver to choose when.

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 3, 2017

Contributor

I guess it can be set as default on small dataset, given the little loss in performance. On iris. The bounce in log loss is due to the approximation of the logistic function (sgd_fast.pyx)

single_target_l1
single_target_l2

Contributor

arthurmensch commented Mar 3, 2017

I guess it can be set as default on small dataset, given the little loss in performance. On iris. The bounce in log loss is due to the approximation of the logistic function (sgd_fast.pyx)

single_target_l1
single_target_l2

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Mar 3, 2017

Member
Member

GaelVaroquaux commented Mar 3, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 3, 2017

Contributor

Digits. Final loss is different due to intercept penalty in liblinear. No loss in performance

single_target_l1
single_target_l2

Contributor

arthurmensch commented Mar 3, 2017

Digits. Final loss is different due to intercept penalty in liblinear. No loss in performance

single_target_l1
single_target_l2

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 3, 2017

Contributor

Zoom digits:

single_target_l1
single_target_l2

Contributor

arthurmensch commented Mar 3, 2017

Zoom digits:

single_target_l1
single_target_l2

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Mar 3, 2017

Member

You got to admit that liblinear works well in the l1 case

Member

GaelVaroquaux commented Mar 3, 2017

You got to admit that liblinear works well in the l1 case

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 3, 2017

Contributor
Contributor

arthurmensch commented Mar 3, 2017

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 3, 2017

Contributor
Contributor

arthurmensch commented Mar 3, 2017

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Mar 3, 2017

Member
Member

GaelVaroquaux commented Mar 3, 2017

@fabianp

This comment has been minimized.

Show comment
Hide comment
@fabianp

fabianp Mar 3, 2017

Member

Great work @arthurmensch . +1 to deprecate SAG. I don't see a usercase for it having SAGA. In my experience the difference between both is marginal but SAGA is much more versatile.

Also, thanks for the comparisons. Lightning seems to have a large overhead in the first iteration, we should look into that :-)

Member

fabianp commented Mar 3, 2017

Great work @arthurmensch . +1 to deprecate SAG. I don't see a usercase for it having SAGA. In my experience the difference between both is marginal but SAGA is much more versatile.

Also, thanks for the comparisons. Lightning seems to have a large overhead in the first iteration, we should look into that :-)

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Mar 3, 2017

Member

Does this support elastic net penalty? Is there a reason not to? [see #8288]

Member

amueller commented Mar 3, 2017

Does this support elastic net penalty? Is there a reason not to? [see #8288]

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 4, 2017

Contributor

It is indeed cheap to do, but would require some explanation regarding which solver to use. If we put saga by default l1_ratio will be usable, but we should explain somewhere that liblinear might be a sound choice for large regularization (and also state when to use newton-cg and lbfgs).

However this will make this PR larger and it is already quite big, so it might be better to focus on this one first (adding an elastic net penalty would require an example + some extra documentation not related to the optimizer).

Contributor

arthurmensch commented Mar 4, 2017

It is indeed cheap to do, but would require some explanation regarding which solver to use. If we put saga by default l1_ratio will be usable, but we should explain somewhere that liblinear might be a sound choice for large regularization (and also state when to use newton-cg and lbfgs).

However this will make this PR larger and it is already quite big, so it might be better to focus on this one first (adding an elastic net penalty would require an example + some extra documentation not related to the optimizer).

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 6, 2017

Member

I am not sure if it's worth deprecating SAG. Most of the code is shared with the SAGA implementation, I don't think it adds much complexity to the code base (@arthurmensch tell us if you disagree) so I am fine with keeping SAG. We should just make it clear in the description that SAG and SAGA usually have similar convergence behaviors on l2 penalized models but only SAGA can fit non-smooth l1-penalized models.

Member

ogrisel commented Mar 6, 2017

I am not sure if it's worth deprecating SAG. Most of the code is shared with the SAGA implementation, I don't think it adds much complexity to the code base (@arthurmensch tell us if you disagree) so I am fine with keeping SAG. We should just make it clear in the description that SAG and SAGA usually have similar convergence behaviors on l2 penalized models but only SAGA can fit non-smooth l1-penalized models.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 6, 2017

Member

+1 for adding solver='sag' solver='saga' as an option to the Lasso and LassoCV classes. I think we this could be done as part of this PR to give us more opportunity to test the sag solver on a different loss + penalty combination.

+1 for elastic net but I am fine keeping that for a later PR.

Member

ogrisel commented Mar 6, 2017

+1 for adding solver='sag' solver='saga' as an option to the Lasso and LassoCV classes. I think we this could be done as part of this PR to give us more opportunity to test the sag solver on a different loss + penalty combination.

+1 for elastic net but I am fine keeping that for a later PR.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Mar 6, 2017

Member
Member

GaelVaroquaux commented Mar 6, 2017

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 7, 2017

Member

There is a problem in the benchmark script:

Traceback (most recent call last):
  File "benchmarks/bench_saga.py", line 241, in <module>
    max_iter=20)
  File "benchmarks/bench_saga.py", line 156, in exp
    for solver in solvers
  File "/home/ogrisel/code/scikit-learn/sklearn/externals/joblib/parallel.py", line 758, in __call__
    while self.dispatch_one_batch(iterator):
  File "/home/ogrisel/code/scikit-learn/sklearn/externals/joblib/parallel.py", line 603, in dispatch_one_batch
    tasks = BatchedCalls(itertools.islice(iterator, batch_size))
  File "/home/ogrisel/code/scikit-learn/sklearn/externals/joblib/parallel.py", line 127, in __init__
    self.items = list(iterator_slice)
  File "benchmarks/bench_saga.py", line 158, in <genexpr>
    for single_target in single_targets)
TypeError: 'bool' object is not iterable
Member

ogrisel commented Mar 7, 2017

There is a problem in the benchmark script:

Traceback (most recent call last):
  File "benchmarks/bench_saga.py", line 241, in <module>
    max_iter=20)
  File "benchmarks/bench_saga.py", line 156, in exp
    for solver in solvers
  File "/home/ogrisel/code/scikit-learn/sklearn/externals/joblib/parallel.py", line 758, in __call__
    while self.dispatch_one_batch(iterator):
  File "/home/ogrisel/code/scikit-learn/sklearn/externals/joblib/parallel.py", line 603, in dispatch_one_batch
    tasks = BatchedCalls(itertools.islice(iterator, batch_size))
  File "/home/ogrisel/code/scikit-learn/sklearn/externals/joblib/parallel.py", line 127, in __init__
    self.items = list(iterator_slice)
  File "benchmarks/bench_saga.py", line 158, in <genexpr>
    for single_target in single_targets)
TypeError: 'bool' object is not iterable
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 7, 2017

Member

The sparse mnist filter obtained by L1 logistic regression with SAGA are really beautiful:

image

Impressive to get this in less than 10s.

Member

ogrisel commented Mar 7, 2017

The sparse mnist filter obtained by L1 logistic regression with SAGA are really beautiful:

image

Impressive to get this in less than 10s.

@ogrisel

Partial review, more comments later.

benchmarks/bench_saga.py
+def fit_single(solver, X, y, penalty='l2',
+ single_target=True,
+ C=1,
+ max_iter=10):

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

style: please do not insert unnecessary line breaks.

@ogrisel

ogrisel Mar 7, 2017

Member

style: please do not insert unnecessary line breaks.

benchmarks/bench_mnist.py
@@ -92,6 +92,8 @@ def load_data(dtype=np.float32, order='F'):
'SampledRBF-SVM': make_pipeline(
RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)),
'LinearRegression-SAG': LogisticRegression(solver='sag', tol=1e-1, C=1e4),
+ 'LinearRegression-SAGA': LogisticRegression(solver='saga', tol=1e-1,
+ C=1e4),

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

Those two should be renamed to "LogisticRegression". Linear regression should be reserved for the ridge, huber of absolute losses.

@ogrisel

ogrisel Mar 7, 2017

Member

Those two should be renamed to "LogisticRegression". Linear regression should be reserved for the ridge, huber of absolute losses.

benchmarks/bench_saga.py
+ print('Solving %s logistic regression with penalty %s, solver %s.'
+ % ('binary' if single_target else 'multinomial',
+ penalty,
+ solver))

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

style: solver should stay on the same line as penalty.

@ogrisel

ogrisel Mar 7, 2017

Member

style: solver should stay on the same line as penalty.

benchmarks/bench_saga.py
+ y = rcv1.target
+ y = lbin.inverse_transform(y)
+ X = X
+ y = y

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

Why?

+
+Performance of multinomial logistic regression with
+L1 penalty. We use the SAGA algorithm for this purpose, which is fast. Test
+accuracy reaches > 0.8, while classification vectors remains *sparse* and

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

weight vectors for each class.

@ogrisel

ogrisel Mar 7, 2017

Member

weight vectors for each class.

+ models[model]['sparsities'] = sparsities
+ models[model]['accuracies'] = accuracies
+ print('Test accuracy for model %s: %.4f' % (model, accuracies[-1]))
+ print('Model sparsity for model %s: % .2f' % (model, sparsities[-1]))

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

Maybe you could print the number of of non-zero weights selected for each class label.

@ogrisel

ogrisel Mar 7, 2017

Member

Maybe you could print the number of of non-zero weights selected for each class label.

sklearn/linear_model/logistic.py
-
+ else:
+ # Benchmark
+ self.coef_ = np.zeros((n_classes, n_features))

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

Why?

sklearn/linear_model/logistic.py
if warm_start_coef is None:
warm_start_coef = [None] * n_classes
path_func = delayed(logistic_regression_path)
# The SAG solver releases the GIL so it's more efficient to use
# threads for this solver.
- backend = 'threading' if self.solver == 'sag' else 'multiprocessing'
+ backend = 'threading' if self.solver in ['sag', 'saga']\
+ else 'multiprocessing'

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

Please do no use line break escaping. If the ternary conditional expression does not fit on a single line, just use a traditional multiline conditional. It's more readable:

if self.solver in ['sag', 'saga']:
    backend = 'threading'
else:
    backend = 'multiprocessing'

BTW @TomDLT did you try to use the threading backend with the other solvers? Don't they release the GIL?

For liblinear we should release the GIL in our bindings if this is not already the case. For the scipy solvers I am pretty sure they do it already. If not we should report this as an issue on their tracker.

@ogrisel

ogrisel Mar 7, 2017

Member

Please do no use line break escaping. If the ternary conditional expression does not fit on a single line, just use a traditional multiline conditional. It's more readable:

if self.solver in ['sag', 'saga']:
    backend = 'threading'
else:
    backend = 'multiprocessing'

BTW @TomDLT did you try to use the threading backend with the other solvers? Don't they release the GIL?

For liblinear we should release the GIL in our bindings if this is not already the case. For the scipy solvers I am pretty sure they do it already. If not we should report this as an issue on their tracker.

sklearn/linear_model/logistic.py
@@ -1612,7 +1638,8 @@ def fit(self, X, y, sample_weight=None):
# The SAG solver releases the GIL so it's more efficient to use
# threads for this solver.
- backend = 'threading' if self.solver == 'sag' else 'multiprocessing'
+ backend = 'threading' if self.solver in ['sag', 'saga']\
+ else 'multiprocessing'

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

Same comment on multiline conditional.

@ogrisel

ogrisel Mar 7, 2017

Member

Same comment on multiline conditional.

- same scale. You can preprocess the data with a scaler from
- sklearn.preprocessing.
+ - 'sag' uses a Stochastic Average Gradient descent, and 'saga' uses
+ its improved, unbiased version named SAGA. Both methods also use an

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

"improved" is subjective and Nicolas Le Roux won't be our friend anymore ;) Let's stick to objective qualifiers such as "unbiased" and "more flexible" (because of support for non-smooth penalty).

@ogrisel

ogrisel Mar 7, 2017

Member

"improved" is subjective and Nicolas Le Roux won't be our friend anymore ;) Let's stick to objective qualifiers such as "unbiased" and "more flexible" (because of support for non-smooth penalty).

This comment has been minimized.

@arthurmensch

arthurmensch Mar 7, 2017

Contributor

Ok, but some people will tell you that the A of SAGA stands for "amélioré".

@arthurmensch

arthurmensch Mar 7, 2017

Contributor

Ok, but some people will tell you that the A of SAGA stands for "amélioré".

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

héhé

@ogrisel

ogrisel Mar 7, 2017

Member

héhé

This comment has been minimized.

@fabianp

fabianp Mar 14, 2017

Member

I've been told the A stands for augmented

@fabianp

fabianp Mar 14, 2017

Member

I've been told the A stands for augmented

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 7, 2017

Contributor

So should I go for Lasso ? Heavy refactoring expected, the whole ElasticNet/Lasso classes are tightly coupled to the coordinate descent solver.

Contributor

arthurmensch commented Mar 7, 2017

So should I go for Lasso ? Heavy refactoring expected, the whole ElasticNet/Lasso classes are tightly coupled to the coordinate descent solver.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 7, 2017

Member

I would introduce private method _fit_coordinate_descent(self, X, y) and _fit_saga(self, X, y) methods to keep the diff readable. The input checks can stay in the public fit method.

Member

ogrisel commented Mar 7, 2017

I would introduce private method _fit_coordinate_descent(self, X, y) and _fit_saga(self, X, y) methods to keep the diff readable. The input checks can stay in the public fit method.

+for i in range(1, 10):
+ l1_plot = plt.subplot(3, 3, i)
+ l1_plot.imshow(np.abs(coef[i].reshape(28, 28)), interpolation='nearest',
+ cmap='binary', vmin=0, vmax=1)

This comment has been minimized.

@ogrisel

ogrisel Mar 7, 2017

Member

Please use the red blue colormap of matplotlib such that 0 is white. This should better highlight the sparsity of the filters.

@ogrisel

ogrisel Mar 7, 2017

Member

Please use the red blue colormap of matplotlib such that 0 is white. This should better highlight the sparsity of the filters.

@@ -0,0 +1,114 @@
+"""
+=====================================================
+Multiclass logisitic regression on newgroups20

This comment has been minimized.

@ogrisel

ogrisel Mar 20, 2017

Member

Please also put "sparse" in the title.

@ogrisel

ogrisel Mar 20, 2017

Member

Please also put "sparse" in the title.

+features to zero. This is good if the goal is to extract the strongly
+discriminative vocabulary of each class. If the goal is to get the best
+predictive accuracy, it is better to use the non sparsity-inducing
+l2 penalty instead.

This comment has been minimized.

@ogrisel

ogrisel Mar 20, 2017

Member

I think you should also mention univariate feature selection as an alternative way to extract sparse discriminative vocabularies.

Maybe you could even extend the example by adding a pipeline of a sparse uni variate feature selection model + l2 penalized logistic regression to showcase a classification model with similar sparsity level as the l1 penalized variant.

@ogrisel

ogrisel Mar 20, 2017

Member

I think you should also mention univariate feature selection as an alternative way to extract sparse discriminative vocabularies.

Maybe you could even extend the example by adding a pipeline of a sparse uni variate feature selection model + l2 penalized logistic regression to showcase a classification model with similar sparsity level as the l1 penalized variant.

+Performance of multinomial logistic regression with
+L1 penalty. We use the SAGA algorithm for this purpose, which is fast. Test
+accuracy reaches > 0.8, while weight vectors remains *sparse* and
+*interpretable*.

This comment has been minimized.

@ogrisel

ogrisel Mar 20, 2017

Member

Please add a note such that:

Note that this accuracy is far below what can be reached by an non-penalized linear model (I think ~0.93) but this should be checked and even more far below the accuracy non linear models such as a multi layer perceptron (0.98+).

@ogrisel

ogrisel Mar 20, 2017

Member

Please add a note such that:

Note that this accuracy is far below what can be reached by an non-penalized linear model (I think ~0.93) but this should be checked and even more far below the accuracy non linear models such as a multi layer perceptron (0.98+).

sklearn/linear_model/sag_fast.pyx
np.ndarray[double, ndim=2, mode='c'] sum_gradient_init,
np.ndarray[double, ndim=2, mode='c'] gradient_memory_init,
np.ndarray[bint, ndim=1, mode='c'] seen_init,
int num_seen,
bint fit_intercept,
np.ndarray[double, ndim=1, mode='c'] intercept_sum_gradient_init,
double intercept_decay,
+ bint saga,
bint verbose):
"""Stochastic Average Gradient (SAG) solver.

This comment has been minimized.

@ogrisel

ogrisel Mar 21, 2017

Member

Please update this docstring to make it explicit that this function implements both SAG and SAGA.

@ogrisel

ogrisel Mar 21, 2017

Member

Please update this docstring to make it explicit that this function implements both SAG and SAGA.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 22, 2017

Member

I pushed better checks for the l1 penalty logistic regression tests. I still want to review other parts of the code / tests but maybe we should do the Lasso part in a separate PR.

Member

ogrisel commented Mar 22, 2017

I pushed better checks for the l1 penalty logistic regression tests. I still want to review other parts of the code / tests but maybe we should do the Lasso part in a separate PR.

@ogrisel

I did a review and pushed some fixes (mostly missing updates in the documentation). I think we can merge this PR without waiting for elasticnet penalty and integration in the Lasso* classes.

@ogrisel ogrisel changed the title from [MRG] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) to [MRG+1] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) Mar 27, 2017

@ogrisel ogrisel changed the title from [MRG+1] ENH : SAGA support for LogisticRegression, Ridge and Lasso (+ L1 support) to [MRG+1] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge Mar 27, 2017

@TomDLT

TomDLT approved these changes Mar 27, 2017 edited

LGTM except for minor nitpicks

I think we can merge this PR without waiting for elasticnet penalty and integration in the Lasso* classes.

I agree

doc/modules/linear_model.rst
+Very Large dataset (`n_samples`) "sag" or "saga"
+================================= =====================================
+
+The "saga" solver is almost always a the best choice. The "liblinear"

This comment has been minimized.

@TomDLT

TomDLT Mar 27, 2017

Member

a

+regression yields more accurate results and is faster to train on the larger
+scale dataset.
+
+Here we use the l1 sparsity that trims the weights of no to informative

This comment has been minimized.

@TomDLT

TomDLT Mar 27, 2017

Member

-> of not informative

@TomDLT

TomDLT Mar 27, 2017

Member

-> of not informative

sklearn/linear_model/logistic.py
@@ -967,6 +976,9 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
Used to specify the norm used in the penalization. The 'newton-cg',
'sag' and 'lbfgs' solvers support only l2 penalties.
+ .. versionadded:: 0.19
+ l1 penalty with SAGA solver (allowing 'multinomial + L1)

This comment has been minimized.

@TomDLT

TomDLT Mar 27, 2017

Member

fix the '

@TomDLT

TomDLT Mar 27, 2017

Member

fix the '

sklearn/linear_model/sag_fast.pyx
@@ -6,16 +6,16 @@
# Tom Dupre la Tour <tom.dupre-la-tour@m4x.org>
#

This comment has been minimized.

@TomDLT

TomDLT Mar 27, 2017

Member

add author

@TomDLT

TomDLT Mar 27, 2017

Member

add author

@@ -860,6 +895,71 @@ def test_logreg_intercept_scaling_zero():
assert_equal(clf.intercept_, 0.)
+def test_logreg_l1():
+ # Because liblinear penalizes the intercept and saga does not, we do

This comment has been minimized.

@TomDLT

TomDLT Mar 27, 2017

Member

we do not

@TomDLT

TomDLT Mar 27, 2017

Member

we do not

+
+
+def test_logreg_l1_sparse_data():
+ # Because liblinear penalizes the intercept and saga does not, we do

This comment has been minimized.

@TomDLT

TomDLT Mar 27, 2017

Member

we do not

@TomDLT

TomDLT Mar 27, 2017

Member

we do not

@ogrisel ogrisel changed the title from [MRG+1] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge to [MRG+2] ENH : SAGA support for LogisticRegression (+ L1 support), Ridge Mar 27, 2017

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 27, 2017

Member

I rebased an squashed everything down to a single commit. If CI is still green, let's merge.

Member

ogrisel commented Mar 27, 2017

I rebased an squashed everything down to a single commit. If CI is still green, let's merge.

@ogrisel ogrisel merged commit 5147fd0 into scikit-learn:master Mar 27, 2017

5 checks passed

ci/circleci Your tests passed on CircleCI!
Details
codecov/patch 100% of diff hit (target 95.49%)
Details
codecov/project 95.5% (+0.01%) compared to 7877f3c
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Mar 27, 2017

Member

Merged! Thanks @arthurmensch!

Member

ogrisel commented Mar 27, 2017

Merged! Thanks @arthurmensch!

@arthurmensch

This comment has been minimized.

Show comment
Hide comment
@arthurmensch

arthurmensch Mar 27, 2017

Contributor
Contributor

arthurmensch commented Mar 27, 2017

@TomDLT

This comment has been minimized.

Show comment
Hide comment
@TomDLT

TomDLT Mar 27, 2017

Member

🍻

Member

TomDLT commented Mar 27, 2017

🍻

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Mar 28, 2017

Member
Member

GaelVaroquaux commented Mar 28, 2017

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Mar 28, 2017

Member
Member

jnothman commented Mar 28, 2017

@agramfort

This comment has been minimized.

Show comment
Hide comment
@agramfort

agramfort Mar 28, 2017

Member
Member

agramfort commented Mar 28, 2017

@fabianp

This comment has been minimized.

Show comment
Hide comment
@fabianp

fabianp Mar 28, 2017

Member

congrats @arthurmensch and co. I think this is a great example of development that started in scikit-learn-contrib and (with a lot of work and improvements) ended upstream.

Member

fabianp commented Mar 28, 2017

congrats @arthurmensch and co. I think this is a great example of development that started in scikit-learn-contrib and (with a lot of work and improvements) ended upstream.

massich added a commit to massich/scikit-learn that referenced this pull request Apr 26, 2017

Sundrique added a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017

NelleV added a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017

paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017

maskani-moh added a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017

jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment