Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[optimizer] refactor Adadelta to use functional API #50409

Closed
wants to merge 9 commits into from
36 changes: 23 additions & 13 deletions torch/optim/adadelta.py
@@ -1,5 +1,6 @@
import torch

from . import functional as F
from .optimizer import Optimizer


Expand Down Expand Up @@ -49,32 +50,41 @@ def step(self, closure=None):
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
square_avgs = []
acc_deltas = []

for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Adadelta does not support sparse gradients')
grads.append(p.grad)

state = self.state[p]

# State initialization
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['acc_delta'] = torch.zeros_like(p, memory_format=torch.preserve_format)

square_avg, acc_delta = state['square_avg'], state['acc_delta']
rho, eps = group['rho'], group['eps']
square_avgs.append(state['square_avg'])
acc_deltas.append(state['acc_delta'])

state['step'] += 1
lr, rho, eps, weight_decay = group['lr'], group['rho'], group['eps'], group['weight_decay']

if group['weight_decay'] != 0:
grad = grad.add(p, alpha=group['weight_decay'])
state['step'] += 1

square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
std = square_avg.add(eps).sqrt_()
delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
p.add_(delta, alpha=-group['lr'])
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
F.adadelta(params_with_grad,
grads,
square_avgs,
acc_deltas,
lr,
rho,
eps,
weight_decay)

return loss
24 changes: 24 additions & 0 deletions torch/optim/functional.py
Expand Up @@ -132,3 +132,27 @@ def sgd(params: List[Tensor],
d_p = buf

param.add_(d_p, alpha=-lr)


def adadelta(params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
lr: float,
rho: float,
eps: float,
weight_decay: float):
r"""Functional API that performs Adadelta algorithm computation.

See :class:`~torch.optim.Adadelta` for details.
"""

for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas):
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
std = square_avg.add(eps).sqrt_()
delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
param.add_(delta, alpha=-lr)
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)