Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH flexible gram solver with penalty and using datafit #16

Closed

Conversation

mathurinm
Copy link
Collaborator

This is a smaller version of #4 : only without groups, but reusing more code and supporting any penalty

@PABannier
Copy link
Collaborator

@mathurinm ready for a quick review ;)

skglm/solvers/cd_solver.py Outdated Show resolved Hide resolved
if (isinstance(datafit, (Quadratic, Quadratic_32)) and n_samples > n_features
and n_features < 10_000) or solver in ("cd_gram", "fista"):
# Gram matrix must fit in memory hence the restriction n_features < 1e5
if not isinstance(datafit, (Quadratic, Quadratic_32)):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this bit is unreachable because the check is already performed L155

Copy link
Collaborator

@PABannier PABannier May 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've placed it because the first condition is "isinstance.... OR solver in ...". If the user manually inputs "cd_gram", I think we enter the if statement and I want to catch a wrong datafit, hence L158. Overkill maybe? Should we even expose solver? I think it is convenient for benchmarks.
WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I understood, thanks.
Maybe we can indent the first if, breaking line before or solver to make it more visible ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to make it more obvious. WDYT?


coefs : array, shape (n_features, n_alphas)
Coefficients along the path.
obj_out : array, shape (n_iter,)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really return this? or the optimality condition violation instead

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do return this. See L371.

skglm/solvers/gram.py Outdated Show resolved Hide resolved
skglm/solvers/gram.py Outdated Show resolved Hide resolved
skglm/solvers/gram.py Outdated Show resolved Hide resolved
skglm/solvers/gram.py Outdated Show resolved Hide resolved


@njit
def prox_vec(penalty, z, stepsize, n_features):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arf, I though we had penalty.prox

make this function private, remove n_features (access as z.shape[1])

we need a reflection on solvers, but probably all penalties will need to implement it. We can do so in basepenalty, but I fear looping over all coordinates will be slower than performing it in one step as ST_vec does

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [16]: %%time
    ...: out = _prox_vec(pen, z, 0.01)
CPU times: user 28 µs, sys: 1e+03 ns, total: 29 µs
Wall time: 34.1 µs

In [17]: %%time
    ...: out2 = ST_vec(z, 0.01)
CPU times: user 23 µs, sys: 0 ns, total: 23 µs
Wall time: 25.7 µs

not a big difference, from my experiments. I tried with different thresholds.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with @QB3 we had an issue a while ago on flashcd with finance where this caused a big overhead. Just to keep it in mind

skglm/solvers/gram.py Outdated Show resolved Hide resolved
PABannier and others added 5 commits May 11, 2022 09:53
Co-authored-by: mathurinm <mathurinm@users.noreply.github.com>
Co-authored-by: mathurinm <mathurinm@users.noreply.github.com>
skglm/solvers/cd_solver.py Outdated Show resolved Hide resolved
@mathurinm mathurinm requested a review from Klopfe May 13, 2022 08:37
@PABannier PABannier requested a review from QB3 May 16, 2022 20:49
Copy link
Collaborator

@PABannier PABannier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.
Tests are missing for the solvers though, I can write some if needed.

@@ -52,6 +56,9 @@ def cd_solver_path(X, y, datafit, penalty, alphas=None,
return_n_iter : bool, optional
If True, number of iterations along the path are returned.

solver : ('cd_ws'|'cd_gram'|'fista'), optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FISTA is not a CD solver, it's confusing to expose it to the user like this.
@mathurinm WDYT?

@njit
def _cd_epoch_gram(XtX, grad, w, datafit, penalty, n_samples, n_features):
lc = datafit.lipschitz
for j in range(n_features):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we have complete access to grad at each iteration, it would be interesting to use a greedy selection rule here: do not pick j cyclically, but instead take j = np.argmax(np.abs(grad))

One "epoch" in this setting would only be the update of n_features coordinates.

@Badr-MOUFAD
Copy link
Collaborator

closing in favor of #59

@mathurinm mathurinm deleted the gram_penalty_nogroup branch August 26, 2022 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants