Permalink
Browse files

Fixed broadcast_optimizer_state to properly convert nested types (#608)

  • Loading branch information...
tgaddair committed Nov 3, 2018
1 parent 0d10da3 commit 62d2869047ee8ccab3d559bee35b8f5e392936fb
Showing with 30 additions and 6 deletions.
  1. +20 −4 horovod/torch/__init__.py
  2. +10 −2 test/test_torch.py
@@ -227,6 +227,22 @@ def broadcast_optimizer_state(optimizer, root_rank):
callbacks = {}
occurrences = collections.defaultdict(int)
# Returns the full type structure of the possibly nested objects for recursive casting back
def _get_types(x):
if isinstance(x, collections.Iterable):
return type(x), [_get_types(xi) for xi in x]
else:
return type(x)
# Casts an object encoded in a tensor back into its original type and subtypes
def _recursive_cast(x, dtype):
if isinstance(dtype, tuple):
t, dtypes = dtype
x = t(x)
return t([_recursive_cast(x[i], dtypes[i]) for i in range(len(x))])
else:
return dtype(x)
# Some optimizer parameters may be represented as scalars instead of
# tensors. In such cases, we need to wrap the scalar in a tensor, then
# broadcast, then update the appropriate value in the state_dict with the
@@ -236,9 +252,9 @@ def _from_tensor():
state_dict['state'][pid][name] = t(p.numpy()[0])
return _from_tensor
def _create_option_callback(index, option_key, option_tensor, dtype):
def _create_option_callback(index, option_key, option_tensor, dtypes):
def _from_tensor():
optimizer.param_groups[index][option_key] = dtype(option_tensor.numpy()[0])
optimizer.param_groups[index][option_key] = _recursive_cast(option_tensor.numpy()[0], dtypes)
return _from_tensor
# Param groups are an ordered list, normally there is only one per model,
@@ -252,9 +268,9 @@ def _from_tensor():
# Options like the learning rate are scalar, and need to be wrapped in tensors
key = '%s.%d' % (option_key, index)
dtype = type(option_value)
dtypes = _get_types(option_value)
option_tensor = torch.Tensor([option_value])
callbacks[key] = _create_option_callback(index, option_key, option_tensor, dtype)
callbacks[key] = _create_option_callback(index, option_key, option_tensor, dtypes)
params.append((key, option_tensor))
# The params list here is ordered by the layers in the model
@@ -873,9 +873,9 @@ def test_broadcast_state_options(self):
y = torch.randn(N, D_out).requires_grad_()
params_0 = dict(lr=0.1, momentum=0.8, weight_decay=0.2, nesterov=True,
etas=(0.8, 2.4), step_sizes=(1e-5, 100))
betas=(0.9, 0.999), etas=(0.8, 2.4), step_sizes=(1e-5, 100))
params_1 = dict(lr=0.2, momentum=0.9, weight_decay=0.1, nesterov=False,
etas=(0.25, 1.75), step_sizes=(1e-7, 5))
betas=(0.8, 0.9), etas=(0.25, 1.75), step_sizes=(1e-7, 5))
def create_model(opt_class):
model = torch.nn.Sequential(
@@ -924,8 +924,16 @@ def create_model(opt_class):
p_actual = [p_actual]
p = [p]
for i in range(len(p)):
self.assertEqual(type(p_actual[i]), type(p[i]))
self.assertAlmostEqual(p_actual[i], p[i], delta=1e-5)
# Ensure that the parameter option types are compatible with ops
y_pred = model(x)
loss = F.mse_loss(y_pred, y, size_average=False)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test_compression_fp16(self):
valid_dtypes = [torch.float32, torch.float64]
invalid_dtypes = [torch.uint8, torch.int8, torch.int16,

0 comments on commit 62d2869

Please sign in to comment.