Skip to content

Conversation

@taylanbil
Copy link
Collaborator

@taylanbil taylanbil commented Jun 14, 2019

... set of weights


PAIR=@jysohn23

Daniel and I were investigating how the parallel model training works, and we uncovered an issue. After adding

  accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = sum(accuracies) / len(devices)
    if FLAGS.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())
  # BEGIN ADDITION
  models = model_parallel._models
  for i, model in enumerate(models):
      print('I am model {}'.format(i))
      print(model.fc2.weight)
      print('-'*80)
  # END ADDITION
  return accuracy * 100.0

to test/test_train_mnist.py, we saw that without this edit, weights of models on different devices aren't the same. We then figured out that this is due to random initialization of weights each time the network() call happens. Deepcopy guarantees that all devices start from the same point, so all gradient computations are based at the same point, and backwards passes are mathematically correct. As a result, accuracy after 1 epoch goes up to 96% (84% without this change).

@taylanbil
Copy link
Collaborator Author

@dlibenzi , more details in the email I sent.

Copy link
Collaborator

@dlibenzi dlibenzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with the required changes.

xm.Replication(self._device_ids, replication_devices)
if replication_devices else None)
self._models = []
netw = network()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch Taylan!
But let's do a change, since it came up with SF as well.

module = network() if iscallable(network) else network
for device in device_ids:
  devive_module = deepcopy(module).to(device=torch.device(device))
  self._models.append(device_module)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you do not need deepcopy. As long as the model is generated once outside the loop, we are fine:

module = network() if iscallable(network) else network
for device in device_ids:
  devive_module = module.to(device=torch.device(device))
  self._models.append(device_module)

The to() call will simply copy the same parameters to the different devices.

Copy link
Collaborator

@dlibenzi dlibenzi Jun 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually2 😁
The network (torch.nn.Module) is a callable, so maybe something like:

module = network if isinstance(network, torch.nn.Module) else network()
for device in device_ids:
  devive_module = module.to(device=torch.device(device))
  self._models.append(device_module)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. We were already calling network() before, so I guess this is a separate fix, two birds with 1 stone.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dlibenzi I tested this version, and it actually does not work w/o deepcopy.

Exception in model function for device=xla:8: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:8
Exception in model function for device=xla:7: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:7
Exception in model function for device=xla:1: Function AddmmBackward returned an invalid gradient at index 2 - expected device xla:8 but got xla:1
Exception in model function for device=xla:5: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:5
Exception in model function for device=xla:6: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:6
Exception in model function for device=xla:2: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:2
Exception in model function for device=xla:4: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:4
Exception in model function for device=xla:3: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:3
Traceback (most recent call last):
  File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
    result.result = loop_fn(module, loader, torch.device(device), context)
Traceback (most recent call last):
  File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
    result.result = loop_fn(module, loader, torch.device(device), context)
  File "/pytorch/xla/test/test_train_mnist.py", line 103, in train_loop_fn
    loss.backward()
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/tensor.py", line 120, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
Traceback (most recent call last):
  File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
    result.result = loop_fn(module, loader, torch.device(device), context)
Traceback (most recent call last):
  File "/pytorch/xla/test/test_train_mnist.py", line 103, in train_loop_fn
    loss.backward()
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/tensor.py", line 120, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
Traceback (most recent call last):
  File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:3
Traceback (most recent call last):
  File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
    result.result = loop_fn(module, loader, torch.device(device), context)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
    result.result = loop_fn(module, loader, torch.device(device), context)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, module.to() should create a new model with all parameters moved.
Can you debug a little?

@taylanbil
Copy link
Collaborator Author

@dlibenzi I investigated a little; in the docstring of torch, it says .to modifies the module inplace. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L334

@dlibenzi
Copy link
Collaborator

@dlibenzi I investigated a little; in the docstring of torch, it says .to modifies the module inplace. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L334

Oh right. I remember we had a conversation with PT folks. IMHO a to() API should leave the source unchanged (similar to tensor's to()), but anyway.
Let's add a deepcopy() in the loop then.

@taylanbil
Copy link
Collaborator Author

I agree, in fact that's what I tried first, moved module creation out and it gave me that error. Only then I got to deepcopy. Anyway, adding deepcopy back in, and leaving the isinstance line intact.

@taylanbil
Copy link
Collaborator Author

I tested, this final version runs and achieves 96% accuracy after 1 epoch (mnist) (as before)

@dlibenzi
Copy link
Collaborator

I tested, this final version runs and achieves 96% accuracy after 1 epoch (mnist) (as before)

Thanks Taylan, will merge once the CI passed!

@dlibenzi dlibenzi merged commit fa5ec82 into pytorch:master Jun 17, 2019
@taylanbil taylanbil deleted the weights branch June 18, 2019 18:29
@taylanbil taylanbil restored the weights branch June 18, 2019 18:30
@taylanbil taylanbil deleted the weights branch June 18, 2019 18:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants