diff --git a/skglm/penalties/block_separable.py b/skglm/penalties/block_separable.py index 0a9bd2743..dc9029147 100644 --- a/skglm/penalties/block_separable.py +++ b/skglm/penalties/block_separable.py @@ -212,16 +212,23 @@ def prox_1group(self, value, stepsize, g): """Compute the proximal operator of group ``g``.""" return BST(value, self.alpha * stepsize * self.weights[g]) - def subdiff_distance(self, w, grad, ws): - """Compute distance of negative gradient to the subdifferential at ``w``.""" + def subdiff_distance(self, w, grad_ws, ws): + """Compute distance to the subdifferential at ``w`` of negative gradient. + + Note: ``grad_ws`` is a stacked array of ``-``gradients. + ([-grad_ws_1, -grad_ws_2, ...]) + """ alpha, weights = self.alpha, self.weights grp_ptr, grp_indices = self.grp_ptr, self.grp_indices scores = np.zeros(len(ws)) + grad_ptr = 0 for idx, g in enumerate(ws): - grad_g = grad[idx] - grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] + + grad_g = grad_ws[grad_ptr: grad_ptr + len(grp_g_indices)] + grad_ptr += len(grp_g_indices) + w_g = w[grp_g_indices] norm_w_g = norm(w_g) @@ -232,3 +239,23 @@ def subdiff_distance(self, w, grad, ws): scores[idx] = norm(grad_g - subdiff) return scores + + def is_penalized(self, n_groups): + return np.ones(n_groups, dtype=np.bool_) + + def generalized_support(self, w): + grp_indices, grp_ptr = self.grp_indices, self.grp_ptr + n_groups = len(grp_ptr) - 1 + is_penalized = self.is_penalized(n_groups) + + gsupp = np.zeros(n_groups, dtype=np.bool_) + for g in range(n_groups): + if not is_penalized[g]: + gsupp[g] = True + continue + + grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] + if np.any(w[grp_g_indices]): + gsupp[g] = True + + return gsupp diff --git a/skglm/solvers/cd_solver.py b/skglm/solvers/cd_solver.py index 2ab16cfdb..83584ee1e 100644 --- a/skglm/solvers/cd_solver.py +++ b/skglm/solvers/cd_solver.py @@ -92,7 +92,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None, # X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X) if alphas is None: - raise ValueError('alphas should be passed explicitely') + raise ValueError('alphas should be passed explicitly') # if hasattr(penalty, "alpha_max"): # if sparse.issparse(X): # grad0 = construct_grad_sparse( diff --git a/skglm/solvers/group_bcd_solver.py b/skglm/solvers/group_bcd_solver.py index 7c6a6a5e6..941a60b21 100644 --- a/skglm/solvers/group_bcd_solver.py +++ b/skglm/solvers/group_bcd_solver.py @@ -1,9 +1,11 @@ import numpy as np from numba import njit +from skglm.utils import check_group_compatible -def bcd_solver(X, y, datafit, penalty, w_init=None, - max_iter=1000, max_epochs=100, tol=1e-7, verbose=False): + +def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10, + max_iter=1000, max_epochs=100, tol=1e-4, verbose=False): """Run a group BCD solver. Parameters @@ -24,13 +26,16 @@ def bcd_solver(X, y, datafit, penalty, w_init=None, Initial value of coefficients. If set to None, a zero vector is used instead. + p0 : int, default 10 + Minimum number of groups to be included in the working set. + max_iter : int, default 1000 Maximum number of iterations. max_epochs : int, default 100 Maximum number of epochs. - tol : float, default 1e-6 + tol : float, default 1e-4 Tolerance for convergence. verbose : bool, default False @@ -47,6 +52,9 @@ def bcd_solver(X, y, datafit, penalty, w_init=None, stop_crit: float The value of the stop criterion. """ + check_group_compatible(datafit) + check_group_compatible(penalty) + n_features = X.shape[1] n_groups = len(penalty.grp_ptr) - 1 @@ -56,51 +64,62 @@ def bcd_solver(X, y, datafit, penalty, w_init=None, datafit.initialize(X, y) all_groups = np.arange(n_groups) p_objs_out = np.zeros(max_iter) + stop_crit = 0. # prevent ref before assign when max_iter == 0 for t in range(max_iter): - if t == 0: # avoid computing p_obj twice - prev_p_obj = datafit.value(y, w, Xw) + penalty.value(w) + if t == 0: # avoid computing grad and opt twice + grad = _construct_grad(X, y, w, Xw, datafit, all_groups) + opt = penalty.subdiff_distance(w, grad, all_groups) + stop_crit = np.max(opt) + + if stop_crit <= tol: + break + + gsupp_size = penalty.generalized_support(w).sum() + ws_size = max(min(p0, n_groups), + min(n_groups, 2 * gsupp_size)) + ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items (no sort) for epoch in range(max_epochs): - _bcd_epoch(X, y, w, Xw, datafit, penalty, all_groups) + _bcd_epoch(X, y, w, Xw, datafit, penalty, ws) if epoch % 10 == 0: - current_p_obj = datafit.value(y, w, Xw) + penalty.value(w) - stop_crit_in = prev_p_obj - current_p_obj + grad_ws = _construct_grad(X, y, w, Xw, datafit, ws) + opt_in = penalty.subdiff_distance(w, grad_ws, ws) + stop_crit_in = np.max(opt_in) if max(verbose - 1, 0): + p_obj = datafit.value(y, w, Xw) + penalty.value(w) print( - f"Epoch {epoch+1}: {current_p_obj:.10f} " + f"Epoch {epoch+1}: {p_obj:.10f} " f"obj. variation: {stop_crit_in:.2e}" ) - if stop_crit_in <= tol: - print("Early exit") + if stop_crit_in <= 0.3 * stop_crit: break - prev_p_obj = current_p_obj - current_p_obj = datafit.value(y, w, Xw) + penalty.value(w) - stop_crit = prev_p_obj - current_p_obj + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + grad = _construct_grad(X, y, w, Xw, datafit, all_groups) + opt = penalty.subdiff_distance(w, grad, all_groups) + stop_crit = np.max(opt) - if max(verbose, 0): + if verbose: print( - f"Iteration {t+1}: {current_p_obj:.10f}, " - f"stopping crit: {stop_crit:.2f}" + f"Iteration {t+1}: {p_obj:.10f}, " + f"stopping crit: {stop_crit:.2e}" ) if stop_crit <= tol: - print("Outer solver: Early exit") break - prev_p_obj = current_p_obj - p_objs_out[t] = current_p_obj + p_objs_out[t] = p_obj return w, p_objs_out, stop_crit @njit def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws): - """Perform a single BCD epoch on groups in ws.""" + # perform a single BCD epoch on groups in ws grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices for g in ws: @@ -119,3 +138,19 @@ def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws): if old_w_g[idx] != w[j]: Xw += (w[j] - old_w_g[idx]) * X[:, j] return + + +@njit +def _construct_grad(X, y, w, Xw, datafit, ws): + # compute the -gradient according to each group in ws + # note: -gradients are stacked in a 1d array ([-grad_ws_1, -grad_ws_2, ...]) + grp_ptr = datafit.grp_ptr + n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws]) + + grads = np.zeros(n_features_ws) + grad_ptr = 0 + for g in ws: + grad_g = datafit.gradient_g(X, y, w, Xw, g) + grads[grad_ptr: grad_ptr+len(grad_g)] = -grad_g + grad_ptr += len(grad_g) + return grads diff --git a/skglm/tests/test_group.py b/skglm/tests/test_group.py index 794b05cad..81ec1c2e3 100644 --- a/skglm/tests/test_group.py +++ b/skglm/tests/test_group.py @@ -2,6 +2,8 @@ import numpy as np from numpy.linalg import norm +from skglm.penalties import L1 +from skglm.datafits import Quadratic from skglm.penalties.block_separable import WeightedGroupL2 from skglm.datafits.group import QuadraticGroup from skglm.solvers.group_bcd_solver import bcd_solver @@ -26,6 +28,15 @@ def _generate_random_grp(n_groups, n_features, shuffle=True): return grp_indices, splits, groups +def test_check_group_compatible(): + l1_penalty = L1(1e-3) + quad_datafit = Quadratic() + X, y = np.random.randn(5, 5), np.random.randn(5) + + with np.testing.assert_raises(Exception): + bcd_solver(X, y, quad_datafit, l1_penalty) + + @pytest.mark.parametrize("n_groups, n_features, shuffle", [[10, 50, True], [10, 50, False], [17, 53, False]]) def test_alpha_max(n_groups, n_features, shuffle): diff --git a/skglm/utils.py b/skglm/utils.py index b4751d637..d9ff10207 100644 --- a/skglm/utils.py +++ b/skglm/utils.py @@ -234,3 +234,23 @@ def grp_converter(groups, n_features): else: raise ValueError("Unsupported group format.") return grp_indices.astype(np.int32), grp_ptr.astype(np.int32) + + +def check_group_compatible(obj): + """Check whether ``obj`` is compatible with ``bcd_solver``. + + Parameters + ---------- + obj : instance of BaseDatafit or BasePenalty + Object to check. + """ + obj_name = obj.__class__.__name__ + group_attrs = ('grp_ptr', 'grp_indices') + + for attr in group_attrs: + if not hasattr(obj, attr): + raise Exception( + f"datafit and penalty must be compatible with 'bcd_solver'.\n" + f"'{obj_name}' is not block-separable. " + f"Missing '{attr}' attribute." + )