Skip to content

Commit

Permalink
[MRG+2] Make enet_coordinate_descent_gram support fused types (#7218)
Browse files Browse the repository at this point in the history
* Make cd_gram support fused types

* Add test
  • Loading branch information
yenchenlin authored and ogrisel committed Aug 29, 2016
1 parent 5b20d48 commit 845e702
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 33 deletions.
74 changes: 46 additions & 28 deletions sklearn/linear_model/cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -528,11 +528,11 @@ def sparse_enet_coordinate_descent(floating [:] w,
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
np.ndarray[double, ndim=2, mode='c'] Q,
np.ndarray[double, ndim=1, mode='c'] q,
np.ndarray[double, ndim=1] y,
int max_iter, double tol, object rng,
def enet_coordinate_descent_gram(floating[:] w, floating alpha, floating beta,
np.ndarray[floating, ndim=2, mode='c'] Q,
np.ndarray[floating, ndim=1, mode='c'] q,
np.ndarray[floating, ndim=1] y,
int max_iter, floating tol, object rng,
bint random=0, bint positive=0):
"""Cython version of the coordinate descent algorithm
for Elastic-Net regression
Expand All @@ -546,34 +546,52 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
q = X^T y
"""

# fused types version of BLAS functions
cdef DOT dot
cdef AXPY axpy
cdef ASUM asum

if floating is float:
dtype = np.float32
dot = sdot
axpy = saxpy
asum = sasum
else:
dtype = np.float64
dot = ddot
axpy = daxpy
asum = dasum

# get the data information into easy vars
cdef unsigned int n_samples = y.shape[0]
cdef unsigned int n_features = Q.shape[0]

# initial value "Q w" which will be kept of up to date in the iterations
cdef double[:] H = np.dot(Q, w)
cdef floating[:] H = np.dot(Q, w)

cdef double[:] XtA = np.zeros(n_features)
cdef double tmp
cdef double w_ii
cdef double d_w_max
cdef double w_max
cdef double d_w_ii
cdef double gap = tol + 1.0
cdef double d_w_tol = tol
cdef double dual_norm_XtA
cdef floating[:] XtA = np.zeros(n_features, dtype=dtype)
cdef floating tmp
cdef floating w_ii
cdef floating d_w_max
cdef floating w_max
cdef floating d_w_ii
cdef floating q_dot_w
cdef floating w_norm2
cdef floating gap = tol + 1.0
cdef floating d_w_tol = tol
cdef floating dual_norm_XtA
cdef unsigned int ii
cdef unsigned int n_iter = 0
cdef unsigned int f_iter
cdef UINT32_t rand_r_state_seed = rng.randint(0, RAND_R_MAX)
cdef UINT32_t* rand_r_state = &rand_r_state_seed

cdef double y_norm2 = np.dot(y, y)
cdef double* w_ptr = <double*>&w[0]
cdef double* Q_ptr = &Q[0, 0]
cdef double* q_ptr = <double*>q.data
cdef double* H_ptr = &H[0]
cdef double* XtA_ptr = &XtA[0]
cdef floating y_norm2 = np.dot(y, y)
cdef floating* w_ptr = <floating*>&w[0]
cdef floating* Q_ptr = &Q[0, 0]
cdef floating* q_ptr = <floating*>q.data
cdef floating* H_ptr = &H[0]
cdef floating* XtA_ptr = &XtA[0]
tol = tol * y_norm2

if alpha == 0:
Expand All @@ -597,8 +615,8 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,

if w_ii != 0.0:
# H -= w_ii * Q[ii]
daxpy(n_features, -w_ii, Q_ptr + ii * n_features, 1,
H_ptr, 1)
axpy(n_features, -w_ii, Q_ptr + ii * n_features, 1,
H_ptr, 1)

tmp = q[ii] - H[ii]

Expand All @@ -610,8 +628,8 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,

if w[ii] != 0.0:
# H += w[ii] * Q[ii] # Update H = X.T X w
daxpy(n_features, w[ii], Q_ptr + ii * n_features, 1,
H_ptr, 1)
axpy(n_features, w[ii], Q_ptr + ii * n_features, 1,
H_ptr, 1)

# update the maximum absolute coefficient update
d_w_ii = fabs(w[ii] - w_ii)
Expand All @@ -627,7 +645,7 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
# criterion

# q_dot_w = np.dot(w, q)
q_dot_w = ddot(n_features, w_ptr, 1, q_ptr, 1)
q_dot_w = dot(n_features, w_ptr, 1, q_ptr, 1)

for ii in range(n_features):
XtA[ii] = q[ii] - H[ii] - beta * w[ii]
Expand All @@ -643,7 +661,7 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
R_norm2 = y_norm2 + tmp - 2.0 * q_dot_w

# w_norm2 = np.dot(w, w)
w_norm2 = ddot(n_features, &w[0], 1, &w[0], 1)
w_norm2 = dot(n_features, &w[0], 1, &w[0], 1)

if (dual_norm_XtA > alpha):
const = alpha / dual_norm_XtA
Expand All @@ -654,7 +672,7 @@ def enet_coordinate_descent_gram(double[:] w, double alpha, double beta,
gap = R_norm2

# The call to dasum is equivalent to the L1 norm of w
gap += (alpha * dasum(n_features, &w[0], 1) -
gap += (alpha * asum(n_features, &w[0], 1) -
const * y_norm2 + const * q_dot_w +
0.5 * beta * (1 + const ** 2) * w_norm2)

Expand Down
23 changes: 18 additions & 5 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,11 @@ def test_enet_float_precision():
for fit_intercept in [True, False]:
coef = {}
intercept = {}
clf = ElasticNet(alpha=0.5, max_iter=100, precompute=False,
fit_intercept=fit_intercept, normalize=normalize)
for dtype in [np.float64, np.float32]:
clf = ElasticNet(alpha=0.5, max_iter=100, precompute=False,
fit_intercept=fit_intercept,
normalize=normalize)

X = dtype(X)
y = dtype(y)
ignore_warnings(clf.fit)(X, y)
Expand All @@ -694,8 +696,19 @@ def test_enet_float_precision():

assert_equal(clf.coef_.dtype, dtype)

# test precompute Gram array
Gram = X.T.dot(X)
clf_precompute = ElasticNet(alpha=0.5, max_iter=100,
precompute=Gram,
fit_intercept=fit_intercept,
normalize=normalize)
ignore_warnings(clf_precompute.fit)(X, y)
assert_array_almost_equal(clf.coef_, clf_precompute.coef_)
assert_array_almost_equal(clf.intercept_,
clf_precompute.intercept_)

assert_array_almost_equal(coef[np.float32], coef[np.float64],
decimal=4)
decimal=4)
assert_array_almost_equal(intercept[np.float32],
intercept[np.float64],
decimal=4)
intercept[np.float64],
decimal=4)

0 comments on commit 845e702

Please sign in to comment.