Skip to content

Commit

Permalink
Merge pull request #218 from fabianp/fix_lars
Browse files Browse the repository at this point in the history
Fix lars
  • Loading branch information
Fabian Pedregosa committed Jul 19, 2011
2 parents 58d5b7f + 21f3659 commit 1355100
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 35 deletions.
2 changes: 1 addition & 1 deletion doc/modules/linear_model.rst
Expand Up @@ -269,7 +269,7 @@ function of the norm of its coefficients.
LassoLARS(normalize=True, verbose=False, fit_intercept=True, max_iter=500,
precompute='auto', alpha=0.1)
>>> clf.coef_
array([ 0.50710678, 0. ])
array([ 0.71715729, 0. ])

.. topic:: Examples:

Expand Down
11 changes: 5 additions & 6 deletions scikits/learn/linear_model/least_angle.py
Expand Up @@ -122,10 +122,7 @@ def lars_path(X, y, Xy=None, Gram=None, max_features=None, max_iter=500,
C = np.fabs(C_)
# to match a for computing gamma_
else:
if Gram is None:
C -= gamma_ * np.abs(np.dot(X.T[0], eq_dir))
else:
C -= gamma_ * np.abs(np.dot(Gram[0], least_squares))
C = 0.

alphas[n_iter] = C / n_samples

Expand Down Expand Up @@ -347,7 +344,7 @@ class LARS(LinearModel):
LARS(normalize=True, precompute='auto', max_iter=500, verbose=False,
fit_intercept=True)
>>> print clf.coef_
[ 0. -0.81649658]
[ 0. -1.]
References
----------
Expand Down Expand Up @@ -414,6 +411,8 @@ def fit(self, X, y, max_features=None, overwrite_X=False, **params):
method=self.method, verbose=self.verbose,
max_features=max_features, max_iter=self.max_iter)

if self.normalize:
self.coef_path_ /= norms[:, np.newaxis]
self.coef_ = self.coef_path_[:, -1]

self._set_intercept(Xmean, ymean)
Expand Down Expand Up @@ -468,7 +467,7 @@ class LassoLARS (LARS):
LassoLARS(normalize=True, verbose=False, fit_intercept=True, max_iter=500,
precompute='auto', alpha=0.01)
>>> print clf.coef_
[ 0. -0.78649658]
[ 0. -0.96325765]
References
----------
Expand Down
11 changes: 6 additions & 5 deletions scikits/learn/linear_model/tests/test_bayes.py
Expand Up @@ -13,6 +13,7 @@

from scikits.learn import datasets


def test_bayesian_on_diabetes():
"""
Test BayesianRidge on diabetes
Expand All @@ -29,7 +30,7 @@ def test_bayesian_on_diabetes():
assert_array_equal(np.diff(clf.scores_) > 0, True)

# Test with more features than samples
X = X[:5,:]
X = X[:5, :]
y = y[:5]
clf.fit(X, y)
# Test that scores are increasing at each iteration
Expand All @@ -45,7 +46,7 @@ def test_toy_bayesian_ridge_object():
clf = BayesianRidge(compute_score=True)
clf.fit(X, Y)
X_test = [[1], [3], [4]]
assert(np.abs(clf.predict(X_test)-[1, 3, 4]).sum() < 1.e-2) # identity
assert(np.abs(clf.predict(X_test) - [1, 3, 4]).sum() < 1.e-2) # identity


def test_toy_ard_object():
Expand All @@ -54,7 +55,7 @@ def test_toy_ard_object():
"""
X = np.array([[1], [2], [3]])
Y = np.array([1, 2, 3])
clf = ARDRegression(compute_score = True)
clf = ARDRegression(compute_score=True)
clf.fit(X, Y)
Test = [[1], [3], [4]]
assert(np.abs(clf.predict(Test)-[1, 3, 4]).sum() < 1.e-3) # identity
test = [[1], [3], [4]]
assert(np.abs(clf.predict(test) - [1, 3, 4]).sum() < 1.e-3) # identity
48 changes: 25 additions & 23 deletions scikits/learn/linear_model/tests/test_least_angle.py
Expand Up @@ -18,13 +18,13 @@ def test_simple():
diabetes.data, diabetes.target, method="lar")

for (i, coef_) in enumerate(coef_path_.T):
res = y - np.dot(X, coef_)
res = y - np.dot(X, coef_)
cov = np.dot(X.T, res)
C = np.max(abs(cov))
eps = 1e-3
ocur = len(cov[ C - eps < abs(cov)])
ocur = len(cov[C - eps < abs(cov)])
if i < X.shape[1]:
assert ocur == i+1
assert ocur == i + 1
else:
# no more than max_pred variables can go into the active set
assert ocur == X.shape[1]
Expand All @@ -35,18 +35,18 @@ def test_simple_precomputed():
The same, with precomputed Gram matrix
"""

G = np.dot (diabetes.data.T, diabetes.data)
G = np.dot(diabetes.data.T, diabetes.data)
alphas_, active, coef_path_ = linear_model.lars_path(
diabetes.data, diabetes.target, Gram=G, method="lar")

for (i, coef_) in enumerate(coef_path_.T):
res = y - np.dot(X, coef_)
for i, coef_ in enumerate(coef_path_.T):
res = y - np.dot(X, coef_)
cov = np.dot(X.T, res)
C = np.max(abs(cov))
eps = 1e-3
ocur = len(cov[ C - eps < abs(cov)])
ocur = len(cov[C - eps < abs(cov)])
if i < X.shape[1]:
assert ocur == i+1
assert ocur == i + 1
else:
# no more than max_pred variables can go into the active set
assert ocur == X.shape[1]
Expand All @@ -57,35 +57,33 @@ def test_lars_lstsq():
Test that LARS gives least square solution at the end
of the path
"""
# test that it arrives to a least squares solution
alphas_, active, coef_path_ = linear_model.lars_path(diabetes.data, diabetes.target,
method="lar")
coef_lstsq = np.linalg.lstsq(X, y)[0]
assert_array_almost_equal(coef_path_.T[-1], coef_lstsq)
X1 = 3 * diabetes.data # use un-normalized dataset
clf = linear_model.LassoLARS(alpha=0.)
clf.fit(X1, y)
coef_lstsq = np.linalg.lstsq(X1, y)[0]
assert_array_almost_equal(clf.coef_, coef_lstsq)


def test_lasso_gives_lstsq_solution():
"""
Test that LARS Lasso gives least square solution at the end
of the path
"""

alphas_, active, coef_path_ = linear_model.lars_path(X, y, method="lasso")
coef_lstsq = np.linalg.lstsq(X, y)[0]
assert_array_almost_equal(coef_lstsq , coef_path_[:,-1])
assert_array_almost_equal(coef_lstsq, coef_path_[:, -1])


def test_collinearity():
"""Check that lars_path is robust to collinearity in input"""

X = np.array([[3., 3., 1.],
[2., 2., 0.],
[1., 1., 0]])
y = np.array([1., 0., 0])

_, _, coef_path_ = linear_model.lars_path(X, y)
assert (not np.isnan(coef_path_).any())
assert_array_almost_equal(np.dot(X, coef_path_[:,-1]), y)
assert_array_almost_equal(np.dot(X, coef_path_[:, -1]), y)


def test_singular_matrix():
Expand All @@ -103,6 +101,8 @@ def test_lasso_lars_vs_lasso_cd(verbose=False):
Test that LassoLars and Lasso using coordinate descent give the
same results
"""
X = 3 * diabetes.data

alphas, _, lasso_path = linear_model.lars_path(X, y, method='lasso')
lasso_cd = linear_model.Lasso(fit_intercept=False)
for (c, a) in zip(lasso_path.T, alphas):
Expand All @@ -113,10 +113,11 @@ def test_lasso_lars_vs_lasso_cd(verbose=False):

# similar test, with the classifiers
for alpha in np.linspace(1e-2, 1 - 1e-2):
clf1 = linear_model.LassoLARS(alpha=alpha).fit(X, y)
clf2 = linear_model.Lasso(alpha=alpha).fit(X, y, tol=1e-8)
err = np.linalg.norm(clf1.coef_ - clf2.coef_)
assert err < 1e-3
clf1 = linear_model.LassoLARS(alpha=alpha, normalize=False).fit(X, y)
clf2 = linear_model.Lasso(alpha=alpha).fit(X, y, tol=1e-8)
err = np.linalg.norm(clf1.coef_ - clf2.coef_)
assert err < 1e-3


def test_lasso_lars_vs_lasso_cd_early_stopping(verbose=False):
"""
Expand All @@ -131,9 +132,10 @@ def test_lasso_lars_vs_lasso_cd_early_stopping(verbose=False):
lasso_cd = linear_model.Lasso(fit_intercept=False)
lasso_cd.alpha = alphas[-1]
lasso_cd.fit(X, y, tol=1e-8)
error = np.linalg.norm(lasso_path[:,-1] - lasso_cd.coef_)
error = np.linalg.norm(lasso_path[:, -1] - lasso_cd.coef_)
assert error < 0.01


def test_lars_add_features(verbose=False):
"""
assure that at least some features get added if necessary
Expand All @@ -155,7 +157,7 @@ def test_lars_add_features(verbose=False):
[-0.12951744, 0.21978613, -0.04762174, -0.27227304, -0.02722684, 0.57449581]]),
np.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]))


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 1355100

Please sign in to comment.