New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
optim.LBFGS convergence problem for batch function minimization #49993
Comments
cc @fehiepsi, perhaps |
I am not sure what is the expected behavior. It makes more sense to me that they are different. Did you find the same behavior in other frameworks? |
Thanks for your feedback @fehiepsi The expected behavior is that torch.optim converges to the minimum of the Rosenbrock function, as jax.scipy.optimize does in the script below, but torch.optim does not. Note, however, that in jax.scipy.optimize I am using BFGS, instead of L-BFGS, since the latter is not available in jax.scipy.optimize yet. This is the output of the script:
Thanks |
Thanks, @joacorapela! I did some quick tests by changing the last element of
In all of my tests, PyTorch version performs better, either in float32 or float64 modes (probably because of the limited memory feature?). PyTorch does a pretty good job in float64. I think the algorithm is not intended to use in float32. For example, The algorithm is pretty complicated to track for implementation details and to decide which options are good or bad. Back then, in a PR for line search, I tried to stick as close as possible to minFunc. I didn't play much with the theory and just followed the references so it is hard for me to say if running lbfgs multiple times (which is intended for stochastic training) has a similar performance as running lbfgs one time... cc @Joshuaalbert who might be interested in improving the performance of jax bfgs. |
Hello @fehiepsi, I repeated my analysis with different initial conditions, as you suggested, but obtained very different results. I am appending my evaluation code, in case the difference in my results is due to a mistake in the use of Pytorch, which I doubt. Also, if you can, please post your evaluation code and I will study why our evaluations are different. I will comment only on results using double precision. Briefly, Pytorch did not converge using 1e2, 1e3 or 1e4 as the last element of x0, while JAX converged converged using 1e2 and 1e4 but not using 1e3. Pytorch was faster than JAX for 1e2 but slower for 1e4. The runtime of the two libraries was similar for 1e3. Thanks for your help PS: I hope we can fix this convergence problem of PyTorch, as I have written quite a lot of code based on PyTorch and I would prefer not to have to migrate to JAX or scipy.optimize. Here is the output of the code below:
and here is the script
cc @Joshuaalbert in case he can help |
Here you are
PyTorch converges (with tol 1e-4) in all 3 cases. You might think that 1e-4 is not a good precision. I don't have an intuition on what is a good precision here. |
Very helpful hints @fehiepsi. To summarize, the convergence problem was due to using a float initial condition instead of double. Convergence within 1e-4 is good for me. Thanks a lot. I will close this issue. |
馃悰 Bug
We can (batch) minimize a function with LBFGS in two (equivalent?) ways:
I expected the two approaches to give similar results. But they don't, as shown next.
I used LBFGS to minimize the Rosenbrock function in six dimensions with starting points closer and further from the unique global minimum (at x=[1.0,1.0, 1.0, 1.0, 1.0, 1.0]). With the closer starting point the two approaches converged to the minimum with a similar number of function calls. However, with the further starting point the approach using only one step() call did not
converged, while the approach using multiple step() calls converged (code below).
Is this an expected behavior or a bug?
lbfgs.py code details
I looked into lbfgs.py and saw that, with the further starting point and only one step() call, the algorithm returned after a few iterations because the loss did not decrease sufficiently in successive iterations (line 458):
Since LBFGS keeps a state between different calls to step(), I thought that a new step() call after the algorithm had reach this state would do nothing. But this was not the case. When you call step() again LBFGS does a new line
search, with a new learning rate, which can lead to a loss decrease, so the algorithm can continue (line 397).
Code
In the code below I use two starting points. For each of them I call optim.step() once (i.e., one epoch) with a large max_iter and then call optim.step() several times (i.e., multiple epochs) with a smaller max_iter. For each call I print if optim converged to the global minimum, the number of function calls and the number of internal LBFGS iterations.
The code illustrates that only for the further starting point (x0=[76.0, 97.0, 20.0, 20.0, 0.01, 10000.0]) the one epoch call failed to approximate the global minimum, but the multiple epochs one succeeded.
Thanks
cc @vincentqb
The text was updated successfully, but these errors were encountered: