diff --git a/test/test_optim.py b/test/test_optim.py index e52317645917..aff5baf825ef 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -447,6 +447,10 @@ def test_lbfgs(self): lambda weight, bias: optim.LBFGS([weight, bias]), ignore_multidevice=True ) + self._test_basic_cases( + lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_Wolfe"), + ignore_multidevice=True + ) @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") def test_lbfgs_return_type(self): diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 3ccfc13af88d..d104ea994ff3 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -3,8 +3,171 @@ from .optimizer import Optimizer +def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None): + # ported from https://github.com/torch/optim/blob/master/polyinterp.lua + # Compute bounds of interpolation area + if bounds is not None: + xmin_bound, xmax_bound = bounds + else: + xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1) + + # Code for most common case: cubic interpolation of 2 points + # w/ function and derivative values for both + # Solution in this case (where x2 is the farthest point): + # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); + # d2 = sqrt(d1^2 - g1*g2); + # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); + # t_new = min(max(min_pos,xmin_bound),xmax_bound); + d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2) + d2_square = d1 ** 2 - g1 * g2 + if d2_square >= 0: + d2 = d2_square.sqrt() + if x1 <= x2: + min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2)) + else: + min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2)) + return min(max(min_pos, xmin_bound), xmax_bound) + else: + return (xmin_bound + xmax_bound) / 2. + + +def _strong_Wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, + max_ls=25): + # ported from https://github.com/torch/optim/blob/master/lswolfe.lua + d_norm = d.abs().max() + g = g.clone() + # evaluate objective and gradient using initial step + f_new, g_new = obj_func(x, t, d) + ls_func_evals = 1 + gtd_new = g_new.dot(d) + + # bracket an interval containing a point satisfying the Wolfe criteria + t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd + done = False + ls_iter = 0 + while ls_iter < max_ls: + # check conditions + if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev): + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone()] + bracket_gtd = [gtd_prev, gtd_new] + break + + if abs(gtd_new) <= -c2 * gtd: + bracket = [t] + bracket_f = [f_new] + bracket_g = [g_new] + done = True + break + + if gtd_new >= 0: + bracket = [t_prev, t] + bracket_f = [f_prev, f_new] + bracket_g = [g_prev, g_new.clone()] + bracket_gtd = [gtd_prev, gtd_new] + break + + # interpolate + min_step = t + 0.01 * (t - t_prev) + max_step = t * 10 + tmp = t + t = _cubic_interpolate(t_prev, f_prev, gtd_prev, t, f_new, gtd_new, + bounds=(min_step, max_step)) + + # next step + t_prev = tmp + f_prev = f_new + g_prev = g_new.clone() + gtd_prev = gtd_new + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + # reached max number of iterations? + if ls_iter == max_ls: + bracket = [0, t] + bracket_f = [f, f_new] + bracket_g = [g, g_new] + + # zoom phase: we now have a point satisfying the criteria, or + # a bracket around it. We refine the bracket until we find the + # exact point satisfying the criteria + insuf_progress = False + # find high and low points in bracket + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) + while not done and ls_iter < max_ls: + # compute new trial value + t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0], + bracket[1], bracket_f[1], bracket_gtd[1]) + + # test that we are making sufficient progress: + # in case `t` is so close to boundary, we mark that we are making + # insufficient progress, and if + # + we have made insufficient progress in the last step, or + # + `t` is at one of the boundary, + # we will move `t` to a position which is `0.1 * len(bracket)` + # away from the nearest boundary point. + eps = 0.1 * (max(bracket) - min(bracket)) + if min(max(bracket) - t, t - min(bracket)) < eps: + # interpolation close to boundary + if insuf_progress or t >= max(bracket) or t <= min(bracket): + # evaluate at 0.1 away from boundary + if abs(t - max(bracket)) < abs(t - min(bracket)): + t = max(bracket) - eps + else: + t = min(bracket) + eps + insuf_progress = False + else: + insuf_progress = True + else: + insuf_progress = False + + # Evaluate new point + f_new, g_new = obj_func(x, t, d) + ls_func_evals += 1 + gtd_new = g_new.dot(d) + ls_iter += 1 + + if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: + # Armijo condition not satisfied or not lower than lowest point + bracket[high_pos] = t + bracket_f[high_pos] = f_new + bracket_g[high_pos] = g_new.clone() + bracket_gtd[high_pos] = gtd_new + low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) + else: + if abs(gtd_new) <= -c2 * gtd: + # Wolfe conditions satisfied + done = True + elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: + # old high becomes new low + bracket[high_pos] = bracket[low_pos] + bracket_f[high_pos] = bracket_f[low_pos] + bracket_g[high_pos] = bracket_g[low_pos] + bracket_gtd[high_pos] = bracket_gtd[low_pos] + + # new point becomes new low + bracket[low_pos] = t + bracket_f[low_pos] = f_new + bracket_g[low_pos] = g_new.clone() + bracket_gtd[low_pos] = gtd_new + + # line-search bracket is so small + if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: + break + + # return stuff + t = bracket[low_pos] + f_new = bracket_f[low_pos] + g_new = bracket_g[low_pos] + return f_new, g_new, t, ls_func_evals + + class LBFGS(Optimizer): - """Implements L-BFGS algorithm. + """Implements L-BFGS algorithm, heavily inspired by `minFunc + `. .. warning:: This optimizer doesn't support per-parameter options and parameter @@ -30,6 +193,7 @@ class LBFGS(Optimizer): tolerance_change (float): termination tolerance on function value/parameter changes (default: 1e-9). history_size (int): update history size (default: 100). + line_search_fn (str): either 'strong_Wolfe' or None (default: None). """ def __init__(self, params, lr=1, max_iter=20, max_eval=None, @@ -58,11 +222,11 @@ def _gather_flat_grad(self): views = [] for p in self._params: if p.grad is None: - view = p.data.new(p.data.numel()).zero_() - elif p.grad.data.is_sparse: - view = p.grad.data.to_dense().view(-1) + view = p.new(p.numel()).zero_() + elif p.grad.is_sparse: + view = p.grad.to_dense().view(-1) else: - view = p.grad.data.view(-1) + view = p.grad.view(-1) views.append(view) return torch.cat(views, 0) @@ -75,6 +239,20 @@ def _add_grad(self, step_size, update): offset += numel assert offset == self._numel() + def _clone_param(self): + return [p.clone() for p in self._params] + + def _set_param(self, params_data): + for p, pdata in zip(self._params, params_data): + p.data.copy_(pdata) + + def _directional_evaluate(self, closure, x, t, d): + self._add_grad(t, d) + loss = float(closure()) + flat_grad = self._gather_flat_grad() + self._set_param(x) + return loss, flat_grad + def step(self, closure): """Performs a single optimization step. @@ -106,9 +284,10 @@ def step(self, closure): state['func_evals'] += 1 flat_grad = self._gather_flat_grad() - abs_grad_sum = flat_grad.abs().sum() + opt_cond = flat_grad.abs().max() <= tolerance_grad - if abs_grad_sum <= tolerance_grad: + # optimal condition + if opt_cond: return orig_loss # tensors cached in state (for tracing) @@ -116,6 +295,7 @@ def step(self, closure): t = state.get('t') old_dirs = state.get('old_dirs') old_stps = state.get('old_stps') + ro = state.get('ro') H_diag = state.get('H_diag') prev_flat_grad = state.get('prev_flat_grad') prev_loss = state.get('prev_loss') @@ -134,6 +314,7 @@ def step(self, closure): d = flat_grad.neg() old_dirs = [] old_stps = [] + ro = [] H_diag = 1 else: # do lbfgs update (update memory) @@ -146,10 +327,12 @@ def step(self, closure): # shift history by one (limited-memory) old_dirs.pop(0) old_stps.pop(0) + ro.pop(0) # store new direction/step old_dirs.append(y) old_stps.append(s) + ro.append(1. / ys) # update scale of initial Hessian approximation H_diag = ys / y.dot(y) # (y*y) @@ -158,15 +341,10 @@ def step(self, closure): # multiplied by the gradient num_old = len(old_dirs) - if 'ro' not in state: - state['ro'] = [None] * history_size + if 'al' not in state: state['al'] = [None] * history_size - ro = state['ro'] al = state['al'] - for i in range(num_old): - ro[i] = 1. / old_dirs[i].dot(old_stps[i]) - # iteration in L-BFGS loop collapsed to use just one buffer q = flat_grad.neg() for i in range(num_old - 1, -1, -1): @@ -191,18 +369,32 @@ def step(self, closure): ############################################################ # reset initial guess for step size if state['n_iter'] == 1: - t = min(1., 1. / abs_grad_sum) * lr + t = min(1., 1. / flat_grad.abs().sum()) * lr else: t = lr # directional derivative gtd = flat_grad.dot(d) # g * d + # directional derivative is below tolerance + if gtd > -tolerance_change: + break + # optional line search: user function ls_func_evals = 0 if line_search_fn is not None: # perform line search, using user function - raise RuntimeError("line search function is not supported yet") + if line_search_fn != "strong_Wolfe": + raise RuntimeError("only 'strong_Wolfe' is supported") + else: + x_init = self._clone_param() + + def obj_func(x, t, d): + return self._directional_evaluate(closure, x, t, d) + loss, flat_grad, t, ls_func_evals = _strong_Wolfe(obj_func, x_init, t, d, + loss, flat_grad, gtd) + self._add_grad(t, d) + opt_cond = flat_grad.abs().max() <= tolerance_grad else: # no line search, simply move with fixed-step self._add_grad(t, d) @@ -212,7 +404,7 @@ def step(self, closure): # no use to re-evaluate that function here loss = float(closure()) flat_grad = self._gather_flat_grad() - abs_grad_sum = flat_grad.abs().sum() + opt_cond = flat_grad.abs().max() <= tolerance_grad ls_func_evals = 1 # update func eval @@ -228,13 +420,12 @@ def step(self, closure): if current_evals >= max_eval: break - if abs_grad_sum <= tolerance_grad: - break - - if gtd > -tolerance_change: + # optimal condition + if opt_cond: break - if d.mul(t).abs_().sum() <= tolerance_change: + # lack of progress + if d.mul(t).abs().max() <= tolerance_change: break if abs(loss - prev_loss) < tolerance_change: @@ -244,6 +435,7 @@ def step(self, closure): state['t'] = t state['old_dirs'] = old_dirs state['old_stps'] = old_stps + state['ro'] = ro state['H_diag'] = H_diag state['prev_flat_grad'] = prev_flat_grad state['prev_loss'] = prev_loss