Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[READY] ENH - Add Gram Solver for single task Quadratic datafit #59

Merged
merged 27 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b8fd539
init commit
Badr-MOUFAD Aug 24, 2022
c2aecba
gram solver && unit test
Badr-MOUFAD Aug 24, 2022
507fc8a
fix bug gram solver && tighten test
Badr-MOUFAD Aug 24, 2022
c9b64c2
add anderson acceleration
Badr-MOUFAD Aug 24, 2022
20c1911
bug ``stop_criter`` && refactor
Badr-MOUFAD Aug 24, 2022
f2e985d
refactoring of var names
Badr-MOUFAD Aug 25, 2022
2dbc8e4
handle ``w_init``
Badr-MOUFAD Aug 25, 2022
8ca7a41
refactor ``_gram_cd_``
Badr-MOUFAD Aug 25, 2022
3453233
gram epoch greedy and cyclic strategy
Badr-MOUFAD Aug 25, 2022
8d3dbc1
extend to sparse case && unitest
Badr-MOUFAD Aug 25, 2022
cdd7e34
one implementation of _gram_cd && unittest
Badr-MOUFAD Aug 25, 2022
f4bfeaf
greedy_cd arg instead of cd_strategy
Badr-MOUFAD Aug 25, 2022
95cf1d4
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Aug 25, 2022
4c0acca
add docs
Badr-MOUFAD Aug 25, 2022
dcab054
script fast gram, not faster than scipy
mathurinm Aug 25, 2022
e8bc96e
fast gram timing
Badr-MOUFAD Aug 25, 2022
61a67c4
keep grads instead
Badr-MOUFAD Aug 25, 2022
1b6c169
refactor ``chosen_j``
Badr-MOUFAD Aug 25, 2022
c9c5575
script to profile
Badr-MOUFAD Aug 25, 2022
68a0458
potential improvements, docstring
mathurinm Aug 26, 2022
3788cc4
warnings.warn arguments in correct order
mathurinm Aug 26, 2022
1ce391d
cleanups: ann files
Badr-MOUFAD Aug 26, 2022
2476a34
fix ``p_obj`` computation
Badr-MOUFAD Aug 26, 2022
0f766e9
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Aug 26, 2022
3208dfa
typos + less cases in test, smaller X in tests
mathurinm Aug 26, 2022
16f6ee4
typo: ``XtXw`` --> ``grad``
Badr-MOUFAD Aug 26, 2022
e9b7224
Merge branch 'gram-solver' of https://github.com/Badr-MOUFAD/skglm in…
Badr-MOUFAD Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions skglm/solvers/gram_cd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import warnings
import numpy as np
from numba import njit
from scipy.sparse import issparse

from skglm.utils import AndersonAcceleration


def gram_cd_solver(X, y, penalty, max_iter=100, w_init=None,
use_acc=True, greedy_cd=True, tol=1e-4, verbose=False):
r"""Run coordinate descent while keeping the gradients up-to-date with Gram updates.

This solver should be used when n_features < n_samples, and computes the
(n_features, n_features) Gram matrix which comes with an overhead. It is only
suited to Quadratic datafits.

It minimizes::
1 / (2*n_samples) * norm(y - Xw)**2 + penalty(w)

which can be rewritten as::
w.T @ Q @ w / (2*n_samples) - q.T @ w / n_samples + penalty(w)

where::
Q = X.T @ X (gram matrix), and q = X.T @ y

Parameters
----------
X : array or sparse CSC matrix, shape (n_samples, n_features)
Design matrix.

y : array, shape (n_samples,)
Target vector.

penalty : instance of BasePenalty
Penalty object.

max_iter : int, default 100
Maximum number of iterations.

w_init : array, shape (n_features,), default None
Initial value of coefficients.
If set to None, a zero vector is used instead.

use_acc : bool, default True
Extrapolate the iterates based on the past 5 iterates if set to True.

greedy_cd : bool, default True
Use a greedy strategy to select features to update in coordinate descent epochs
if set to True. A cyclic strategy is used otherwise.

tol : float, default 1e-4
Tolerance for convergence.

verbose : bool, default False
Amount of verbosity. 0/False is silent.

Returns
-------
w : array, shape (n_features,)
Solution that minimizes the problem defined by datafit and penalty.

objs_out : array, shape (n_iter,)
The objective values at every outer iteration.

stop_crit : float
The value of the stopping criterion when the solver stops.
"""
n_samples, n_features = X.shape
mathurinm marked this conversation as resolved.
Show resolved Hide resolved

if issparse(X):
scaled_gram = X.T.dot(X)
scaled_gram = scaled_gram.toarray() / n_samples
scaled_Xty = X.T.dot(y) / n_samples
else:
scaled_gram = X.T @ X / n_samples
scaled_Xty = X.T @ y / n_samples
# TODO potential improvement: allow to pass scaled_gram (e.g. for path computation)

scaled_y_norm2 = np.linalg.norm(y)**2 / (2*n_samples)

all_features = np.arange(n_features)
stop_crit = np.inf # prevent ref before assign
p_objs_out = []

w = np.zeros(n_features) if w_init is None else w_init
grad = - scaled_Xty if w_init is None else scaled_gram @ w_init - scaled_Xty
opt = penalty.subdiff_distance(w, grad, all_features)

if use_acc:
mathurinm marked this conversation as resolved.
Show resolved Hide resolved
if greedy_cd:
warnings.warn(
"Anderson acceleration does not work with greedy_cd, set use_acc=False",
UserWarning)
accelerator = AndersonAcceleration(K=5)
w_acc = np.zeros(n_features)
grad_acc = np.zeros(n_features)

for t in range(max_iter):
# check convergences
stop_crit = np.max(opt)
if verbose:
p_obj = (0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w +
scaled_y_norm2 + penalty.value(w))
print(
f"Iteration {t+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit <= tol:
if verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break

# inplace update of w, grad
opt = _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd)

# perform Anderson extrapolation
if use_acc:
w_acc, grad_acc, is_extrapolated = accelerator.extrapolate(w, grad)

if is_extrapolated:
# omit constant term for comparison
p_obj_acc = (0.5 * w_acc @ (scaled_gram @ w_acc) - scaled_Xty @ w_acc +
penalty.value(w_acc))
p_obj = 0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w + penalty.value(w)
if p_obj_acc < p_obj:
w[:] = w_acc
grad[:] = grad_acc

# store p_obj
p_obj = (0.5 * w @ (scaled_gram @ w) - scaled_Xty @ w + scaled_y_norm2 +
penalty.value(w))
p_objs_out.append(p_obj)
return w, np.array(p_objs_out), stop_crit


@njit
def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd):
all_features = np.arange(len(w))
for cd_iter in all_features:
# select feature j
if greedy_cd:
opt = penalty.subdiff_distance(w, grad, all_features)
j = np.argmax(opt)
else: # cyclic
j = cd_iter

# update w_j
old_w_j = w[j]
step = 1 / scaled_gram[j, j] # 1 / lipschitz_j
w[j] = penalty.prox_1d(old_w_j - step * grad[j], step, j)

# gradient update with Gram matrix
if w[j] != old_w_j:
grad += (w[j] - old_w_j) * scaled_gram[:, j]

return penalty.subdiff_distance(w, grad, all_features)
32 changes: 32 additions & 0 deletions skglm/tests/test_gram_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from itertools import product

import numpy as np
from numpy.linalg import norm
from sklearn.linear_model import Lasso

from skglm.penalties import L1
from skglm.solvers.gram_cd import gram_cd_solver
from skglm.utils import make_correlated_data, compiled_clone


@pytest.mark.parametrize("rho, X_density, greedy_cd",
product([1e-1, 1e-3], [1., 0.8], [True, False]))
def test_vs_lasso_sklearn(rho, X_density, greedy_cd):
X, y, _ = make_correlated_data(
n_samples=18, n_features=8, random_state=0, X_density=X_density)
alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
alpha = rho * alpha_max

sk_lasso = Lasso(alpha, fit_intercept=False, tol=1e-9)
sk_lasso.fit(X, y)

l1_penalty = compiled_clone(L1(alpha))
w = gram_cd_solver(X, y, l1_penalty, tol=1e-9, verbose=0,
max_iter=1000, greedy_cd=greedy_cd)[0]

np.testing.assert_allclose(w, sk_lasso.coef_.flatten(), rtol=1e-7, atol=1e-7)


if __name__ == '__main__':
pass