Skip to content

Commit

Permalink
[optimizer] refactor RMSProp to use functional API (#50410)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #50410

Test Plan: Imported from OSS

Reviewed By: izdeby

Differential Revision: D25932779

Pulled By: wanchaol

fbshipit-source-id: b0d6007ea83d77e2d70d04681163ea7e4632c5cd
  • Loading branch information
wanchaol authored and facebook-github-bot committed Jan 21, 2021
1 parent d6fb27c commit ce1781d
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 21 deletions.
40 changes: 40 additions & 0 deletions torch/optim/functional.py
Expand Up @@ -156,3 +156,43 @@ def adadelta(params: List[Tensor],
delta = acc_delta.add(eps).sqrt_().div_(std).mul_(grad)
param.add_(delta, alpha=-lr)
acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)


def rmsprop(params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool):
r"""Functional API that performs rmsprop algorithm computation.
See :class:`~torch.optim.RMSProp` for details.
"""

for i, param in enumerate(params):
grad = grads[i]
square_avg = square_avgs[i]

if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)

if centered:
grad_avg = grad_avgs[i]
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(eps)
else:
avg = square_avg.sqrt().add_(eps)

if momentum > 0:
buf = momentum_buffer_list[i]
buf.mul_(momentum).addcdiv_(grad, avg)
param.add_(buf, alpha=-lr)
else:
param.addcdiv_(grad, avg, value=-lr)
50 changes: 29 additions & 21 deletions torch/optim/rmsprop.py
@@ -1,4 +1,5 @@
import torch
from . import functional as F
from .optimizer import Optimizer


Expand Down Expand Up @@ -66,12 +67,21 @@ def step(self, closure=None):
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
square_avgs = []
grad_avgs = []
momentum_buffer_list = []

for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
params_with_grad.append(p)

if p.grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients')
grads.append(p.grad)

state = self.state[p]

# State initialization
Expand All @@ -83,28 +93,26 @@ def step(self, closure=None):
if group['centered']:
state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)

square_avg = state['square_avg']
alpha = group['alpha']

state['step'] += 1
square_avgs.append(state['square_avg'])

if group['weight_decay'] != 0:
grad = grad.add(p, alpha=group['weight_decay'])
if group['momentum'] > 0:
momentum_buffer_list.append(state['momentum_buffer'])
if group['centered']:
grad_avgs.append(state['grad_avg'])

square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
state['step'] += 1

if group['centered']:
grad_avg = state['grad_avg']
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps'])
else:
avg = square_avg.sqrt().add_(group['eps'])

if group['momentum'] > 0:
buf = state['momentum_buffer']
buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.add_(buf, alpha=-group['lr'])
else:
p.addcdiv_(grad, avg, value=-group['lr'])
F.rmsprop(params_with_grad,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
group['lr'],
group['alpha'],
group['eps'],
group['weight_decay'],
group['momentum'],
group['centered'])

return loss

0 comments on commit ce1781d

Please sign in to comment.