Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[READY] ENH - Add Gram Solver for single task Quadratic datafit #59

Merged
merged 27 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b8fd539
init commit
Badr-MOUFAD Aug 24, 2022
c2aecba
gram solver && unit test
Badr-MOUFAD Aug 24, 2022
507fc8a
fix bug gram solver && tighten test
Badr-MOUFAD Aug 24, 2022
c9b64c2
add anderson acceleration
Badr-MOUFAD Aug 24, 2022
20c1911
bug ``stop_criter`` && refactor
Badr-MOUFAD Aug 24, 2022
f2e985d
refactoring of var names
Badr-MOUFAD Aug 25, 2022
2dbc8e4
handle ``w_init``
Badr-MOUFAD Aug 25, 2022
8ca7a41
refactor ``_gram_cd_``
Badr-MOUFAD Aug 25, 2022
3453233
gram epoch greedy and cyclic strategy
Badr-MOUFAD Aug 25, 2022
8d3dbc1
extend to sparse case && unitest
Badr-MOUFAD Aug 25, 2022
cdd7e34
one implementation of _gram_cd && unittest
Badr-MOUFAD Aug 25, 2022
f4bfeaf
greedy_cd arg instead of cd_strategy
Badr-MOUFAD Aug 25, 2022
95cf1d4
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Aug 25, 2022
4c0acca
add docs
Badr-MOUFAD Aug 25, 2022
dcab054
script fast gram, not faster than scipy
mathurinm Aug 25, 2022
e8bc96e
fast gram timing
Badr-MOUFAD Aug 25, 2022
61a67c4
keep grads instead
Badr-MOUFAD Aug 25, 2022
1b6c169
refactor ``chosen_j``
Badr-MOUFAD Aug 25, 2022
c9c5575
script to profile
Badr-MOUFAD Aug 25, 2022
68a0458
potential improvements, docstring
mathurinm Aug 26, 2022
3788cc4
warnings.warn arguments in correct order
mathurinm Aug 26, 2022
1ce391d
cleanups: ann files
Badr-MOUFAD Aug 26, 2022
2476a34
fix ``p_obj`` computation
Badr-MOUFAD Aug 26, 2022
0f766e9
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Aug 26, 2022
3208dfa
typos + less cases in test, smaller X in tests
mathurinm Aug 26, 2022
16f6ee4
typo: ``XtXw`` --> ``grad``
Badr-MOUFAD Aug 26, 2022
e9b7224
Merge branch 'gram-solver' of https://github.com/Badr-MOUFAD/skglm in…
Badr-MOUFAD Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def path(self, yXT, y, Cs, coef_init=None, return_n_iter=True, **params):
Target vector relative to X.

Cs : ndarray shape (n_Cs,)
Values of regularization strenghts for which solutions are
Values of regularization strengths for which solutions are
computed.

coef_init : array, shape (n_features,), optional
Expand Down
117 changes: 117 additions & 0 deletions skglm/solvers/gram_cd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
from numba import njit
from scipy.sparse import issparse
from skglm.utils import AndersonAcceleration


def gram_cd_solver(X, y, penalty, max_iter=20, w_init=None,
use_acc=True, cd_type='greedy', tol=1e-4, verbose=False):
"""Run coordinate descent while keeping the gradients up-to-date with Gram updates.

Minimize::
1 / (2*n_samples) * norm(y - Xw)**2 + penalty(w)

Which can be rewritten as::
w.T @ Q @ w / (2*n_samples) - q.T @ w / n_samples + penalty(w)

where::
Q = X.T @ X (gram matrix)
q = X.T @ y
"""
n_samples, n_features = X.shape
mathurinm marked this conversation as resolved.
Show resolved Hide resolved
scaled_gram = X.T @ X / n_samples
scaled_Xty = X.T @ y / n_samples
scaled_y_norm2 = np.linalg.norm(y)**2 / (2*n_samples)

if issparse(X):
scaled_gram = scaled_gram.toarray()

all_features = np.arange(n_features)
stop_crit = np.inf # prevent ref before assign
p_objs_out = []

w = np.zeros(n_features) if w_init is None else w_init
scaled_gram_w = np.zeros(n_features) if w_init is None else scaled_gram @ w_init
opt = penalty.subdiff_distance(w, -scaled_Xty, all_features) # initial: grad = -Xty
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved

if use_acc:
mathurinm marked this conversation as resolved.
Show resolved Hide resolved
accelerator = AndersonAcceleration(K=5)
w_acc = np.zeros(n_features)
scaled_gram_w_acc = np.zeros(n_features)

for t in range(max_iter):
# check convergences
stop_crit = np.max(opt)
if verbose:
p_obj = (0.5 * w @ scaled_gram_w - scaled_Xty @ w +
scaled_y_norm2 + penalty.value(w))
print(
f"Iteration {t+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit <= tol:
if verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break

# inplace update of w, XtXw
_gram_cd_epoch = _gram_cd_greedy if cd_type == 'greedy' else _gram_cd_cyclic
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
opt = _gram_cd_epoch(scaled_gram, scaled_Xty, w, scaled_gram_w,
penalty, all_features)
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved

# perform anderson extrapolation
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
if use_acc:
w_acc, scaled_gram_w_acc, is_extrapolated = accelerator.extrapolate(
w, scaled_gram_w)

if is_extrapolated:
p_obj_acc = (0.5 * w_acc @ scaled_gram_w_acc - scaled_Xty @ w_acc +
penalty.value(w_acc))
p_obj = 0.5 * w @ scaled_gram_w - scaled_Xty @ w + penalty.value(w)
if p_obj_acc < p_obj:
w[:] = w_acc
scaled_gram_w[:] = scaled_gram_w_acc

p_obj = 0.5 * w @ scaled_gram_w - scaled_Xty @ w + penalty.value(w)
p_objs_out.append(p_obj)
return w, np.array(p_objs_out), stop_crit


@njit
def _gram_cd_greedy(scaled_gram, scaled_Xty, w, scaled_gram_w, penalty, ws):
# inplace update of w, XtXw, opt
# perform greedy cd updates
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(len(w)):
grad = scaled_gram_w - scaled_Xty
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
opt = penalty.subdiff_distance(w, grad, ws)
j_max = np.argmax(opt)

old_w_j = w[j_max]
step = 1 / scaled_gram[j_max, j_max] # 1 / lipchitz_j
w[j_max] = penalty.prox_1d(old_w_j - step * grad[j_max], step, j_max)

# Gram matrix update
if w[j_max] != old_w_j:
scaled_gram_w += (w[j_max] - old_w_j) * scaled_gram[:, j_max]
return opt


@njit
def _gram_cd_cyclic(scaled_gram, scaled_Xty, w, scaled_gram_w, penalty, ws):
# inplace update of w, XtXw, opt
# perform greedy cd updates
for j in range(len(w)):
grad = scaled_gram_w - scaled_Xty

old_w_j = w[j]
step = 1 / scaled_gram[j, j] # 1 / lipchitz_j
w[j] = penalty.prox_1d(old_w_j - step * grad[j], step, j)

# Gram matrix update
if w[j] != old_w_j:
scaled_gram_w += (w[j] - old_w_j) * scaled_gram[:, j]

# opt
grad = scaled_gram_w - scaled_Xty
return penalty.subdiff_distance(w, grad, ws)
45 changes: 45 additions & 0 deletions skglm/tests/test_gram_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
from itertools import product

import numpy as np
from numpy.linalg import norm
from sklearn.linear_model import Lasso

from skglm.penalties import L1
from skglm.solvers.gram_cd import gram_cd_solver
from skglm.utils import make_correlated_data, compiled_clone


@pytest.mark.parametrize("n_samples, n_features, X_density",
product([100, 200], [50, 90], [1., 0.6]))
def test_alpha_max(n_samples, n_features, X_density):
X, y, _ = make_correlated_data(n_samples, n_features,
random_state=0, X_density=X_density)
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples

l1_penalty = compiled_clone(L1(alpha_max))
w = gram_cd_solver(X, y, l1_penalty, tol=1e-9, verbose=0)[0]

np.testing.assert_equal(w, 0)


@pytest.mark.parametrize("n_samples, n_features, rho, X_density",
product([500, 100], [30, 80], [1e-1, 1e-2, 1e-3], [1., 0.8]))
def test_vs_lasso_sklearn(n_samples, n_features, rho, X_density):
X, y, _ = make_correlated_data(n_samples, n_features,
random_state=0, X_density=X_density)
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
alpha = rho * alpha_max

sk_lasso = Lasso(alpha, fit_intercept=False, tol=1e-9)
sk_lasso.fit(X, y)

l1_penalty = compiled_clone(L1(alpha))
w = gram_cd_solver(X, y, l1_penalty, tol=1e-9, verbose=0, max_iter=1000)[0]

np.testing.assert_allclose(w, sk_lasso.coef_.flatten(), rtol=1e-7, atol=1e-7)


if __name__ == '__main__':
test_vs_lasso_sklearn(100, 10, 0.01)
pass