Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qs: Mapping different solvers to leaves, parameter normalisation, different parameter scales #20

Closed
LouisDesdoigts opened this issue Nov 1, 2023 · 2 comments
Labels
question User queries

Comments

@LouisDesdoigts
Copy link

Hey, first up another awesome package in the Jax eco-system! I've been meaning to incorporate these kind of solvers in my work for a long time, so thanks for for making it easy 馃槢. This is partly a discussion post as I am relatively unfamiliar with these algorithms, I have done my best to parse the docs in detail, but feel free to point me to any external resources as I would love to learn more.


Mapping solvers to different leaves

Is there a way that we can map different solvers to each leaf of a pytree? Lets say we know one parameter will be initialised in the smooth bowl of the loss space and can be solved with BFGS, but the other parameter has a 'noisy' loss topology and is best tackled with a regular optax optimiser. This is actually quite typical for the sort of optical models I work with, although not super common in general AFAIK.

It is simple to apply each of these algorithms one at a time to each pytree leaf with eqx.partition and eqx.combine. This approach works but can't 'jointly' optimise these leaves and would result in redundant calculation of the model gradients, since the grads from each evaluation could be passed to both algorithms.

Now I recognise that a 'joint' approach would pose a problem for algorithms like BFGS since it would be trying to find the minimum of a dynamic topology that changes as the other leaves are updated throughout the optimisation. I would be curious as to what you think might be the right approach to this kind of problem, maybe there are solvers designed for this sort of problem? If not what approach might you take, I'm very excited about the flexibility and extensibility of this software to be able to build out much better custom solvers for my rather niche set of optimisation problems.


Parameter normalisation during the solve loop

So during a gradient descent loop we commonly need to apply some normalisation/regularisation to our updated parameters to ensure they are 'physical'. An example would be normalising relative spectral weights to have a sum of 1 after the updates have been applied. I am wondering if there is a way to enforce these constraints during the solve. The simplest example case here would be preventing some values from being above some threshold.

I would guess this would likely be possible through a user-defined solver class, that applies the custom regularisation. If something like this is possible, how would it be implemented? From a crude look at the code it looks like this could be done within the step function of the AbstractGradientDescent class?


Parameters with large scale variation

So this one is more of an open discussion, rather than a specific question. It's very common for the models I work with to have vastly different scales (everything from 1e10 to 1e-10). This is a problem for these algorithms in general, so I was hoping to get your thoughts on what would be the right way to approach a solution.

There is the 'naive' solution where you apply a scaling to each parameter of the pytree before passing it into the minimisation function, and then inverting the scaling once inside the function. Now this works but is far from what I would consider ideal as it still requires a degree of prior knowledge of the model and sort of just kicks the tunable hyper-parameter from a learning rate into a scaling. Granted this is still going to be generally more robust, but I feel like there is something more elegant... I'm wondering if you have any thoughts or ideas about this!


Anyway thanks again for the excellent software and the help!

@packquickly
Copy link
Collaborator

Thanks for the excellent questions!

For mapping solvers to different leaves, it sounds like what you're trying to do is to take the parameters you're optimising over, let's call them $P$, and partition this into two sets of parameters $P_1$ and $P_2$. Then, you'd like to apply one optimisation algorithm to $P_1$ and another to $P_2$. BFGS does has some theoretical issues in this regime as you pointed out, but I don't see why there should be redundant computation if the solvers are not coupled. BFGS will simply not compute gradients for $P_2$ and Optax will not compute gradients with respect to $P_1$.

Setting up the solves to be done in tandem should be straightforward, if a little tedious. Roughly following the example of interactive solving in the docs, I was able to throw together a toy example in a gist which did the tandem solving for a test problem from CUTE. Let me know if you meant something else! You could also define a custom solver to do this using the optx.minimise/optx.least_squares APIs, but it's a bit more work.

For parameter normalisation, you could also either use the method of interactively stepping through the solve like above, and normalise y at each step, or create a custom solver which does this. Take a look at the implementation of the BestSoFarMinimiser for an idea of how to implement this as a custom solver if you wanted to take that approach. You shouldn't need to rewrite the main logic of each algorithm, but just apply the normalisation after each update.

Finally, for handling parameters with large parameter scale variation, there's a number of techniques, but in my opinion it's an open problem. As far as I know, there is no "right" way to handle large variations in parameter scales. If you're using first-order techniques, then incorporating more curvature information/approximating a second order method seems to help (ie. Adam, Adagrad, and pals do this.)

When already using a second order optimiser, it's less obvious what to do. Section 8.3 of Training Deep and Recurrent Networks with Hessian-Free Optimization has a great discussion of some approaches using Tikhnov dampening (Levenberg-Marquardt and related trust-region methods) that I've been meaning to try as experimental extensions to Optimistix. I'm also aware that Mark Transtrum has some work in this area that's far more theoretical.

@packquickly packquickly added the question User queries label Nov 1, 2023
@LouisDesdoigts
Copy link
Author

Awesome thanks for this detailed response!

Very much appreciate the small example for the tandem solving, this will make things much easier to pull together. It seems like you have built this package with this flexibility in mind and I'm very glad you have 馃槢.

in my opinion it's an open problem. As far as I know, there is no "right" way to handle large variations in parameter scales.

Okay I'm glad I haven't just been missing a body of work on this as it has been plaguing me for a while (for the most parts its adam to the rescue)! Thanks for the resources on this, I will tinker on this in the background and see if I can make of it.

Cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants