diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index ae9286a5cb2f..e090f62012b4 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -1,5 +1,6 @@ import torch +from . import functional as F from .optimizer import Optimizer @@ -49,32 +50,41 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: + params_with_grad = [] + grads = [] + square_avgs = [] + acc_deltas = [] + for p in group['params']: if p.grad is None: continue - grad = p.grad - if grad.is_sparse: + params_with_grad.append(p) + if p.grad.is_sparse: raise RuntimeError('Adadelta does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] - # State initialization + # Lazy state initialization if len(state) == 0: state['step'] = 0 state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) state['acc_delta'] = torch.zeros_like(p, memory_format=torch.preserve_format) - square_avg, acc_delta = state['square_avg'], state['acc_delta'] - rho, eps = group['rho'], group['eps'] + square_avgs.append(state['square_avg']) + acc_deltas.append(state['acc_delta']) - state['step'] += 1 + lr, rho, eps, weight_decay = group['lr'], group['rho'], group['eps'], group['weight_decay'] - if group['weight_decay'] != 0: - grad = grad.add(p, alpha=group['weight_decay']) + state['step'] += 1 - square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) - std = square_avg.add(eps).sqrt_() - delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) - p.add_(delta, alpha=-group['lr']) - acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho) + F.adadelta(params_with_grad, + grads, + square_avgs, + acc_deltas, + lr, + rho, + eps, + weight_decay) return loss diff --git a/torch/optim/functional.py b/torch/optim/functional.py index dd6cd3ca7b3c..725dead7e37a 100644 --- a/torch/optim/functional.py +++ b/torch/optim/functional.py @@ -132,3 +132,27 @@ def sgd(params: List[Tensor], d_p = buf param.add_(d_p, alpha=-lr) + + +def adadelta(params: List[Tensor], + grads: List[Tensor], + square_avgs: List[Tensor], + acc_deltas: List[Tensor], + lr: float, + rho: float, + eps: float, + weight_decay: float): + r"""Functional API that performs Adadelta algorithm computation. + + See :class:`~torch.optim.Adadelta` for details. + """ + + for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas): + if weight_decay != 0: + grad = grad.add(param, alpha=weight_decay) + + square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho) + std = square_avg.add(eps).sqrt_() + delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad) + param.add_(delta, alpha=-lr) + acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)