Skip to content

Commit

Permalink
[optimizer] refactor Adadelta to use functional API (#50409)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #50409

Test Plan: Imported from OSS

Reviewed By: izdeby

Differential Revision: D25932780

Pulled By: wanchaol

fbshipit-source-id: 2fc025f66a0e0863f21689892e19d8a5681f2f2f
  • Loading branch information
wanchaol authored and facebook-github-bot committed Jan 21, 2021
1 parent a0cf556 commit d6fb27c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
36 changes: 23 additions & 13 deletions torch/optim/adadelta.py
@@ -1,5 +1,6 @@
import torch

from . import functional as F
from .optimizer import Optimizer


Expand Down Expand Up @@ -49,32 +50,41 @@ def step(self, closure=None):
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
square_avgs = []
acc_deltas = []

for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Adadelta does not support sparse gradients')
grads.append(p.grad)

state = self.state[p]

# State initialization
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['acc_delta'] = torch.zeros_like(p, memory_format=torch.preserve_format)

square_avg, acc_delta = state['square_avg'], state['acc_delta']
rho, eps = group['rho'], group['eps']
square_avgs.append(state['square_avg'])
acc_deltas.append(state['acc_delta'])

state['step'] += 1
lr, rho, eps, weight_decay = group['lr'], group['rho'], group['eps'], group['weight_decay']

if group['weight_decay'] != 0:
grad = grad.add(p, alpha=group['weight_decay'])
state['step'] += 1

square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
std = square_avg.add(eps).sqrt_()
delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
p.add_(delta, alpha=-group['lr'])
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
F.adadelta(params_with_grad,
grads,
square_avgs,
acc_deltas,
lr,
rho,
eps,
weight_decay)

return loss
24 changes: 24 additions & 0 deletions torch/optim/functional.py
Expand Up @@ -132,3 +132,27 @@ def sgd(params: List[Tensor],
d_p = buf

param.add_(d_p, alpha=-lr)


def adadelta(params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
lr: float,
rho: float,
eps: float,
weight_decay: float):
r"""Functional API that performs Adadelta algorithm computation.
See :class:`~torch.optim.Adadelta` for details.
"""

for (param, grad, square_avg, acc_delta) in zip(params, grads, square_avgs, acc_deltas):
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
std = square_avg.add(eps).sqrt_()
delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
param.add_(delta, alpha=-lr)
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)

0 comments on commit d6fb27c

Please sign in to comment.