Skip to content

Commit

Permalink
ENH : speeding up the LARS
Browse files Browse the repository at this point in the history
  • Loading branch information
agramfort committed Sep 20, 2010
1 parent e8cc69e commit f1530fb
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 91 deletions.
178 changes: 91 additions & 87 deletions scikits/learn/glm/lars.py
Expand Up @@ -8,16 +8,15 @@
#
# License: BSD Style.

from math import fabs, sqrt
import numpy as np
from scipy import linalg
import scipy.sparse as sp # needed by LeastAngleRegression

from .base import LinearModel
from ..utils.fixes import copysign
from ..utils import arrayfuncs

def lars_path(X, y, Gram=None, max_iter=None, alpha_min=0,
method="lar", precompute=True):
method="lar"):
""" Compute Least Angle Regression and LASSO path
Parameters
Expand Down Expand Up @@ -62,8 +61,6 @@ def lars_path(X, y, Gram=None, max_iter=None, alpha_min=0,
"""
# TODO: detect stationary points.
# Lasso variant
# store full path

X = np.atleast_2d(X)
y = np.atleast_1d(y)
Expand All @@ -78,66 +75,68 @@ def lars_path(X, y, Gram=None, max_iter=None, alpha_min=0,
# because of some restrictions in Cython, boolean values are
# simulated using np.int8

beta = np.zeros ((max_iter + 1, X.shape[1]))
alphas = np.zeros (max_iter + 1)
betas = np.zeros((max_iter + 1, n_features))
alphas = np.zeros(max_iter + 1)
n_iter, n_pred = 0, 0
active = list()
unactive = range (X.shape[1])
active_mask = np.zeros (X.shape[1], dtype=np.uint8)
active = list()
unactive = range(n_features)
active_mask = np.zeros(n_features, dtype=np.uint8)
# holds the sign of covariance
sign_active = np.empty (max_pred, dtype=np.int8)
Cov = np.empty (X.shape[1])
a = np.empty (X.shape[1])
Cov = np.empty(n_features)
a = np.empty(n_features)
drop = False

# will hold the cholesky factorization
# only lower part is referenced. We do not create it as
# empty array because chol_solve calls chkfinite on the
# whole array, which can cause problems.
L = np.zeros ((max_pred, max_pred), dtype=np.float64)
L = np.zeros((max_pred, max_pred), dtype=np.float64)

Xt = X.T

if Gram is not None:
res_init = np.dot (X.T, y)
Xty = np.dot(Xt, y)
else:
res = y.copy() # Residual to be kept up to date

while 1:

n_unactive = X.shape[1] - n_pred # number of unactive elements
n_unactive = n_features - n_pred # number of unactive elements

if n_unactive:
# Calculate covariance matrix and get maximum
if Gram is None:
res = y - np.dot (X, beta[n_iter]) # there are better ways
arrayfuncs.dot_over (X.T, res, active_mask, np.False_, Cov)
# Compute X[:,inactive].T * res where res = y - X beta
# To get the most correlated variable not already in the active set
arrayfuncs.dot_over(Xt, res, active_mask, np.False_, Cov)
else:
# could use dot_over
arrayfuncs.dot_over (Gram, beta[n_iter], active_mask, np.False_, a)
Cov = res_init[unactive] - a[:n_unactive]
arrayfuncs.dot_over(Gram, betas[n_iter], active_mask, np.False_, a)
Cov = Xty[unactive] - a[:n_unactive]

imax = np.argmax (np.abs(Cov[:n_unactive])) #rename
C_ = Cov [imax]
imax = np.argmax(np.abs(Cov[:n_unactive])) #rename
C_ = Cov[imax]
# np.delete (Cov, imax) # very ugly, has to be fixed
else:
# special case when all elements are in the active set
if Gram is None:
res = y - np.dot (X, beta[n_iter])
C_ = np.dot (X.T[0], res)
C_ = np.dot(X.T[0], res)
else:
C_ = np.dot(Gram[0], beta[n_iter]) - res_init[0]
C_ = np.dot(Gram[0], betas[n_iter]) - Xty[0]

alpha = np.abs(C_) # ugly alpha vs alphas
alphas [n_iter] = alpha
alpha = fabs(C_) # ugly alpha vs alphas
alphas[n_iter] = alpha

if (n_iter >= max_iter or n_pred >= max_pred ):
print alpha, n_pred

if (n_iter >= max_iter or n_pred >= max_pred):
break

if (alpha < alpha_min): break

if not drop:

imax = unactive.pop (imax)

imax = unactive.pop(imax)

# Update the Cholesky factorization of (Xa * Xa') #
# #
Expand All @@ -147,104 +146,109 @@ def lars_path(X, y, Gram=None, max_iter=None, alpha_min=0,
# #
# where u is the last added to the active set #


sign_active [n_pred] = np.sign (C_)
sign_active[n_pred] = np.sign(C_)

if Gram is None:
X_max = Xt[imax]
c = np.dot (X_max, X_max)
b = np.dot (X_max, X[:, active])
c = linalg.norm(X_max)**2
b = np.dot(X_max, X[:, active])
else:
c = Gram[imax, imax]
b = Gram[imax, active]

n_pred += 1
# Do cholesky update of the Gram matrix of the active set
L[n_pred, n_pred] = c
active.append(imax)
if n_pred > 0:
arrayfuncs.solve_triangular(L[:n_pred, :n_pred], b)
L[n_pred, :n_pred] = b[:]
v = np.dot(L[n_pred, :n_pred], L[n_pred, :n_pred])
L[n_pred, n_pred] = np.sqrt(c - v)

L [n_pred-1, n_pred-1] = c

if n_pred > 1:

# please refactor me, using linalg.solve is overkill
#L [n_pred-1, :n_pred-1] = linalg.solve (L[:n_pred-1, :n_pred-1], b)
arrayfuncs.solve_triangular (L[:n_pred-1, :n_pred-1],
b)
L [n_pred-1, :n_pred-1] = b[:]
v = np.dot(L [n_pred-1, :n_pred-1], L [n_pred - 1, :n_pred -1])
L [n_pred-1, n_pred-1] = np.sqrt (c - v)
n_pred += 1

# Now we go into the normal equations dance.
# (Golub & Van Loan, 1996)
C = fabs(C_)

b = copysign (C_.repeat(n_pred), sign_active[:n_pred])
b = linalg.cho_solve ((L[:n_pred, :n_pred], True), b)

C = A = np.abs(C_)
# compute eqiangular vector
if Gram is None:
u = np.dot (Xt[active].T, b)
arrayfuncs.dot_over (X.T, u, active_mask, np.False_, a)

b = linalg.cho_solve((L[:n_pred, :n_pred], True), sign_active[:n_pred])
AA = 1. / sqrt(np.sum(b * sign_active[:n_pred]))
b *= AA
else:
# Not sure that this is not not buggy ...
arrayfuncs.dot_over (Gram[active].T, b, active_mask, np.False_, a)

# equation 2.13, there's probably a simpler way
g1 = (C - Cov[:n_unactive]) / (A - a[:n_unactive])
g2 = (C + Cov[:n_unactive]) / (A + a[:n_unactive])
S = sign_active[:n_pred][:,None] * sign_active[:n_pred][None,:]
b = linalg.inv(Gram[active][:,active] * S)
b = np.sum(b, axis=1)
AA = 1. / sqrt(b.sum())
b *= sign_active[:n_pred]
b *= AA

eqdir = np.dot(X[:,active], b) # equiangular direction (unit vector)
# correlation between active variables and eqiangular vector
u = np.dot(X[:,active], b)
arrayfuncs.dot_over(X.T, u, active_mask, np.False_, a)
arrayfuncs.dot_over(X.T, eqdir, active_mask, np.False_, a)

if not drop:
# Quickfix
active_mask [imax] = np.True_
active_mask[imax] = np.True_
else:
drop = False

# one for the border cases
g = np.concatenate((g1, g2, [1.]))

g = g[g > 0.]
gamma_ = np.min (g)

if n_pred >= X.shape[1]:
gamma_ = 1.
if n_pred >= n_features:
gamma_ = C / AA
else:
# equation 2.13
g1 = (C - Cov[:n_unactive]) / (AA - a[:n_unactive])
g2 = (C + Cov[:n_unactive]) / (AA + a[:n_unactive])
gamma_ = np.r_[g1[g1 > 0], g2[g2 > 0], C / AA].min()

if method == 'lasso':
drop = False
z = - betas[n_iter, active] / b
z_pos = z[z > 0]
if z_pos.size > 0:
gamma_tilde_ = np.r_[z_pos, gamma_].min()
if gamma_tilde_ < gamma_:
idx = np.where(z == gamma_tilde_)[0]
gamma_ = gamma_tilde_
drop = True

z = - beta[n_iter, active] / b
z[z <= 0.] = np.inf

idx = np.argmin(z)
n_iter += 1
betas[n_iter, active] = betas[n_iter - 1, active] + gamma_ * b

if z[idx] < gamma_:
gamma_ = z[idx]
drop = True
if Gram is None:
res -= gamma_ * eqdir # update residual

n_iter += 1
beta[n_iter, active] = beta[n_iter - 1, active] + gamma_ * b
if n_pred > n_features:
break

if drop:
arrayfuncs.cholesky_delete (L[:n_pred, :n_pred], idx)
arrayfuncs.cholesky_delete(L[:n_pred, :n_pred], idx)
n_pred -= 1
drop_idx = active.pop (idx)
unactive.append(drop_idx)
active_mask[drop_idx] = False
sign_active = np.delete (sign_active, idx) # do an append to maintain size
sign_active = np.append (sign_active, 0.)
# should be done using cholesky deletes
# do an append to maintain size
sign_active = np.delete(sign_active, idx)
sign_active = np.append(sign_active, 0.)


if alpha < alpha_min: # interpolate
# interpolation factor 0 <= ss < 1
ss = (alphas[n_iter-1] - alpha_min) / (alphas[n_iter-1] - alphas[n_iter])
beta[n_iter] = beta[n_iter-1] + ss*(beta[n_iter] - beta[n_iter-1]);
betas[n_iter] = betas[n_iter-1] + ss*(betas[n_iter] - betas[n_iter-1]);
alphas[n_iter] = alpha_min
alphas = alphas[:n_iter+1]
beta = beta[:n_iter+1]

return alphas, active, beta.T
alphas = alphas[:n_iter+1]
betas = betas[:n_iter+1]

return alphas, active, betas.T


class LARS (LinearModel):
""" Least Angle Regression model a.k.a. LAR
class LARS(LinearModel):
"""Least Angle Regression model a.k.a. LAR
Parameters
----------
Expand Down
11 changes: 7 additions & 4 deletions scikits/learn/glm/tests/test_lars.py
Expand Up @@ -20,7 +20,8 @@ def test_simple():
Principle of LARS is to keep covariances tied and decreasing
"""
max_pred = 10
alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target, max_iter=max_pred, method="lar")
alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target,
max_iter=max_pred, method="lar")
for (i, coef_) in enumerate(coef_path_.T):
res = y - np.dot(X, coef_)
cov = np.dot(X.T, res)
Expand All @@ -40,7 +41,8 @@ def test_simple_precomputed():
"""
max_pred = 10
G = np.dot (diabetes.data.T, diabetes.data)
alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target, Gram=G, max_iter=max_pred, method="lar")
alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target,
Gram=G, max_iter=max_pred, method="lar")
for (i, coef_) in enumerate(coef_path_.T):
res = y - np.dot(X, coef_)
cov = np.dot(X.T, res)
Expand All @@ -54,14 +56,14 @@ def test_simple_precomputed():
assert ocur == max_pred



def test_lars_lstsq():
"""
Test that LARS gives least square solution at the end
of the path
"""
# test that it arrives to a least squares solution
alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target, method="lar")
alphas_, active, coef_path_ = lars_path(diabetes.data, diabetes.target,
method="lar")
coef_lstsq = np.linalg.lstsq(X, y)[0]
assert_array_almost_equal(coef_path_.T[-1], coef_lstsq)

Expand All @@ -76,6 +78,7 @@ def test_lasso_gives_lstsq_solution():
coef_lstsq = np.linalg.lstsq(X, y)[0]
assert_array_almost_equal(coef_lstsq , coef_path_[:,-1])


def test_lasso_lars_vs_lasso_cd(verbose=False):
"""
Test that LassoLars and Lasso using coordinate descent give the
Expand Down

0 comments on commit f1530fb

Please sign in to comment.