Skip to content

Commit

Permalink
Broadcast optimizer options in addition to parameter state (#562)
Browse files Browse the repository at this point in the history
* Broadcast optimizer options in addition to parameter state

* Added comment

* Added tests for all the optimizer subclasses

* Added comment
  • Loading branch information
tgaddair committed Oct 15, 2018
1 parent 6629f5f commit d90d9e8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
23 changes: 21 additions & 2 deletions horovod/torch/__init__.py
Expand Up @@ -213,8 +213,27 @@ def _from_tensor():
state_dict['state'][pid][name] = t(p.numpy()[0])
return _from_tensor

# Groups are unordered, but their params will be distinct
for group in state_dict['param_groups']:
def _create_option_callback(index, option_key, option_tensor, dtype):
def _from_tensor():
optimizer.param_groups[index][option_key] = dtype(option_tensor.numpy()[0])
return _from_tensor

# Param groups are an ordered list, normally there is only one per model,
# but users can add additional param groups for example to train
# previously frozen layers
for index, group in enumerate(state_dict['param_groups']):
# Broadcast options like learning rate
for option_key, option_value in group.items():
if option_key == 'params':
continue

# Options like the learning rate are scalar, and need to be wrapped in tensors
key = '%s.%d' % (option_key, index)
dtype = type(option_value)
option_tensor = torch.Tensor([option_value])
callbacks[key] = _create_option_callback(index, option_key, option_tensor, dtype)
params.append((key, option_tensor))

# The params list here is ordered by the layers in the model
for pid in group['params']:
param_state = state_dict['state'][pid]
Expand Down
62 changes: 62 additions & 0 deletions test/test_torch.py
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

from distutils.version import LooseVersion
import collections
import inspect
import itertools
import numpy as np
Expand Down Expand Up @@ -829,6 +830,67 @@ def get_optimizer_param_values(optimizer):
else:
self.assertEqual(opt_param_value, opt_param_value_after)

def test_broadcast_state_options(self):
hvd.init()

N, D_in, H, D_out = 64, 100, 10, 10
x = torch.randn(N, D_in).requires_grad_()
y = torch.randn(N, D_out).requires_grad_()

params_0 = dict(lr=0.1, momentum=0.8, weight_decay=0.2, nesterov=True,
etas=(0.8, 2.4), step_sizes=(1e-5, 100))
params_1 = dict(lr=0.2, momentum=0.9, weight_decay=0.1, nesterov=False,
etas=(0.25, 1.75), step_sizes=(1e-7, 5))

def create_model(opt_class):
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)

params = params_0 if hvd.rank() == 0 else params_1
p = {
k: v for k, v in params.items()
if k in inspect.getargspec(opt_class.__init__).args
}
opt = opt_class(model.parameters(), **p)
opt = hvd.DistributedOptimizer(opt, named_parameters=model.named_parameters())

return model, opt

# Include subclass name so we can sort them lexicographically, otherwise different
# ranks will have different optimizer orderings
optimizers = [
(subclass.__name__, subclass)
for subclass in torch.optim.Optimizer.__subclasses__()
if subclass.__module__.startswith('torch.optim') and
subclass != torch.optim.LBFGS and
subclass != torch.optim.SparseAdam
]
optimizers.sort()

for _, opt_class in optimizers:
model, optimizer = create_model(opt_class)
y_pred = model(x)
loss = F.mse_loss(y_pred, y, size_average=False)
optimizer.zero_grad()
loss.backward()
optimizer.step()

hvd.broadcast_optimizer_state(optimizer, root_rank=0)
p0 = {
k: v for k, v in params_0.items()
if k in inspect.getargspec(opt_class.__init__).args
}
for k, p in p0.items():
p_actual = optimizer.param_groups[0][k]
if not isinstance(p, collections.Iterable):
p_actual = [p_actual]
p = [p]
for i in range(len(p)):
self.assertAlmostEqual(p_actual[i], p[i], delta=1e-5)

def test_compression_fp16(self):
valid_dtypes = [torch.float32, torch.float64]
invalid_dtypes = [torch.uint8, torch.int8, torch.int16,
Expand Down

0 comments on commit d90d9e8

Please sign in to comment.