-
Notifications
You must be signed in to change notification settings - Fork 559
deepcopy the network so the parallel versions initialize to the same #755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@dlibenzi , more details in the email I sent. |
dlibenzi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving with the required changes.
torch_xla_py/data_parallel.py
Outdated
| xm.Replication(self._device_ids, replication_devices) | ||
| if replication_devices else None) | ||
| self._models = [] | ||
| netw = network() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch Taylan!
But let's do a change, since it came up with SF as well.
module = network() if iscallable(network) else network
for device in device_ids:
devive_module = deepcopy(module).to(device=torch.device(device))
self._models.append(device_module)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you do not need deepcopy. As long as the model is generated once outside the loop, we are fine:
module = network() if iscallable(network) else network
for device in device_ids:
devive_module = module.to(device=torch.device(device))
self._models.append(device_module)
The to() call will simply copy the same parameters to the different devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually2 😁
The network (torch.nn.Module) is a callable, so maybe something like:
module = network if isinstance(network, torch.nn.Module) else network()
for device in device_ids:
devive_module = module.to(device=torch.device(device))
self._models.append(device_module)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. We were already calling network() before, so I guess this is a separate fix, two birds with 1 stone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dlibenzi I tested this version, and it actually does not work w/o deepcopy.
Exception in model function for device=xla:8: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:8
Exception in model function for device=xla:7: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:7
Exception in model function for device=xla:1: Function AddmmBackward returned an invalid gradient at index 2 - expected device xla:8 but got xla:1
Exception in model function for device=xla:5: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:5
Exception in model function for device=xla:6: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:6
Exception in model function for device=xla:2: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:2
Exception in model function for device=xla:4: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:4
Exception in model function for device=xla:3: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:3
Traceback (most recent call last):
File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
result.result = loop_fn(module, loader, torch.device(device), context)
Traceback (most recent call last):
File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
result.result = loop_fn(module, loader, torch.device(device), context)
File "/pytorch/xla/test/test_train_mnist.py", line 103, in train_loop_fn
loss.backward()
File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/tensor.py", line 120, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
Traceback (most recent call last):
File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
result.result = loop_fn(module, loader, torch.device(device), context)
Traceback (most recent call last):
File "/pytorch/xla/test/test_train_mnist.py", line 103, in train_loop_fn
loss.backward()
File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/tensor.py", line 120, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
Traceback (most recent call last):
File "/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Function AddmmBackward returned an invalid gradient at index 0 - expected device xla:1 but got xla:3
Traceback (most recent call last):
File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
result.result = loop_fn(module, loader, torch.device(device), context)
Traceback (most recent call last):
Traceback (most recent call last):
File "/pytorch/xla/torch_xla_py/data_parallel.py", line 204, in _module_runner
result.result = loop_fn(module, loader, torch.device(device), context)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, module.to() should create a new model with all parameters moved.
Can you debug a little?
|
@dlibenzi I investigated a little; in the docstring of torch, it says |
Oh right. I remember we had a conversation with PT folks. IMHO a |
|
I agree, in fact that's what I tried first, moved module creation out and it gave me that error. Only then I got to deepcopy. Anyway, adding deepcopy back in, and leaving the |
|
I tested, this final version runs and achieves 96% accuracy after 1 epoch (mnist) (as before) |
Thanks Taylan, will merge once the CI passed! |
... set of weights
PAIR=@jysohn23
Daniel and I were investigating how the parallel model training works, and we uncovered an issue. After adding
to test/test_train_mnist.py, we saw that without this edit, weights of models on different devices aren't the same. We then figured out that this is due to random initialization of weights each time the network() call happens. Deepcopy guarantees that all devices start from the same point, so all gradient computations are based at the same point, and backwards passes are mathematically correct. As a result, accuracy after 1 epoch goes up to 96% (84% without this change).