Skip to content

Commit

Permalink
Use lower case for strong wolfe option. (#22092)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #22092
ghimport-source-id: ccc53ed

Test Plan: Imported from OSS

Differential Revision: D15955996

Pulled By: vincentqb

fbshipit-source-id: 8ffbea3b9ef8ff7021d42524fa46112da8a3438e
  • Loading branch information
vincentqb authored and facebook-github-bot committed Jun 26, 2019
1 parent 9f22805 commit f176950
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
2 changes: 1 addition & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def test_lbfgs(self):
ignore_multidevice=True
)
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_Wolfe"),
lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_wolfe"),
ignore_multidevice=True
)

Expand Down
55 changes: 41 additions & 14 deletions torch/optim/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1 ** 2 - g1 * g2
d2_square = d1**2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
Expand All @@ -31,7 +31,16 @@ def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
return (xmin_bound + xmax_bound) / 2.


def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9,
def _strong_wolfe(obj_func,
x,
t,
d,
f,
g,
gtd,
c1=1e-4,
c2=0.9,
tolerance_change=1e-9,
max_ls=25):
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
d_norm = d.abs().max()
Expand Down Expand Up @@ -72,8 +81,14 @@ def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_chang
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _cubic_interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new,
bounds=(min_step, max_step))
t = _cubic_interpolate(
t_prev,
f_prev,
gtd_prev,
t,
f_new,
gtd_new,
bounds=(min_step, max_step))

# next step
t_prev = tmp
Expand Down Expand Up @@ -193,17 +208,28 @@ class LBFGS(Optimizer):
tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100).
line_search_fn (str): either 'strong_Wolfe' or None (default: None).
line_search_fn (str): either 'strong_wolfe' or None (default: None).
"""

def __init__(self, params, lr=1, max_iter=20, max_eval=None,
tolerance_grad=1e-5, tolerance_change=1e-9, history_size=100,
def __init__(self,
params,
lr=1,
max_iter=20,
max_eval=None,
tolerance_grad=1e-5,
tolerance_change=1e-9,
history_size=100,
line_search_fn=None):
if max_eval is None:
max_eval = max_iter * 5 // 4
defaults = dict(lr=lr, max_iter=max_iter, max_eval=max_eval,
tolerance_grad=tolerance_grad, tolerance_change=tolerance_change,
history_size=history_size, line_search_fn=line_search_fn)
defaults = dict(
lr=lr,
max_iter=max_iter,
max_eval=max_eval,
tolerance_grad=tolerance_grad,
tolerance_change=tolerance_change,
history_size=history_size,
line_search_fn=line_search_fn)
super(LBFGS, self).__init__(params, defaults)

if len(self.param_groups) != 1:
Expand Down Expand Up @@ -384,15 +410,16 @@ def step(self, closure):
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
if line_search_fn != "strong_Wolfe":
raise RuntimeError("only 'strong_Wolfe' is supported")
if line_search_fn != "strong_wolfe":
raise RuntimeError("only 'strong_wolfe' is supported")
else:
x_init = self._clone_param()

def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_Wolfe(obj_func, x_init, t, d,
loss, flat_grad, gtd)

loss, flat_grad, t, ls_func_evals = _strong_wolfe(
obj_func, x_init, t, d, loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
Expand Down

0 comments on commit f176950

Please sign in to comment.