Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[READY] ENH - add working set strategy to group_bcd_solver #28

Merged
merged 25 commits into from
Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 31 additions & 4 deletions skglm/penalties/block_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,23 @@ def prox_1group(self, value, stepsize, g):
"""Compute the proximal operator of group ``g``."""
return BST(value, self.alpha * stepsize * self.weights[g])

def subdiff_distance(self, w, grad, ws):
"""Compute distance of negative gradient to the subdifferential at ``w``."""
def subdiff_distance(self, w, grad_ws, ws):
"""Compute distance to the subdifferential at ``w`` of negative gradient.
mathurinm marked this conversation as resolved.
Show resolved Hide resolved

Note: ``grad_ws`` is a stacked array of ``-``gradients.
([-grad_ws_1, -grad_ws_2, ...])
"""
alpha, weights = self.alpha, self.weights
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices

scores = np.zeros(len(ws))
grad_ptr = 0
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
for idx, g in enumerate(ws):
grad_g = grad[idx]

grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]

grad_g = grad_ws[grad_ptr: grad_ptr + len(grp_g_indices)]
grad_ptr += len(grp_g_indices)

w_g = w[grp_g_indices]
norm_w_g = norm(w_g)
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -232,3 +239,23 @@ def subdiff_distance(self, w, grad, ws):
scores[idx] = norm(grad_g - subdiff)

return scores

def is_penalized(self, n_groups):
return np.ones(n_groups, dtype=np.bool_)

def generalized_support(self, w):
grp_indices, grp_ptr = self.grp_indices, self.grp_ptr
n_groups = len(grp_ptr) - 1
is_penalized = self.is_penalized(n_groups)

gsupp = np.zeros(n_groups, dtype=np.bool_)
for g in range(n_groups):
if not is_penalized[g]:
gsupp[g] = True
continue

grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
if np.any(w[grp_g_indices]):
gsupp[g] = True

return gsupp
2 changes: 1 addition & 1 deletion skglm/solvers/cd_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
# X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X)

if alphas is None:
raise ValueError('alphas should be passed explicitely')
raise ValueError('alphas should be passed explicitly')
# if hasattr(penalty, "alpha_max"):
# if sparse.issparse(X):
# grad0 = construct_grad_sparse(
Expand Down
77 changes: 56 additions & 21 deletions skglm/solvers/group_bcd_solver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
from numba import njit

from skglm.utils import check_group_compatible

def bcd_solver(X, y, datafit, penalty, w_init=None,
max_iter=1000, max_epochs=100, tol=1e-7, verbose=False):

def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
max_iter=1000, max_epochs=100, tol=1e-4, verbose=False):
"""Run a group BCD solver.

Parameters
Expand All @@ -24,13 +26,16 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
Initial value of coefficients.
If set to None, a zero vector is used instead.

p0 : int, default 10
Minimum number of groups to be included in the working set.

max_iter : int, default 1000
Maximum number of iterations.

max_epochs : int, default 100
Maximum number of epochs.

tol : float, default 1e-6
tol : float, default 1e-4
Tolerance for convergence.

verbose : bool, default False
Expand All @@ -47,6 +52,9 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
stop_crit: float
The value of the stop criterion.
"""
check_group_compatible(datafit)
check_group_compatible(penalty)

n_features = X.shape[1]
n_groups = len(penalty.grp_ptr) - 1

Expand All @@ -56,51 +64,62 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
datafit.initialize(X, y)
all_groups = np.arange(n_groups)
p_objs_out = np.zeros(max_iter)
stop_crit = 0. # prevent ref before assign when max_iter == 0
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved

for t in range(max_iter):
if t == 0: # avoid computing p_obj twice
prev_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
if t == 0: # avoid computing grad and opt twice
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
opt = penalty.subdiff_distance(w, grad, all_groups)
stop_crit = np.max(opt)

if stop_crit <= tol:
break

gsupp_size = penalty.generalized_support(w).sum()
ws_size = max(min(p0, n_groups),
min(n_groups, 2 * gsupp_size))
ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items (no sort)

for epoch in range(max_epochs):
_bcd_epoch(X, y, w, Xw, datafit, penalty, all_groups)
_bcd_epoch(X, y, w, Xw, datafit, penalty, ws)

if epoch % 10 == 0:
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
stop_crit_in = prev_p_obj - current_p_obj
grad_ws = _construct_grad(X, y, w, Xw, datafit, ws)
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
stop_crit_in = np.max(opt_in)

if max(verbose - 1, 0):
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
print(
f"Epoch {epoch+1}: {current_p_obj:.10f} "
f"Epoch {epoch+1}: {p_obj:.10f} "
f"obj. variation: {stop_crit_in:.2e}"
)

if stop_crit_in <= tol:
print("Early exit")
if stop_crit_in <= 0.3 * stop_crit:
break
prev_p_obj = current_p_obj

current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
stop_crit = prev_p_obj - current_p_obj
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
opt = penalty.subdiff_distance(w, grad, all_groups)
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
stop_crit = np.max(opt)

if max(verbose, 0):
if verbose:
print(
f"Iteration {t+1}: {current_p_obj:.10f}, "
f"stopping crit: {stop_crit:.2f}"
f"Iteration {t+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit <= tol:
print("Outer solver: Early exit")
break

prev_p_obj = current_p_obj
p_objs_out[t] = current_p_obj
p_objs_out[t] = p_obj

return w, p_objs_out, stop_crit


@njit
def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
"""Perform a single BCD epoch on groups in ws."""
# perform a single BCD epoch on groups in ws
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices

for g in ws:
Expand All @@ -119,3 +138,19 @@ def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
if old_w_g[idx] != w[j]:
Xw += (w[j] - old_w_g[idx]) * X[:, j]
return


@njit
def _construct_grad(X, y, w, Xw, datafit, ws):
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
# compute the -gradient according to each group in ws
# note: -gradients are stacked in a 1d array ([-grad_ws_1, -grad_ws_2, ...])
grp_ptr = datafit.grp_ptr
n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws])

grads = np.zeros(n_features_ws)
grad_ptr = 0
for g in ws:
grad_g = datafit.gradient_g(X, y, w, Xw, g)
grads[grad_ptr: grad_ptr+len(grad_g)] = -grad_g
grad_ptr += len(grad_g)
return grads
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 11 additions & 0 deletions skglm/tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from numpy.linalg import norm

from skglm.penalties import L1
from skglm.datafits import Quadratic
from skglm.penalties.block_separable import WeightedGroupL2
from skglm.datafits.group import QuadraticGroup
from skglm.solvers.group_bcd_solver import bcd_solver
Expand All @@ -26,6 +28,15 @@ def _generate_random_grp(n_groups, n_features, shuffle=True):
return grp_indices, splits, groups


def test_check_group_compatible():
l1_penalty = L1(1e-3)
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
quad_datafit = Quadratic()
X, y = np.random.randn(5, 5), np.random.randn(5)

with np.testing.assert_raises(Exception):
bcd_solver(X, y, quad_datafit, l1_penalty)


@pytest.mark.parametrize("n_groups, n_features, shuffle",
[[10, 50, True], [10, 50, False], [17, 53, False]])
def test_alpha_max(n_groups, n_features, shuffle):
Expand Down
20 changes: 20 additions & 0 deletions skglm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,23 @@ def grp_converter(groups, n_features):
else:
raise ValueError("Unsupported group format.")
return grp_indices.astype(np.int32), grp_ptr.astype(np.int32)


def check_group_compatible(obj):
"""Check whether ``obj`` is compatible with ``bcd_solver``.

Parameters
----------
obj : instance of BaseDatafit or BasePenalty
Object to check.
"""
obj_name = obj.__class__.__name__
group_attrs = ('grp_ptr', 'grp_indices')

for attr in group_attrs:
if not hasattr(obj, attr):
raise Exception(
f"datafit and penalty must be compatible with 'bcd_solver'.\n"
f"'{obj_name}' is not block-separable. "
f"Missing '{attr}' attribute."
)