Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup MultiTaskLasso #17021

Merged
merged 24 commits into from May 2, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 52 additions & 46 deletions sklearn/linear_model/_cd_fast.pyx
Expand Up @@ -19,7 +19,7 @@ from cython cimport floating
import warnings
from ..exceptions import ConvergenceWarning

from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
_copy, _scal)
from ..utils._cython_blas cimport RowMajor, ColMajor, Trans, NoTrans

Expand Down Expand Up @@ -154,7 +154,7 @@ def enet_coordinate_descent(floating[::1] w,
with nogil:
# R = y - np.dot(X, w)
_copy(n_samples, &y[0], 1, &R[0], 1)
_gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0],
_gemv(ColMajor, NoTrans, n_samples, n_features, -1.0, &X[0, 0],
n_samples, &w[0], 1, 1.0, &R[0], 1)

# tol *= np.dot(y, y)
Expand Down Expand Up @@ -622,8 +622,8 @@ def enet_coordinate_descent_gram(floating[::1] w,

def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeremiedbb I already see the syntax floating[::1, :] W in master, did I miss something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use memoryviews when we are sure that the array has been created in scikit-learn because then we know it's not read-only. When it's a user provided array (X, y), we don't have that guarantee and fused typed memoryviews don't work with read only arrays.

floating l2_reg,
np.ndarray[floating, ndim=2, mode='fortran'] X,
np.ndarray[floating, ndim=2] Y,
floating[::1, :] X,
floating[::1, :] Y,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't squeeze that into the release. We need to have cython 3.0 as min version.

int max_iter, floating tol, object rng,
bint random=0):
"""Cython version of the coordinate descent algorithm
Expand Down Expand Up @@ -651,11 +651,11 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
cdef floating dual_norm_XtA

# initial value of the residuals
cdef floating[:, ::1] R = np.zeros((n_samples, n_tasks), dtype=dtype)
cdef floating[::1, :] R = np.zeros((n_samples, n_tasks), dtype=dtype, order='F')

cdef floating[:] norm_cols_X = np.zeros(n_features, dtype=dtype)
cdef floating[::1] norm_cols_X = np.zeros(n_features, dtype=dtype)
cdef floating[::1] tmp = np.zeros(n_tasks, dtype=dtype)
cdef floating[:] w_ii = np.zeros(n_tasks, dtype=dtype)
cdef floating[::1] w_ii = np.zeros(n_tasks, dtype=dtype)
cdef floating d_w_max
cdef floating w_max
cdef floating d_w_ii
Expand All @@ -674,30 +674,24 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
cdef UINT32_t* rand_r_state = &rand_r_state_seed

cdef floating* X_ptr = &X[0, 0]
cdef floating* W_ptr = &W[0, 0]
cdef floating* Y_ptr = &Y[0, 0]
cdef floating* wii_ptr = &w_ii[0]

if l1_reg == 0:
warnings.warn("Coordinate descent with l1_reg=0 may lead to unexpected"
" results and is discouraged.")

with nogil:
# norm_cols_X = (np.asarray(X) ** 2).sum(axis=0)
for ii in range(n_features):
for jj in range(n_samples):
norm_cols_X[ii] += X[jj, ii] ** 2
norm_cols_X[ii] = _nrm2(n_samples, &X[0, ii], 1) ** 2

# R = Y - np.dot(X, W.T)
for ii in range(n_samples):
_copy(n_samples * n_tasks, &Y[0, 0], 1, &R[0, 0], 1)
for ii in range(n_features):
for jj in range(n_tasks):
R[ii, jj] = Y[ii, jj] - (
_dot(n_features, X_ptr + ii, n_samples, W_ptr + jj, n_tasks)
)
if W[jj, ii] != 0:
_axpy(n_samples, -W[jj, ii], &X[0, ii], 1, &R[0, jj], 1)

# tol = tol * linalg.norm(Y, ord='fro') ** 2
tol = tol * _nrm2(n_samples * n_tasks, Y_ptr, 1) ** 2
tol = tol * _nrm2(n_samples * n_tasks, &Y[0, 0], 1) ** 2

for n_iter in range(max_iter):
w_max = 0.0
Expand All @@ -712,42 +706,60 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
continue

# w_ii = W[:, ii] # Store previous value
_copy(n_tasks, W_ptr + ii * n_tasks, 1, wii_ptr, 1)
_copy(n_tasks, &W[0, ii], 1, &w_ii[0], 1)

# if np.sum(w_ii ** 2) != 0.0: # can do better
if _nrm2(n_tasks, wii_ptr, 1) != 0.0:
# if (w_ii[0] != 0.): # faster than testing full norm for non-zeros, yet unsafe
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm discovering the code so I lack context, but I'm not sure what this is supposed to mean.

In general I feel like "We don't do X because it wouldn't work" is mostly confusing because one would just wonder why we would even want to do X in the first place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think we can even remove the 2 comments.
if np.sum(w_ii ** 2) != 0.0: # can do better -> we do better so no need for that any more
# faster than testing full norm for non-zeros, yet unsafe -> nicola's argument

# Using Numpy:
# R += np.dot(X[:, ii][:, None], w_ii[None, :]) # rank 1 update
_ger(RowMajor, n_samples, n_tasks, 1.0,
X_ptr + ii * n_samples, 1,
wii_ptr, 1, &R[0, 0], n_tasks)

# Using Blas Level2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this whole comment block be indented to the left now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

# _ger(RowMajor, n_samples, n_tasks, 1.0,
# &X[0, ii], 1,
# &w_ii[0], 1, &R[0, 0], n_tasks)
# Using Blas Level1 and for loop for avoid slower threads
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Using Blas Level1 and for loop for avoid slower threads
# Using Blas Level1 and for loop to avoid slower threads

# for such small vectors
for jj in range(n_tasks):
if w_ii[jj] != 0:
_axpy(n_samples, w_ii[jj], &X[0, ii], 1, &R[0, jj], 1)

# Using numpy:
# tmp = np.dot(X[:, ii][None, :], R).ravel()
_gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
n_tasks, X_ptr + ii * n_samples, 1, 0.0, &tmp[0], 1)
# Using BLAS Level 2:
# _gemv(RowMajor, Trans, n_samples, n_tasks, 1.0, &R[0, 0],
# n_tasks, &X[0, ii], 1, 0.0, &tmp[0], 1)
# Using BLAS Level 1 (faster small vectors like here):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Using BLAS Level 1 (faster small vectors like here):
# Using BLAS Level 1 (faster for small vectors like here):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

for jj in range(n_tasks):
tmp[jj] = _dot(n_samples, &X[0, ii], 1, &R[0, jj], 1)

# nn = sqrt(np.sum(tmp ** 2))
nn = _nrm2(n_tasks, &tmp[0], 1)

# W[:, ii] = tmp * fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg)
_copy(n_tasks, &tmp[0], 1, W_ptr + ii * n_tasks, 1)
_copy(n_tasks, &tmp[0], 1, &W[0, ii], 1)
_scal(n_tasks, fmax(1. - l1_reg / nn, 0) / (norm_cols_X[ii] + l2_reg),
W_ptr + ii * n_tasks, 1)
&W[0, ii], 1)

# if np.sum(W[:, ii] ** 2) != 0.0: # can do better
if _nrm2(n_tasks, W_ptr + ii * n_tasks, 1) != 0.0:
# if (W[0, ii] != 0.): # faster than testing full col norm, but unsafe
# Using numpy:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

# R -= np.dot(X[:, ii][:, None], W[:, ii][None, :])
# Using BLAS Level 2:
# Update residual : rank 1 update
_ger(RowMajor, n_samples, n_tasks, -1.0,
X_ptr + ii * n_samples, 1, W_ptr + ii * n_tasks, 1,
&R[0, 0], n_tasks)
# _ger(RowMajor, n_samples, n_tasks, -1.0,
# &X[0, ii], 1, &W[0, ii], 1,
# &R[0, 0], n_tasks)
# Using BLAS Level 1 (faster small vectors like here):
for jj in range(n_tasks):
if W[jj, ii] != 0:
_axpy(n_samples, -W[jj, ii], &X[0, ii], 1, &R[0, jj], 1)

# update the maximum absolute coefficient update
d_w_ii = diff_abs_max(n_tasks, W_ptr + ii * n_tasks, wii_ptr)
d_w_ii = diff_abs_max(n_tasks, &W[0, ii], &w_ii[0])

if d_w_ii > d_w_max:
d_w_max = d_w_ii

W_ii_abs_max = abs_max(n_tasks, W_ptr + ii * n_tasks)
W_ii_abs_max = abs_max(n_tasks, &W[0, ii])
if W_ii_abs_max > w_max:
w_max = W_ii_abs_max

Expand All @@ -760,24 +772,22 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
for ii in range(n_features):
for jj in range(n_tasks):
XtA[ii, jj] = _dot(
n_samples, X_ptr + ii * n_samples, 1,
&R[0, 0] + jj, n_tasks
n_samples, &X[0, ii], 1, &R[0, jj], 1
) - l2_reg * W[jj, ii]

# dual_norm_XtA = np.max(np.sqrt(np.sum(XtA ** 2, axis=1)))
dual_norm_XtA = 0.0
for ii in range(n_features):
# np.sqrt(np.sum(XtA ** 2, axis=1))
XtA_axis1norm = _nrm2(n_tasks,
&XtA[0, 0] + ii * n_tasks, 1)
XtA_axis1norm = _nrm2(n_tasks, &XtA[ii, 0], 1)
if XtA_axis1norm > dual_norm_XtA:
dual_norm_XtA = XtA_axis1norm

# TODO: use squared L2 norm directly
# R_norm = linalg.norm(R, ord='fro')
# w_norm = linalg.norm(W, ord='fro')
R_norm = _nrm2(n_samples * n_tasks, &R[0, 0], 1)
w_norm = _nrm2(n_features * n_tasks, W_ptr, 1)
w_norm = _nrm2(n_features * n_tasks, &W[0, 0], 1)
if (dual_norm_XtA > l1_reg):
const = l1_reg / dual_norm_XtA
A_norm = R_norm * const
Expand All @@ -787,16 +797,12 @@ def enet_coordinate_descent_multi_task(floating[::1, :] W, floating l1_reg,
gap = R_norm ** 2

# ry_sum = np.sum(R * y)
ry_sum = 0.0
for ii in range(n_samples):
for jj in range(n_tasks):
ry_sum += R[ii, jj] * Y[ii, jj]
ry_sum = _dot(n_samples * n_tasks, &R[0, 0], 1, &Y[0, 0], 1)

# l21_norm = np.sqrt(np.sum(W ** 2, axis=0)).sum()
l21_norm = 0.0
for ii in range(n_features):
# np.sqrt(np.sum(W ** 2, axis=0))
l21_norm += _nrm2(n_tasks, W_ptr + n_tasks * ii, 1)
l21_norm += _nrm2(n_tasks, &W[0, ii], 1)

gap += l1_reg * l21_norm - const * ry_sum + \
0.5 * l2_reg * (1 + const ** 2) * (w_norm ** 2)
Expand Down
31 changes: 15 additions & 16 deletions sklearn/linear_model/_coordinate_descent.py
Expand Up @@ -1723,9 +1723,9 @@ class MultiTaskElasticNet(Lasso):

Where::

||W||_21 = sum_i sqrt(sum_j w_ij ^ 2)
||W||_21 = sum_i sqrt(sum_j W_ij ^ 2)

i.e. the sum of norm of each row.
i.e. the sum of norms of each row.

Read more in the :ref:`User Guide <multi_task_elastic_net>`.

Expand Down Expand Up @@ -1819,8 +1819,8 @@ class MultiTaskElasticNet(Lasso):
-----
The algorithm used to fit the model is coordinate descent.

To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
@_deprecate_positional_args
def __init__(self, alpha=1.0, *, l1_ratio=0.5, fit_intercept=True,
Expand Down Expand Up @@ -1857,12 +1857,11 @@ def fit(self, X, y):
To avoid memory re-allocation it is advised to allocate the
initial data in memory directly using that format.
"""

# Need to validate separately here.
# We can't pass multi_ouput=True because that would allow y to be csr.
check_X_params = dict(dtype=[np.float64, np.float32], order='F',
copy=self.copy_X and self.fit_intercept)
check_y_params = dict(ensure_2d=False)
check_y_params = dict(ensure_2d=False, order='F')
X, y = self._validate_data(X, y, validate_separately=(check_X_params,
check_y_params))
y = y.astype(X.dtype)
Expand Down Expand Up @@ -1990,13 +1989,13 @@ class MultiTaskLasso(MultiTaskElasticNet):
--------
>>> from sklearn import linear_model
>>> clf = linear_model.MultiTaskLasso(alpha=0.1)
>>> clf.fit([[0,0], [1, 1], [2, 2]], [[0, 0], [1, 1], [2, 2]])
>>> clf.fit([[0, 1], [1, 2], [2, 4]], [[0, 0], [1, 1], [2, 3]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this change needed?
just trying to understand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a duplicated feature resulted in the second coefficient being 0, but with numerical errors it was 1e-16

MultiTaskLasso(alpha=0.1)
>>> print(clf.coef_)
[[0.89393398 0. ]
[0.89393398 0. ]]
[[0. 0.60809415]
[0. 0.94592424]]
>>> print(clf.intercept_)
[0.10606602 0.10606602]
[-0.41888636 -0.87382323]

See also
--------
Expand All @@ -2008,8 +2007,8 @@ class MultiTaskLasso(MultiTaskElasticNet):
-----
The algorithm used to fit the model is coordinate descent.

To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
@_deprecate_positional_args
def __init__(self, alpha=1.0, *, fit_intercept=True, normalize=False,
Expand Down Expand Up @@ -2186,8 +2185,8 @@ class MultiTaskElasticNetCV(RegressorMixin, LinearModelCV):
-----
The algorithm used to fit the model is coordinate descent.

To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
path = staticmethod(enet_path)

Expand Down Expand Up @@ -2358,8 +2357,8 @@ class MultiTaskLassoCV(RegressorMixin, LinearModelCV):
-----
The algorithm used to fit the model is coordinate descent.

To avoid unnecessary memory duplication the X argument of the fit method
should be directly passed as a Fortran-contiguous numpy array.
To avoid unnecessary memory duplication the X and y arguments of the fit
method should be directly passed as Fortran-contiguous numpy arrays.
"""
path = staticmethod(lasso_path)

Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Expand Up @@ -879,9 +879,9 @@ def test_convergence_warnings():
X = random_state.standard_normal((1000, 500))
y = random_state.standard_normal((1000, 3))

# check that the model fails to converge
# check that the model fails to converge (a negative dual gap cannot occur)
with pytest.warns(ConvergenceWarning):
MultiTaskElasticNet(max_iter=1, tol=0).fit(X, y)
MultiTaskElasticNet(max_iter=1, tol=-1).fit(X, y)
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved

# check that the model converges w/o warnings
with pytest.warns(None) as record:
Expand Down