Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions skglm/datafits/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ class QuadraticGroup(BaseDatafit):
grp_ptr : array, shape (n_groups + 1,)
The group pointers such that two consecutive elements delimit
the indices of a group in ``grp_indices``.

lipschitz : array, shape (n_groups,)
The lipschitz constants for each group.
"""

def __init__(self, grp_ptr, grp_indices):
Expand All @@ -34,15 +31,14 @@ def get_spec(self):
spec = (
('grp_ptr', int32[:]),
('grp_indices', int32[:]),
('lipschitz', float64[:])
)
return spec

def params_to_dict(self):
return dict(grp_ptr=self.grp_ptr,
grp_indices=self.grp_indices)

def initialize(self, X, y):
def get_lipschitz(self, X, y):
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices
n_groups = len(grp_ptr) - 1

Expand All @@ -52,7 +48,7 @@ def initialize(self, X, y):
X_g = X[:, grp_g_indices]
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)

self.lipschitz = lipschitz
return lipschitz

def value(self, y, w, Xw):
return norm(y - Xw) ** 2 / (2 * len(y))
Expand Down
8 changes: 5 additions & 3 deletions skglm/solvers/group_bcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
raise ValueError(val_error_message)

datafit.initialize(X, y)
lipschitz = datafit.get_lipschitz(X, y)

all_groups = np.arange(n_groups)
p_objs_out = np.zeros(self.max_iter)
stop_crit = 0. # prevent ref before assign when max_iter == 0
Expand Down Expand Up @@ -100,7 +102,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):

for epoch in range(self.max_epochs):
# inplace update of w and Xw
_bcd_epoch(X, y, w[:n_features], Xw, datafit, penalty, ws)
_bcd_epoch(X, y, w[:n_features], Xw, lipschitz, datafit, penalty, ws)

# update intercept
if self.fit_intercept:
Expand Down Expand Up @@ -140,15 +142,15 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):


@njit
def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
def _bcd_epoch(X, y, w, Xw, lipschitz, datafit, penalty, ws):
# perform a single BCD epoch on groups in ws
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices

for g in ws:
grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
old_w_g = w[grp_g_indices].copy()

lipschitz_g = datafit.lipschitz[g]
lipschitz_g = lipschitz[g]
grad_g = datafit.gradient_g(X, y, w, Xw, g)

w[grp_g_indices] = penalty.prox_1group(
Expand Down