-
Notifications
You must be signed in to change notification settings - Fork 25.6k
DTensor: add comm tests to test_tp_examples #121669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121669
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 6035e67 with merge base 443444d ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@awgu I can just pull the tests to wrap the whole parallelize method since that shouldn't be making any network calls. Not a problem. Do you know when that's landing? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice and fast changes! have some comments around the sharding init comm tracking
self._check_module(model, model_tp) | ||
if is_seq_parallel: | ||
self.assertDictEqual(comm_mode.get_comm_counts(), { | ||
c10d_functional.all_reduce: 30, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wat, I am shocked that there're 30 allreduces in optimizer step 😮 I was expecting around 5 allreduces hmm
could you pass foreach=True
to Adam optimizer and see if that would reduce # allreduces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think foreach is a single kernel but the tensors are still kept separate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems each LayerNorm would incur 2 all-reduces, weight & bias, in the backward pass. With default two-layer Transformer, we have two attention_norm
, two ffn_norm
, and one final norm
-- adding up to 10 all-reduces already?
In general, I wonder if we should include a breakdown of the collectives in the test -- ideally it should be a function on the number of TransformerBlocks (and other relevant configs), so even if we modify the model architecture, the test could still pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traceback (most recent call last):
File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 540, in wrapper
self._join_processes(fn)
File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 759, in _join_processes
self._check_return_codes(elapsed_time)
File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 809, in _check_return_codes
raise RuntimeError(error)
RuntimeError: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 656, in run_test
getattr(self, test_name)()
File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 542, in wrapper
fn()
File "/home/tristanr/pytorch/torch/testing/_internal/common_utils.py", line 2739, in wrapper
method(*args, **kwargs)
File "/home/tristanr/pytorch/torch/testing/_internal/common_utils.py", line 439, in instantiated_test
test(self, **param_kwargs)
File "/home/tristanr/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 374, in wrapper
func(self, *args, **kwargs) # type: ignore[misc]
File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 181, in wrapper
return func(*args, **kwargs)
File "/home/tristanr/pytorch/test/distributed/tensor/parallel/test_tp_examples.py", line 247, in test_transformer_training
optim_tp.step()
File "/home/tristanr/pytorch/torch/optim/optimizer.py", line 391, in wrapper
out = func(*args, **kwargs)
File "/home/tristanr/pytorch/torch/optim/optimizer.py", line 76, in _use_grad
ret = func(self, *args, **kwargs)
File "/home/tristanr/pytorch/torch/optim/adam.py", line 168, in step
adam(
File "/home/tristanr/pytorch/torch/optim/adam.py", line 318, in adam
func(params,
File "/home/tristanr/pytorch/torch/optim/adam.py", line 522, in _multi_tensor_adam
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
File "/home/tristanr/pytorch/torch/distributed/_tensor/api.py", line 279, in __torch_dispatch__
return DTensor._op_dispatcher.dispatch(
File "/home/tristanr/pytorch/torch/distributed/_tensor/dispatch.py", line 111, in dispatch
op_info = self.unwrap_to_op_info(op_call, args, kwargs)
File "/home/tristanr/pytorch/torch/distributed/_tensor/dispatch.py", line 314, in unwrap_to_op_info
raise RuntimeError(
RuntimeError: aten._foreach_lerp_.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
Seems to error out with foreach=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to bypass that error:
from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
optim.step()
cc: @wanchaol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh i see, this is nn.Layernorm not RMSNorm, where it have 2 allreduces per-layernorm, this make sense then
5114cc4
to
6035e67
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This lgtm! thanks for the fast PR!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This adds some basic comm tests to test_tp_examples. This validates that the expected distributed calls are being made for
test_transformer_training
.Fixes #121649
Test plan:
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang