Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[READY] ENH - add working set strategy to group_bcd_solver #28

Merged
merged 25 commits into from
Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 26 additions & 2 deletions skglm/penalties/block_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,13 @@ def subdiff_distance(self, w, grad, ws):
grp_ptr, grp_indices = self.grp_ptr, self.grp_indices

scores = np.zeros(len(ws))
grad_ptr = 0
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
for idx, g in enumerate(ws):
grad_g = grad[idx]

grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]

grad_g = grad[grad_ptr: grad_ptr + len(grp_g_indices)]
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
grad_ptr += len(grp_g_indices)

w_g = w[grp_g_indices]
norm_w_g = norm(w_g)
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -232,3 +235,24 @@ def subdiff_distance(self, w, grad, ws):
scores[idx] = norm(grad_g - subdiff)

return scores

def is_penalized(self, n_groups):
return np.ones(n_groups, dtype=np.bool_)

def generalized_support(self, w):
grp_indices, grp_ptr = self.grp_indices, self.grp_ptr
n_groups = len(grp_ptr) - 1
is_penalized = self.is_penalized(n_groups)

gsupp = np.zeros(n_groups, dtype=np.bool_)
for g in range(n_groups):
if not is_penalized[g]:
gsupp[g] = True
continue

grp_g_indices = grp_indices[grp_ptr[g]: grp_ptr[g+1]]
w_g = w[grp_g_indices]
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
if np.any(w_g):
gsupp[g] = True

return gsupp
2 changes: 1 addition & 1 deletion skglm/solvers/cd_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
# X_dense, X_data, X_indices, X_indptr = _sparse_and_dense(X)

if alphas is None:
raise ValueError('alphas should be passed explicitely')
raise ValueError('alphas should be passed explicitly')
# if hasattr(penalty, "alpha_max"):
# if sparse.issparse(X):
# grad0 = construct_grad_sparse(
Expand Down
59 changes: 47 additions & 12 deletions skglm/solvers/group_bcd_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numba import njit


def bcd_solver(X, y, datafit, penalty, w_init=None,
def bcd_solver(X, y, datafit, penalty, w_init=None, p0=10,
max_iter=1000, max_epochs=100, tol=1e-7, verbose=False):
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
"""Run a group BCD solver.

Expand All @@ -24,6 +24,9 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
Initial value of coefficients.
If set to None, a zero vector is used instead.

p0 : int, default 2
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
Minimum number of groups to be included in the working set.

max_iter : int, default 1000
Maximum number of iterations.

Expand Down Expand Up @@ -56,51 +59,65 @@ def bcd_solver(X, y, datafit, penalty, w_init=None,
datafit.initialize(X, y)
all_groups = np.arange(n_groups)
p_objs_out = np.zeros(max_iter)
stop_crit = 0. # prevent ref before assign when max_iter == 0
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved

for t in range(max_iter):
if t == 0: # avoid computing p_obj twice
prev_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
if t == 0: # avoid computing grad and opt twice
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
opt = penalty.subdiff_distance(w, grad, all_groups)
stop_crit = np.max(opt)

if stop_crit <= tol:
print("Outer solver: Early exit")
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
break

gsupp_size = penalty.generalized_support(w).sum()
ws_size = max(min(p0, n_groups),
min(n_groups, 2 * gsupp_size))
ws = np.argpartition(opt, -ws_size)[-ws_size:] # k-largest items [no sort]

for epoch in range(max_epochs):
_bcd_epoch(X, y, w, Xw, datafit, penalty, all_groups)
_bcd_epoch(X, y, w, Xw, datafit, penalty, ws)

if epoch % 10 == 0:
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
stop_crit_in = prev_p_obj - current_p_obj
grad_ws = _construct_grad(X, y, w, Xw, datafit, ws)
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
stop_crit_in = np.max(opt_in)

if max(verbose - 1, 0):
current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
print(
f"Epoch {epoch+1}: {current_p_obj:.10f} "
f"obj. variation: {stop_crit_in:.2e}"
)

if stop_crit_in <= tol:
if stop_crit_in <= 0.3 * stop_crit:
print("Early exit")
break
prev_p_obj = current_p_obj

current_p_obj = datafit.value(y, w, Xw) + penalty.value(w)
stop_crit = prev_p_obj - current_p_obj
grad = _construct_grad(X, y, w, Xw, datafit, all_groups)
opt = penalty.subdiff_distance(w, grad, all_groups)
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
stop_crit = np.max(opt)

if max(verbose, 0):
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
print(
f"Iteration {t+1}: {current_p_obj:.10f}, "
f"stopping crit: {stop_crit:.2f}"
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit <= tol:
print("Outer solver: Early exit")
break

prev_p_obj = current_p_obj
p_objs_out[t] = current_p_obj

return w, p_objs_out, stop_crit


@njit
def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
"""Perform a single BCD epoch on groups in ws."""
"""Perform a single BCD epoch on groups in ``ws``."""
grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices

for g in ws:
Expand All @@ -119,3 +136,21 @@ def _bcd_epoch(X, y, w, Xw, datafit, penalty, ws):
if old_w_g[idx] != w[j]:
Xw += (w[j] - old_w_g[idx]) * X[:, j]
return


@njit
def _construct_grad(X, y, w, Xw, datafit, ws):
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the -gradient according to each group in ``ws``.

Note: -gradients are stacked in a 1d array (e.g. [-grad_g1, -grad_g2, ...]).
"""
grp_ptr = datafit.grp_ptr
n_features_ws = sum([grp_ptr[g+1] - grp_ptr[g] for g in ws])

grads = np.zeros(n_features_ws)
grad_ptr = 0
for g in ws:
grad_g = datafit.gradient_g(X, y, w, Xw, g)
grads[grad_ptr: grad_ptr+len(grad_g)] = -grad_g
grad_ptr += len(grad_g)
return grads
Badr-MOUFAD marked this conversation as resolved.
Show resolved Hide resolved