Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

optim.LBFGS convergence problem for batch function minimization #49993

Closed
joacorapela opened this issue Jan 1, 2021 · 7 comments
Closed

optim.LBFGS convergence problem for batch function minimization #49993

joacorapela opened this issue Jan 1, 2021 · 7 comments
Labels
module: optimizer Related to torch.optim needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@joacorapela
Copy link

joacorapela commented Jan 1, 2021

馃悰 Bug

We can (batch) minimize a function with LBFGS in two (equivalent?) ways:

  1. use a relatively large max_iter parameter value when constructing the optimizer and call optimizer.step() only once. For example:
optimizer = torch.optim.LBFGS(x, max_iter=200)
optimizer.step(closure)
  1. use a smaller max_iter parameter value when constructing the optimizer and call optimizer.step() multiple times. For example:
optimizer = torch.optim.LBFGS(x, max_iter=20)
for epoch in range(10):
   optimizer.step(closure)

I expected the two approaches to give similar results. But they don't, as shown next.

I used LBFGS to minimize the Rosenbrock function in six dimensions with starting points closer and further from the unique global minimum (at x=[1.0,1.0, 1.0, 1.0, 1.0, 1.0]). With the closer starting point the two approaches converged to the minimum with a similar number of function calls. However, with the further starting point the approach using only one step() call did not
converged, while the approach using multiple step() calls converged (code below).

Is this an expected behavior or a bug?

lbfgs.py code details

I looked into lbfgs.py and saw that, with the further starting point and only one step() call, the algorithm returned after a few iterations because the loss did not decrease sufficiently in successive iterations (line 458):

            if abs(loss - prev_loss) < tolerance_change:
                break

Since LBFGS keeps a state between different calls to step(), I thought that a new step() call after the algorithm had reach this state would do nothing. But this was not the case. When you call step() again LBFGS does a new line
search, with a new learning rate, which can lead to a loss decrease, so the algorithm can continue (line 397).

           if state['n_iter'] == 1:
               t = min(1., 1. / flat_grad.abs().sum()) * lr
           else:
               t = lr

Code

In the code below I use two starting points. For each of them I call optim.step() once (i.e., one epoch) with a large max_iter and then call optim.step() several times (i.e., multiple epochs) with a smaller max_iter. For each call I print if optim converged to the global minimum, the number of function calls and the number of internal LBFGS iterations.

The code illustrates that only for the further starting point (x0=[76.0, 97.0, 20.0, 20.0, 0.01, 10000.0]) the one epoch call failed to approximate the global minimum, but the multiple epochs one succeeded.

Results for x0= [76.0, 97.0, 20.0, 120.0, 0.01, 100.0]
Results for one epochs:
Converged: True
Function evaluations: 126
Iterations: 108

Results for multiple epochs:
Converged: True
Function evaluations: 135
Iterations: 116

Results for x0= [76.0, 97.0, 20.0, 120.0, 0.01, 10000.0]
Results for one epochs:
Converged: False
Function evaluations: 43
Iterations: 37

Results for multiple epochs:
Converged: True
Function evaluations: 492
Iterations: 392

import pdb
import copy
import scipy.optimize
import torch
 
def rosenbrock(x):
    answer = sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)
    return answer
 
evalFunc = rosenbrock
 
x0s = [[76.0, 97.0, 20.0, 120.0, 0.01, 1e2], [76.0, 97.0, 20.0, 120.0, 0.01, 1e4]]
trueMins = [torch.ones(len(x0)) for x0 in x0s]
 
toleranceGrad = 1e-5
toleranceChange = 1e-9
xConvTol = 1e-4
lineSearchFn = "strong_wolfe"
maxIterOneEpoch = 1000
maxIterMultipleEpochs = 100
nEpochs = 10
assert maxIterOneEpoch==nEpochs*maxIterMultipleEpochs
 
def closure():
    optimizer.zero_grad()
    curEval = evalFunc(x=x[0])
    curEval.backward(retain_graph=True)
    return curEval
 
for x0, trueMin in zip(x0s, trueMins):
    print("Results for x0=", x0)
 
    xOneEpoch = torch.tensor(copy.deepcopy(x0))
    xOneEpoch.requires_grad = True
    x = [xOneEpoch]
    optimizer = torch.optim.LBFGS(x, max_iter=maxIterOneEpoch, line_search_fn=lineSearchFn, tolerance_grad=toleranceGrad, tolerance_change=toleranceChange)
    optimizer.step(closure)
    stateOneEpoch = optimizer.state[optimizer._params[0]]
    funcEvalsOneEpoch = stateOneEpoch["func_evals"]
    nIterOneEpoch = stateOneEpoch["n_iter"]
    print("\tResults for one epochs:")
    print("\t\tConverged: {}".format(torch.norm(xOneEpoch-trueMin, p=2)<xConvTol))
    print("\t\tFunction evaluations: {:d}".format(funcEvalsOneEpoch))
    print("\t\tIterations: {:d}\n".format(nIterOneEpoch))
 
    xMultipleEpochs = torch.tensor(copy.deepcopy(x0))
    xMultipleEpochs.requires_grad = True
    x = [xMultipleEpochs]
    optimizer = torch.optim.LBFGS(x, max_iter=maxIterMultipleEpochs, line_search_fn=lineSearchFn, tolerance_grad=toleranceGrad, tolerance_change=toleranceChange)
    for epoch in range(nEpochs):
        optimizer.step(closure)
    stateMultipleEpochs = optimizer.state[optimizer._params[0]]
    funcEvalsMultipleEpochs = stateMultipleEpochs["func_evals"]
    nIterMultipleEpochs = stateMultipleEpochs["n_iter"]
    print("\tResults for multiple epochs:")
    print("\t\tConverged: {}".format(torch.norm(xMultipleEpochs-trueMin, p=2)<xConvTol))
    print("\t\tFunction evaluations: {:d}".format(funcEvalsMultipleEpochs))
    print("\t\tIterations: {:d}\n".format(nIterMultipleEpochs))
 

Thanks

cc @vincentqb

@ezyang ezyang added module: optimizer Related to torch.optim triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 5, 2021
@ezyang
Copy link
Contributor

ezyang commented Jan 5, 2021

cc @fehiepsi, perhaps

@ezyang ezyang added the needs research We need to decide whether or not this merits inclusion, based on research world label Jan 5, 2021
@fehiepsi
Copy link
Contributor

fehiepsi commented Jan 5, 2021

Is this an expected behavior or a bug?

I am not sure what is the expected behavior. It makes more sense to me that they are different. Did you find the same behavior in other frameworks?

@joacorapela
Copy link
Author

Thanks for your feedback @fehiepsi

The expected behavior is that torch.optim converges to the minimum of the Rosenbrock function, as jax.scipy.optimize does in the script below, but torch.optim does not.

Note, however, that in jax.scipy.optimize I am using BFGS, instead of L-BFGS, since the latter is not available in jax.scipy.optimize yet.

This is the output of the script:

Results for troch.optim:
Converged: False
Function evaluations: 43
Iterations: 37

Results for jax.scipy.otpimize:
Converged: True
Function evaluations: 351
Iterations: 210

import pdb
import copy
import scipy.optimize
import torch
import numpy as np
import jax.numpy as jnp
import jax.scipy.optimize
jax.config.update("jax_enable_x64", True)

def rosenbrock(x):
    answer = sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)
    return answer

evalFunc = rosenbrock

x0 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e4]
ptTrueMin = torch.ones(len(x0))
jnpTrueMin = jnp.ones(len(x0))

toleranceGrad = 1e-5
toleranceChange = 1e-9
xConvTol = 1e-6
lineSearchFn = "strong_wolfe"
maxIter = 1000

def closure():
    optimizer.zero_grad()
    curEval = evalFunc(x=x[0])
    curEval.backward(retain_graph=True)
    return curEval

xPT = torch.tensor(copy.deepcopy(x0))
xPT.requires_grad = True
x = [xPT]
optimizer = torch.optim.LBFGS(x, max_iter=maxIter, line_search_fn=lineSearchFn, tolerance_grad=toleranceGrad, tolerance_change=toleranceChange)
optimizer.step(closure)

statePT = optimizer.state[optimizer._params[0]]
funcEvalsPT = statePT["func_evals"]
nIterPT = statePT["n_iter"]
print("Results for troch.optim:")
print("\tConverged: {}".format(torch.norm(xPT-ptTrueMin, p=2)<xConvTol))
print("\tFunction evaluations: {:d}".format(funcEvalsPT))
print("\tIterations: {:d}\n".format(nIterPT))

minimizeOptions = {'gtol': toleranceGrad, 'maxiter': maxIter}
jx0 = jnp.array(x0)
optimRes = jax.scipy.optimize.minimize(fun=evalFunc, x0=jx0, method='BFGS', options=minimizeOptions)

print("Results for jax.scipy.otpimize:")
print("\tConverged: {}".format(jnp.linalg.norm(optimRes.x-jnpTrueMin, ord=2)<xConvTol))
print("\tFunction evaluations: {:d}".format(optimRes.nfev))
print("\tIterations: {:d}\n".format(optimRes.nit))

Thanks

@fehiepsi
Copy link
Contributor

fehiepsi commented Jan 7, 2021

Thanks, @joacorapela! I did some quick tests by changing the last element of x0 from 1e4 to 1e3 and 1e2 and found that

  • In float32, pytorch version somehow converges for 1e2 and 1e3 but not for 1e4; jax version does not converge at all
  • In float64, pytorch version converges for all; jax version only converges for 1e4

In all of my tests, PyTorch version performs better, either in float32 or float64 modes (probably because of the limited memory feature?). PyTorch does a pretty good job in float64. I think the algorithm is not intended to use in float32. For example, tolerance_change = 1e-9 does not make sense at all in float32. What do you think?

The algorithm is pretty complicated to track for implementation details and to decide which options are good or bad. Back then, in a PR for line search, I tried to stick as close as possible to minFunc. I didn't play much with the theory and just followed the references so it is hard for me to say if running lbfgs multiple times (which is intended for stochastic training) has a similar performance as running lbfgs one time...

cc @Joshuaalbert who might be interested in improving the performance of jax bfgs.

@joacorapela
Copy link
Author

Hello @fehiepsi,

I repeated my analysis with different initial conditions, as you suggested, but obtained very different results. I am appending my evaluation code, in case the difference in my results is due to a mistake in the use of Pytorch, which I doubt. Also, if you can, please post your evaluation code and I will study why our evaluations are different.

I will comment only on results using double precision. Briefly, Pytorch did not converge using 1e2, 1e3 or 1e4 as the last element of x0, while JAX converged converged using 1e2 and 1e4 but not using 1e3. Pytorch was faster than JAX for 1e2 but slower for 1e4. The runtime of the two libraries was similar for 1e3.

Thanks for your help

PS: I hope we can fix this convergence problem of PyTorch, as I have written quite a lot of code based on PyTorch and I would prefer not to have to migrate to JAX or scipy.optimize.

Here is the output of the code below:

x0=[76.0, 97.0, 20.0, 120.0, 0.01, 100.0]

Pytorch
Converged: False
Function evaluations: 126
Iterations: 105
ElapsedTime: 0.5592169761657715

jax.scipy.optimize
Converged: True
Function evaluations: 263
Iterations: 148
ElapsedTime: 1.9521615505218506

x0=[76.0, 97.0, 20.0, 120.0, 0.01, 1000.0]

Pytorch
Converged: False
Function evaluations: 241
Iterations: 201
ElapsedTime: 1.4168665409088135

jax.scipy.optimize
Converged: False
Function evaluations: 96
Iterations: 4
ElapsedTime: 1.4578144550323486

x0=[76.0, 97.0, 20.0, 120.0, 0.01, 10000.0]

Pytorch
Converged: False
Function evaluations: 479
Iterations: 387
ElapsedTime: 2.889092206954956

jax.scipy.optimize
Converged: True
Function evaluations: 351
Iterations: 210
ElapsedTime: 1.4964311122894287

and here is the script

import pdb
import time
import copy
import scipy.optimize
import torch
import numpy as np
import jax.numpy as jnp
import jax.scipy.optimize
jax.config.update("jax_enable_x64", True)

def rosenbrock(x):
    answer = sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)
    return answer

evalFunc = rosenbrock
jEvalFunc = jax.jit(evalFunc)

x0_1e2 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e2]
x0_1e3 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e3]
x0_1e4 = [76.0, 97.0, 20.0, 120.0, 0.01, 1e4]
x0s = [x0_1e2, x0_1e3, x0_1e4]

toleranceGrad = 1e-5
toleranceChange = 1e-9
xConvTol = 1e-6
lineSearchFn = "strong_wolfe"
maxIter = 1000

def closure():
    optimizer.zero_grad()
    curEval = evalFunc(x=x[0])
    curEval.backward(retain_graph=True)
    return curEval

for x0 in x0s:
    ptTrueMin = torch.ones(len(x0))
    jnpTrueMin = jnp.ones(len(x0))
    xPT = torch.tensor(copy.deepcopy(x0), dtype=torch.double)
    xPT.requires_grad = True
    x = [xPT]
    optimizer = torch.optim.LBFGS(x, max_iter=maxIter, line_search_fn=lineSearchFn, tolerance_grad=toleranceGrad, tolerance_change=toleranceChange)
    tStart = time.time()
    optimizer.step(closure)
    elapsedTime = time.time()-tStart
    statePT = optimizer.state[optimizer._params[0]]
    funcEvalsPT = statePT["func_evals"]
    nIterPT = statePT["n_iter"]
    print("x0={}\n".format(x0))
    print("\tPytorch")
    print("\t\tConverged: {}".format(torch.norm(xPT-ptTrueMin, p=2)<xConvTol))
    print("\t\tFunction evaluations: {:d}".format(funcEvalsPT))
    print("\t\tIterations: {:d}".format(nIterPT))
    print("\t\tElapsedTime: {}\n".format(elapsedTime))

    minimizeOptions = {'gtol': toleranceGrad, 'maxiter': maxIter}
    jx0 = jnp.array(x0)
    tStart = time.time()
    optimRes = jax.scipy.optimize.minimize(fun=jEvalFunc, x0=jx0, method='BFGS', options=minimizeOptions)
    elapsedTime = time.time()-tStart

    print("\tjax.scipy.optimize")
    print("\t\tConverged: {}".format(jnp.linalg.norm(optimRes.x-jnpTrueMin, ord=2)<xConvTol))
    print("\t\tFunction evaluations: {:d}".format(optimRes.nfev))
    print("\t\tIterations: {:d}".format(optimRes.nit))
    print("\t\tElapsedTime: {}\n\n".format(elapsedTime))

cc @Joshuaalbert in case he can help

@fehiepsi
Copy link
Contributor

fehiepsi commented Feb 4, 2021

Here you are

x0=[76.0, 97.0, 20.0, 120.0, 0.01, 100.0]

	Pytorch
		Converged: False
xPT and ptTrueMin: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], dtype=torch.float64) tensor([1., 1., 1., 1., 1., 1.])
		Function evaluations: 126
		Iterations: 108
		ElapsedTime: 0.20932483673095703

	jax.scipy.optimize
		Converged: True
optimRes.x and jnpTrueMin: [1.         0.99999999 0.99999998 0.99999997 0.99999993 0.99999987] [1. 1. 1. 1. 1. 1.]
		Function evaluations: 262
		Iterations: 144
		ElapsedTime: 3.520697593688965


x0=[76.0, 97.0, 20.0, 120.0, 0.01, 1000.0]

	Pytorch
		Converged: False
xPT and ptTrueMin: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], dtype=torch.float64) tensor([1., 1., 1., 1., 1., 1.])
		Function evaluations: 234
		Iterations: 197
		ElapsedTime: 0.4817698001861572

	jax.scipy.optimize
		Converged: False
optimRes.x and jnpTrueMin: [-147512.64142182   56785.65609855 -534771.80798748    5180.08798277
 3284576.29656548 -197067.97604404] [1. 1. 1. 1. 1. 1.]
		Function evaluations: 96
		Iterations: 4
		ElapsedTime: 0.6329975128173828


x0=[76.0, 97.0, 20.0, 120.0, 0.01, 10000.0]

	Pytorch
		Converged: True
xPT and ptTrueMin: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], dtype=torch.float64) tensor([1., 1., 1., 1., 1., 1.])
		Function evaluations: 478
		Iterations: 393
		ElapsedTime: 1.0849919319152832

	jax.scipy.optimize
		Converged: True
optimRes.x and jnpTrueMin: [1.         1.         1.         1.         0.99999999 0.99999998] [1. 1. 1. 1. 1. 1.]
		Function evaluations: 364
		Iterations: 219
		ElapsedTime: 0.6550674438476562

PyTorch converges (with tol 1e-4) in all 3 cases. You might think that 1e-4 is not a good precision. I don't have an intuition on what is a good precision here.

@joacorapela
Copy link
Author

joacorapela commented Feb 4, 2021

Very helpful hints @fehiepsi. To summarize, the convergence problem was due to using a float initial condition instead of double. Convergence within 1e-4 is good for me. Thanks a lot. I will close this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: optimizer Related to torch.optim needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants