Skip to content

Commit

Permalink
Speedup MultiTaskLasso (#17021)
Browse files Browse the repository at this point in the history
  • Loading branch information
agramfort committed May 2, 2020
1 parent c71a1c2 commit 04d2e32
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 72 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ Changelog
using joblib loky backend. :pr:`14264` by
:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |Efficiency| Speed up :class:`linear_model.MultiTaskLasso`,
:class:`linear_model.MultiTaskLassoCV`, :class:`linear_model.MultiTaskElasticNet`,
:class:`linear_model.MultiTaskElasticNetCV` by avoiding slower
BLAS Level 2 calls on small arrays
:pr:`17021` by :user:`Alex Gramfort <agramfort>` and
:user:`Mathurin Massias <mathurinm>`.

:mod:`sklearn.metrics`
......................

Expand Down
116 changes: 62 additions & 54 deletions sklearn/linear_model/_cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from cython cimport floating
import warnings
from ..exceptions import ConvergenceWarning

from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
_copy, _scal)
from ..utils._cython_blas cimport RowMajor, ColMajor, Trans, NoTrans

Expand Down Expand Up @@ -154,7 +154,7 @@ def enet_coordinate_descent(floating[::1] w,
with nogil:
# R = y - np.dot(X, w)
_copy(n_samples, &y[0], 1, &R[0], 1)
_gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0],
_gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0],
n_samples, &w[0], 1, 1.0, &R[0], 1)

# tol *= np.dot(y, y)
Expand Down Expand Up @@ -620,18 +620,17 @@ def enet_coordinate_descent_gram(floating[::1] w,
return np.asarray(w), gap, tol, n_iter + 1


def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
floating l2_reg,
np.ndarray[floating, ndim=2, mode='fortran'] X,
np.ndarray[floating, ndim=2] Y,
int max_iter, floating tol, object rng,
bint random=0):
def enet_coordinate_descent_multi_task(
floating[::1, :] W, floating l1_reg, floating l2_reg,
np.ndarray[floating, ndim=2, mode='fortran'] X, # TODO: use views in 0.24
np.ndarray[floating, ndim=2, mode='fortran'] Y,
int max_iter, floating tol, object rng, bint random=0):
"""Cython version of the coordinate descent algorithm
for Elastic-Net mult-task regression
We minimize
(1/2) * norm(y - X w, 2)^2 + l1_reg ||w||_21 + (1/2) * l2_reg norm(w, 2)^2
0.5 * norm(Y - X W.T, 2)^2 + l1_reg ||W.T||_21 + 0.5 * l2_reg norm(W.T, 2)^2
"""

Expand All @@ -651,11 +650,11 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
cdef floating dual_norm_XtA

# initial value of the residuals
cdef floating[:, ::1] R = np.zeros((n_samples, n_tasks), dtype=dtype)
cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F')

cdef floating[:] norm_cols_X = np.zeros(n_features, dtype=dtype)
cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype)
cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype)
cdef floating[:] w_ii = np.zeros(n_tasks, dtype=dtype)
cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype)
cdef floating d_w_max
cdef floating w_max
cdef floating d_w_ii
Expand All @@ -675,9 +674,7 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
cdef UINT32_t* rand_r_state = &rand_r_state_seed

cdef floating* X_ptr = &X[0, 0]
cdef floating* W_ptr = &W[0, 0]
cdef floating* Y_ptr = &Y[0, 0]
cdef floating* wii_ptr = &w_ii[0]

if l1_reg == 0:
warnings.warn("Coordinate descent with l1_reg=0 may lead to unexpected"
Expand All @@ -686,15 +683,15 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
with nogil:
# norm_cols_X = (np.asarray(X) ** 2).sum(axis=0)
for ii in range(n_features):
for jj in range(n_samples):
norm_cols_X[ii] += X[jj, ii] ** 2
norm_cols_X[ii] = _nrm2(n_samples, X_ptr + ii * n_samples, 1) ** 2

# R = Y - np.dot(X, W.T)
for ii in range(n_samples):
_copy(n_samples * n_tasks, Y_ptr, 1, &R[0, 0], 1)
for ii in range(n_features):
for jj in range(n_tasks):
R[ii, jj] = Y[ii, jj] - (
_dot(n_features, X_ptr + ii, n_samples, W_ptr + jj, n_tasks)
)
if W[jj, ii] != 0:
_axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1,
&R[0, jj], 1)

# tol = tol * linalg.norm(Y, ord='fro') ** 2
tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2
Expand All @@ -712,42 +709,59 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
continue

# w_ii = W[:, ii] # Store previous value
_copy(n_tasks, W_ptr + ii * n_tasks, 1, wii_ptr, 1)

# if np.sum(w_ii ** 2) != 0.0: # can do better
if _nrm2(n_tasks, wii_ptr, 1) != 0.0:
# R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update
_ger(RowMajor, n_samples, n_tasks, 1.0,
X_ptr + ii * n_samples, 1,
wii_ptr, 1, &R[0, 0], n_tasks)

_copy(n_tasks, &W[0, ii], 1, &w_ii[0], 1)

# Using Numpy:
# R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update
# Using Blas Level2:
# _ger(RowMajor, n_samples, n_tasks, 1.0,
# &X[0, ii], 1,
# &w_ii[0], 1, &R[0, 0], n_tasks)
# Using Blas Level1 and for loop to avoid slower threads
# for such small vectors
for jj in range(n_tasks):
if w_ii[jj] != 0:
_axpy(n_samples, w_ii[jj], X_ptr + ii * n_samples, 1,
&R[0, jj], 1)

# Using numpy:
# tmp = np.dot(X[:, ii][None, :], R).ravel()
_gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
n_tasks, X_ptr + ii * n_samples, 1, 0.0, &tmp[0], 1)
# Using BLAS Level 2:
# _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
# n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1)
# Using BLAS Level 1 (faster for small vectors like here):
for jj in range(n_tasks):
tmp[jj] = _dot(n_samples, X_ptr + ii * n_samples, 1,
&R[0, jj], 1)

# nn = sqrt(np.sum(tmp ** 2))
nn = _nrm2(n_tasks, &tmp[0], 1)

# W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg)
_copy(n_tasks, &tmp[0], 1, W_ptr + ii * n_tasks, 1)
_copy(n_tasks, &tmp[0], 1, &W[0, ii], 1)
_scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg),
W_ptr + ii * n_tasks, 1)

# if np.sum(W[:, ii] ** 2) != 0.0: # can do better
if _nrm2(n_tasks, W_ptr + ii * n_tasks, 1) != 0.0:
# R -= np.dot(X[:, ii][:, None], W[:, ii][None, :])
# Update residual : rank 1 update
_ger(RowMajor, n_samples, n_tasks, -1.0,
X_ptr + ii * n_samples, 1, W_ptr + ii * n_tasks, 1,
&R[0, 0], n_tasks)
&W[0, ii], 1)

# Using numpy:
# R -= np.dot(X[:, ii][:, None], W[:, ii][None, :])
# Using BLAS Level 2:
# Update residual : rank 1 update
# _ger(RowMajor, n_samples, n_tasks, -1.0,
# &X[0, ii], 1, &W[0, ii], 1,
# &R[0, 0], n_tasks)
# Using BLAS Level 1 (faster for small vectors like here):
for jj in range(n_tasks):
if W[jj, ii] != 0:
_axpy(n_samples, -W[jj, ii], X_ptr + ii * n_samples, 1,
&R[0, jj], 1)

# update the maximum absolute coefficient update
d_w_ii = diff_abs_max(n_tasks, W_ptr + ii * n_tasks, wii_ptr)
d_w_ii = diff_abs_max(n_tasks, &W[0, ii], &w_ii[0])

if d_w_ii > d_w_max:
d_w_max = d_w_ii

W_ii_abs_max = abs_max(n_tasks, W_ptr + ii * n_tasks)
W_ii_abs_max = abs_max(n_tasks, &W[0, ii])
if W_ii_abs_max > w_max:
w_max = W_ii_abs_max

Expand All @@ -760,24 +774,22 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
for ii in range(n_features):
for jj in range(n_tasks):
XtA[ii, jj] = _dot(
n_samples, X_ptr + ii * n_samples, 1,
&R[0, 0] + jj, n_tasks
n_samples, X_ptr + ii * n_samples, 1, &R[0, jj], 1
) - l2_reg * W[jj, ii]

# dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1)))
dual_norm_XtA = 0.0
for ii in range(n_features):
# np.sqrt(np.sum(XtA ** 2, axis=1))
XtA_axis1norm = _nrm2(n_tasks,
&XtA[0, 0] + ii * n_tasks, 1)
XtA_axis1norm = _nrm2(n_tasks, &XtA[ii, 0], 1)
if XtA_axis1norm > dual_norm_XtA:
dual_norm_XtA = XtA_axis1norm

# TODO: use squared L2 norm directly
# R_norm = linalg.norm(R, ord='fro')
# w_norm = linalg.norm(W, ord='fro')
R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1)
w_norm = _nrm2(n_features * n_tasks, W_ptr, 1)
w_norm = _nrm2(n_features * n_tasks, &W[0, 0], 1)
if (dual_norm_XtA > l1_reg):
const = l1_reg / dual_norm_XtA
A_norm = R_norm * const
Expand All @@ -787,16 +799,12 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
gap = R_norm ** 2

# ry_sum = np.sum(R * y)
ry_sum = 0.0
for ii in range(n_samples):
for jj in range(n_tasks):
ry_sum += R[ii, jj] * Y[ii, jj]
ry_sum = _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1)

# l21_norm = np.sqrt(np.sum(W ** 2, axis=0)).sum()
l21_norm = 0.0
for ii in range(n_features):
# np.sqrt(np.sum(W ** 2, axis=0))
l21_norm += _nrm2(n_tasks, W_ptr + n_tasks * ii, 1)
l21_norm += _nrm2(n_tasks, &W[0, ii], 1)

gap += l1_reg * l21_norm - const * ry_sum + \
0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2)
Expand Down
31 changes: 15 additions & 16 deletions sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,9 +1733,9 @@ class MultiTaskElasticNet(Lasso):
Where::
||W||_21 = sum_i sqrt(sum_j w_ij ^ 2)
||W||_21 = sum_i sqrt(sum_j W_ij ^ 2)
i.e. the sum of norm of each row.
i.e. the sum of norms of each row.
Read more in the :ref:`User Guide <multi_task_elastic_net>`.
Expand Down Expand Up @@ -1829,8 +1829,8 @@ class MultiTaskElasticNet(Lasso):
-----
The algorithm used to fit the model is coordinate descent.
To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
@_deprecate_positional_args
def __init__(self, alpha=1.0, *, l1_ratio=0.5, fit_intercept=True,
Expand Down Expand Up @@ -1867,12 +1867,11 @@ def fit(self, X, y):
To avoid memory re-allocation it is advised to allocate the
initial data in memory directly using that format.
"""

# Need to validate separately here.
# We can't pass multi_ouput=True because that would allow y to be csr.
check_X_params = dict(dtype=[np.float64, np.float32], order='F',
copy=self.copy_X and self.fit_intercept)
check_y_params = dict(ensure_2d=False)
check_y_params = dict(ensure_2d=False, order='F')
X, y = self._validate_data(X, y, validate_separately=(check_X_params,
check_y_params))
y = y.astype(X.dtype)
Expand Down Expand Up @@ -2000,13 +1999,13 @@ class MultiTaskLasso(MultiTaskElasticNet):
--------
>>> from sklearn import linear_model
>>> clf = linear_model.MultiTaskLasso(alpha=0.1)
>>> clf.fit([[0,0], [1, 1], [2, 2]], [[0, 0], [1, 1], [2, 2]])
>>> clf.fit([[0, 1], [1, 2], [2, 4]], [[0, 0], [1, 1], [2, 3]])
MultiTaskLasso(alpha=0.1)
>>> print(clf.coef_)
[[0.89393398 0. ]
[0.89393398 0. ]]
[[0. 0.60809415]
[0. 0.94592424]]
>>> print(clf.intercept_)
[0.10606602 0.10606602]
[-0.41888636 -0.87382323]
See also
--------
Expand All @@ -2018,8 +2017,8 @@ class MultiTaskLasso(MultiTaskElasticNet):
-----
The algorithm used to fit the model is coordinate descent.
To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
@_deprecate_positional_args
def __init__(self, alpha=1.0, *, fit_intercept=True, normalize=False,
Expand Down Expand Up @@ -2196,8 +2195,8 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV):
-----
The algorithm used to fit the model is coordinate descent.
To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
path = staticmethod(enet_path)

Expand Down Expand Up @@ -2368,8 +2367,8 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV):
-----
The algorithm used to fit the model is coordinate descent.
To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
path = staticmethod(lasso_path)

Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,9 +882,9 @@ def test_convergence_warnings():
X = random_state.standard_normal((1000, 500))
y = random_state.standard_normal((1000, 3))

# check that the model fails to converge
# check that the model fails to converge (a negative dual gap cannot occur)
with pytest.warns(ConvergenceWarning):
MultiTaskElasticNet(max_iter=1, tol=0).fit(X, y)
MultiTaskElasticNet(max_iter=1, tol=-1).fit(X, y)

# check that the model converges w/o warnings
with pytest.warns(None) as record:
Expand Down

0 comments on commit 04d2e32

Please sign in to comment.