Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup MultiTaskLasso #17021

Merged
merged 24 commits into from
May 2, 2020
Merged

Speedup MultiTaskLasso #17021

merged 24 commits into from
May 2, 2020

Conversation

agramfort
Copy link
Member

This is for now WIP.

A good PR for @jeremiedbb as it's related to findings in kmeans.

Basically doing some Blas Level 2 calls in enet_coordinate_descent_multi_task
slows things down as it's called on small vectors.

It needs more tests / benchmarks but I have a 10x speed up on a relevant dataset for my research.

done with @mathurinm

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense since n_tasks is usually small. BLAS libraries try to adjust their threading scheme according to the sizes of the matrices involved but I think there's still a lot of work to do in that direction.


# if np.sum(W[:, ii] ** 2) != 0.0: # can do better
if _nrm2(n_tasks, W_ptr + ii * n_tasks, 1) != 0.0:
if (W[0, ii] != 0.): # faster than testing full col norm
Copy link
Member

@jeremiedbb jeremiedbb Apr 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not equivalent.

  • before: equivalent to test if all the W[:, ii] == 0. If so, no need to compute the dot. Makes sense
  • after: only check if W[0, ii] == 0.

I think you miss some computations here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. It's something we haven't cleaned up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's actually on purpose as it's unlikely that you have w_ii[0]==0 and some values != in w_ii[1:] so it should do as well and avoid computations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What really needs to be avoided is the computation of the norm of W[:, ii]. I think a safe and not too heavy way to do the check is just a loop over W[:, ii], (as in https://github.com/mathurinm/celer/blob/master/celer/multitask_fast.pyx#L360)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's unlikely

Unlikely is not always ? If it can happen that only w_ii[0]==0, the computation would be wrong. I like @mathurinm's idea.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented

@mathurinm
Copy link
Contributor

I changed the doctest because it was a pathological case with a duplicated feature. It failed because of numerical errors (1e-16 instead of exact 0)

@mathurinm
Copy link
Contributor

The test uses an alpha which is superior to alpha_max:

In [7]: %paste                                                                  
        random_state = np.random.RandomState(0)
        X = random_state.standard_normal((1000, 500))
        y = random_state.standard_normal((1000, 3))
## -- End pasted text --

In [8]: np.max(norm(X.T @ y, axis=1)) / len(y)                                  
Out[8]: 0.12329191379277193

So you can (and should, up to numerical errors) get a duality gap of 0 even with 0 iteration.

I suggest changing tol to -1 in the test? Or using a smaller alpha.

@agramfort
Copy link
Member Author

agramfort commented Apr 25, 2020 via email

@mathurinm
Copy link
Contributor

mathurinm commented Apr 25, 2020

The current test is wrong in master and passes only because of numerical errors: we fit with alpha > alpha_max so we should get a duality gap of 0, and no convergence warning.

In master, we get a gap of 1e-13 so a warning is raised. My belief is that in the PR there is no numerical error and we get exactly 0

@agramfort
Copy link
Member Author

agramfort commented Apr 26, 2020

a benchmark to confirm the speed gains:

                              time (old) / time (new)
n_samples n_features n_tasks
100       300        2                      11.658016
                     10                      4.530553
                     20                      2.371121
                     50                      1.337869
          1000       2                       5.732973
                     10                      4.290821
                     20                      3.245888
                     50                      2.114958
          4000       2                       7.559168
                     10                      4.522091
                     20                      1.930551
                     50                      1.676207
500       300        2                       7.787371
                     10                      4.261558
                     20                      9.872827
                     50                      6.685665
          1000       2                      28.021876
                     10                      8.992789
                     20                     12.636008
                     50                      7.281548
          4000       2                      14.004219
                     10                      7.636893
                     20                     14.989917
                     50                      9.752695

code https://gist.github.com/agramfort/ca54cc1bc12a37d7a426a7799cc236ce

you can see that the speed gains are not minor.

@agramfort agramfort marked this pull request as ready for review April 26, 2020 12:38
@agramfort
Copy link
Member Author

ready for review @jeremiedbb et al :)

maybe we can even squeeze this in before the release ...

Comment on lines 625 to 626
floating[::1, :] X,
floating[::1, :] Y,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't squeeze that into the release. We need to have cython 3.0 as min version.

@@ -881,7 +881,7 @@ def test_convergence_warnings():

# check that the model fails to converge
with pytest.warns(ConvergenceWarning):
MultiTaskElasticNet(max_iter=1, tol=0).fit(X, y)
MultiTaskElasticNet(alpha=0.001, max_iter=1, tol=0).fit(X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is changing alpha necessary here ? It converges with only 1 iteration with default alpha ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default alpha is superior to alpha_max for this pair (X, y). So a duality gap of 0 (up to numerical errors) is reached in 0 iterations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must be misunderstanding something but the test does not fail on master. How can it fail with this PR ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is tricky.
The test should fail on master. It passes only because of numerical errors (a duality gap of 1e-13 instead of 0 is returned).
I believe it fails on the PR because a duality gap of 0 is returned (which is the correct value)

Run:

import numpy as np
from sklearn.linear_model import MultiTaskElasticNet

random_state = np.random.RandomState(0)
X = random_state.standard_normal((1000, 500))
y = random_state.standard_normal((1000, 3))

clf = MultiTaskElasticNet(max_iter=1, tol=0).fit(X, y)
print(clf.dual_gap_)  # should be 0, is 1e-13
print(np.linalg.norm(clf.coef_))   # we have converged. The initialization is already the solution bc alpha >= alpha_max

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. tol=0 is rarely a good idea because it's subject to numerical precision issues. We can change the tol or the alpha. But in either case could you add a comment to explain that ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO It is more logical to use tol=-1 (eg, the value of alpha we have used could still be >= alpha_max (it is not))

Where should I put the comment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just above this line, when creating the model instance.

@jeremiedbb
Copy link
Member

I'm happy with the current state of this PR but as I said in a comment, we can't merge the switch to memoryviews until we bump our cython min version to 0.3 (which is still an alpha).

So we can either wait a little bit which means it will be for the next release, or put back the ndarrays and we can merge it for this release. I'm fine with both, your call.

@mathurinm
Copy link
Contributor

Putting back the ndarrays only means changing the signature of the cython function, we could still use W[ii, jj] ?
If so I think we can put back the ndarrays to ship it faster and switch to memoryviews after the release.

@@ -622,8 +622,8 @@ def enet_coordinate_descent_gram(floating[::1] w,

def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeremiedbb I already see the syntax floating[::1, :] W in master, did I miss something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use memoryviews when we are sure that the array has been created in scikit-learn because then we know it's not read-only. When it's a user provided array (X, y), we don't have that guarantee and fused typed memoryviews don't work with read only arrays.

@jeremiedbb
Copy link
Member

Putting back the ndarrays only means changing the signature of the cython function, we could still use W[ii, jj] ?

I'm not sure about that, if I remember correctly, it's slower to access ndarray elements. So it would mean keeping the pointers. We'd have to rerun the benchmarks to make sure.

@mathurinm
Copy link
Contributor

I reverted to np.ndarray for X and Y. @agramfort can you check that the benchmarks are still favorable?
We can release this version and revert to memoryviews afterwards.

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, made a first pass

@@ -1990,13 +1989,13 @@ class MultiTaskLasso(MultiTaskElasticNet):
--------
>>> from sklearn import linear_model
>>> clf = linear_model.MultiTaskLasso(alpha=0.1)
>>> clf.fit([[0,0], [1, 1], [2, 2]], [[0, 0], [1, 1], [2, 2]])
>>> clf.fit([[0, 1], [1, 2], [2, 4]], [[0, 0], [1, 1], [2, 3]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this change needed?
just trying to understand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a duplicated feature resulted in the second coefficient being 0, but with numerical errors it was 1e-16

# _ger(RowMajor, n_samples, n_tasks, 1.0,
# &X[0, ii], 1,
# &w_ii[0], 1, &R[0, 0], n_tasks)
# Using Blas Level1 and for loop for avoid slower threads
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Using Blas Level1 and for loop for avoid slower threads
# Using Blas Level1 and for loop to avoid slower threads


# if np.sum(w_ii ** 2) != 0.0: # can do better
if _nrm2(n_tasks, wii_ptr, 1) != 0.0:
# if (w_ii[0] != 0.): # faster than testing full norm for non-zeros, yet unsafe
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm discovering the code so I lack context, but I'm not sure what this is supposed to mean.

In general I feel like "We don't do X because it wouldn't work" is mostly confusing because one would just wonder why we would even want to do X in the first place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think we can even remove the 2 comments.
if np.sum(w_ii ** 2) != 0.0: # can do better -> we do better so no need for that any more
# faster than testing full norm for non-zeros, yet unsafe -> nicola's argument

X_ptr + ii * n_samples, 1,
wii_ptr, 1, &R[0, 0], n_tasks)

# Using Blas Level2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this whole comment block be indented to the left now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

# Using BLAS Level 2:
# _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
# n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1)
# Using BLAS Level 1 (faster small vectors like here):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Using BLAS Level 1 (faster small vectors like here):
# Using BLAS Level 1 (faster for small vectors like here):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Comment on lines 748 to 749
# if (W[0, ii] != 0.): # faster than testing full col norm, but unsafe
# Using numpy:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

@agramfort
Copy link
Member Author

@mathurinm your latest changes LGTM but can you rerun my bench on this version to see what is the speedup with the last version of the code?

@mathurinm
Copy link
Contributor

On my machine the gain for 50 tasks is not always present.

                              time (old) / time (new)
n_samples n_features n_tasks                         
100       300        2                       5.027327
                     10                      3.161739
                     20                      1.661380
                     50                      0.887943
          1000       2                       4.375420
                     10                      2.920799
                     20                      1.573692
                     50                      0.828126
          4000       2                       3.144004
                     10                      2.019573
                     20                      1.055497
                     50                      0.774573
500       300        2                      12.277192
                     10                      5.607055
                     20                      6.842124
                     50                      2.947654
          1000       2                       9.618764
                     10                      3.642353
                     20                      4.788487
                     50                      1.909268
          4000       2                       6.168548
                     10                      4.363385
                     20                      4.574179
                     50                      2.849952

@agramfort
Copy link
Member Author

great. good to go from my end.

needs 2nd approval

@mathurinm
Copy link
Contributor

@NicolasHug thank you for the review, does it look good to you now?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mathurinm @agramfort

Maybe we could add a TODO comment and/or open an issue about using views instead of numpy arrays in these places, since it does seem that the benchmarks with the views was even faster

Also this probably needs a what's new?

I'm not sure what to think about the small regressions with n_samples=100 and n_tasks=50. I guess if the absolute running time is pretty slow already, that's still OK?

@agramfort
Copy link
Member Author

done @NicolasHug

@@ -319,6 +319,13 @@ Changelog
random noise to the target. This might help with stability in some edge
cases. :pr:`15179` by :user:`angelaambroz`.

- |Efficiency| Speed up :class:`linear_model.MultiTaskLasso`,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this can make it to 0.23

@adrinjalali this is ready for a merge, so your choice basically ;)

@mathurinm
Copy link
Contributor

mathurinm commented May 2, 2020

For completeness, here are the old (aster) running times on my machine:

n_samples n_features n_tasks            
100       300        2          0.343438
                     10         0.825290
                     20         0.848199
                     50         1.225283
          1000       2          0.794713
                     10         2.052545
                     20         2.241742
                     50         3.235364
          4000       2          1.767550
                     10         4.096383
                     20         4.637427
                     50         8.098676
500       300        2          0.045238
                     10         0.090502
                     20         0.274975
                     50         0.305215
          1000       2          1.683292
                     10         5.846135
                     20        16.392532
                     50        31.372721
          4000       2         13.146297
                     10        37.338017
                     20        84.895918
                     50       127.569109

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also +1 for be part of 0.23 if not too late @adrinjalali. Merging and tagging #17010 but feel free to let us know if you would rather move to 0.24 instead.

@ogrisel ogrisel merged commit 04d2e32 into scikit-learn:master May 2, 2020
@ogrisel
Copy link
Member

ogrisel commented May 2, 2020

Thank you @mathurinm!

@agramfort
Copy link
Member Author

agramfort commented May 2, 2020 via email

@adrinjalali
Copy link
Member

adrinjalali commented May 2, 2020

I was hoping this one would get in, I intentionally hadn't tagged yet :) #17010

adrinjalali pushed a commit to adrinjalali/scikit-learn that referenced this pull request May 4, 2020
adrinjalali pushed a commit that referenced this pull request May 5, 2020
gio8tisu pushed a commit to gio8tisu/scikit-learn that referenced this pull request May 15, 2020
viclafargue pushed a commit to viclafargue/scikit-learn that referenced this pull request Jun 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants