Permalink
Browse files

adding LARS with Gram to benchmark

  • Loading branch information...
agramfort committed Sep 20, 2010
1 parent 0e193e8 commit e174b32c7035277210276ee946ac8f2453c981e7
Showing with 19 additions and 11 deletions.
  1. +19 −11 scikits/learn/benchmarks/bench_lasso.py
@@ -17,12 +17,15 @@
from bench_glm import make_data
-def bench(clf, X_train, Y_train, X_test, Y_test):
+def bench(clf, X_train, Y_train, X_test, Y_test, Gram=None):
gc.collect()
# start time
tstart = time()
- clf = clf.fit(X_train, Y_train)
+ if Gram is not None:
+ clf = clf.fit(X_train, Y_train, Gram=Gram)
+ else:
+ clf = clf.fit(X_train, Y_train)
delta = (time() - tstart)
# stop time
@@ -37,11 +40,11 @@ def LassoFactory(alpha):
return Lasso(alpha=alpha, fit_intercept=False)
def LassoLARSFactory(alpha):
- return LassoLARS(alpha=alpha, normalize=False)
- # return LassoLARS(alpha=alpha, fit_intercept=False, normalize=False)
+ return LassoLARS(alpha=alpha, fit_intercept=False, normalize=False)
lasso_results = []
larslasso_results = []
+ larslasso_gram_results = []
n_tests = 1000
it = 0
@@ -65,8 +68,11 @@ def LassoLARSFactory(alpha):
print "benching LassoLARS: "
larslasso_results.append(bench(LassoLARSFactory(alpha),
X, Y, X_test, Y_test))
+ print "benching LassoLARS (precomp. Gram): "
+ larslasso_gram_results.append(bench(LassoLARSFactory(alpha),
+ X, Y, X_test, Y_test, Gram=np.dot(X.T, X)))
- return lasso_results, larslasso_results
+ return lasso_results, larslasso_results, larslasso_gram_results
if __name__ == '__main__':
from scikits.learn.glm import Lasso, LassoLARS
@@ -76,28 +82,30 @@ def LassoLARSFactory(alpha):
n_features = 500
list_n_samples = range(500, 10001, 500);
- lasso_results, larslasso_results = compute_bench(alpha, list_n_samples,
- [n_features])
+ lasso_results, larslasso_results, larslasso_gram_results = \
+ compute_bench(alpha, list_n_samples, [n_features])
pl.close('all')
pl.title('Lasso benchmark (%d features - alpha=%s)' % (n_features, alpha))
pl.plot(list_n_samples, lasso_results, 'b-', label='Lasso')
pl.plot(list_n_samples, larslasso_results,'r-', label='LassoLARS')
- pl.legend()
+ pl.plot(list_n_samples, larslasso_gram_results,'g-', label='LassoLARS (Gram)')
+ pl.legend(loc='upper left')
pl.xlabel('number of samples')
pl.ylabel('time (in seconds)')
pl.show()
n_samples = 500
list_n_features = range(500, 3001, 500);
- lasso_results, larslasso_results = compute_bench(alpha, [n_samples],
- list_n_features)
+ lasso_results, larslasso_results, larslasso_gram_results = \
+ compute_bench(alpha, [n_samples], list_n_features)
pl.figure()
pl.title('Lasso benchmark (%d samples - alpha=%s)' % (n_samples, alpha))
pl.plot(list_n_features, lasso_results, 'b-', label='Lasso')
pl.plot(list_n_features, larslasso_results,'r-', label='LassoLARS')
- pl.legend()
+ pl.plot(list_n_features, larslasso_gram_results,'g-', label='LassoLARS (Gram)')
+ pl.legend(loc='upper left')
pl.xlabel('number of features')
pl.ylabel('time (in seconds)')
pl.show()

0 comments on commit e174b32

Please sign in to comment.