diff --git a/skglm/datafits/group.py b/skglm/datafits/group.py index ab941f311..383d3206f 100644 --- a/skglm/datafits/group.py +++ b/skglm/datafits/group.py @@ -22,9 +22,6 @@ class QuadraticGroup(BaseDatafit): grp_ptr : array, shape (n_groups + 1,) The group pointers such that two consecutive elements delimit the indices of a group in ``grp_indices``. - - lipschitz : array, shape (n_groups,) - The lipschitz constants for each group. """ def __init__(self, grp_ptr, grp_indices): @@ -34,7 +31,6 @@ def get_spec(self): spec = ( ('grp_ptr', int32[:]), ('grp_indices', int32[:]), - ('lipschitz', float64[:]) ) return spec @@ -42,7 +38,7 @@ def params_to_dict(self): return dict(grp_ptr=self.grp_ptr, grp_indices=self.grp_indices) - def initialize(self, X, y): + def get_lipschitz(self, X, y): grp_ptr, grp_indices = self.grp_ptr, self.grp_indices n_groups = len(grp_ptr) - 1 @@ -52,7 +48,7 @@ def initialize(self, X, y): X_g = X[:, grp_g_indices] lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) - self.lipschitz = lipschitz + return lipschitz def value(self, y, w, Xw): return norm(y - Xw) ** 2 / (2 * len(y)) diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index 7824e6b63..e005c8276 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -67,6 +67,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): raise ValueError(val_error_message) datafit.initialize(X, y) + lipschitz = datafit.get_lipschitz(X, y) + all_groups = np.arange(n_groups) p_objs_out = np.zeros(self.max_iter) stop_crit = 0. # prevent ref before assign when max_iter == 0 @@ -100,7 +102,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): for epoch in range(self.max_epochs): # inplace update of w and Xw - _bcd_epoch(X, y, w[:n_features], Xw, datafit, penalty, ws) + _bcd_epoch(X, y, w[:n_features], Xw, lipschitz, datafit, penalty, ws) # update intercept if self.fit_intercept: @@ -140,7 +142,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): @njit -def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws): +def _bcd_epoch(X, y, w, Xw, lipschitz, datafit, penalty, ws): # perform a single BCD epoch on groups in ws grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices @@ -148,7 +150,7 @@ def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws): grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]] old_w_g = w[grp_g_indices].copy() - lipschitz_g = datafit.lipschitz[g] + lipschitz_g = lipschitz[g] grad_g = datafit.gradient_g(X, y, w, Xw, g) w[grp_g_indices] = penalty.prox_1group(