Skip to content

Conversation

d4l3k
Copy link
Member

@d4l3k d4l3k commented Mar 11, 2024

This adds some basic comm tests to test_tp_examples. This validates that the expected distributed calls are being made for test_transformer_training.

Fixes #121649

Test plan:

pytest test/distributed/tensor/parallel/test_tp_examples.py -k test_transformer_training

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

@d4l3k d4l3k requested a review from wanchaol March 11, 2024 20:59
@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 11, 2024
Copy link

pytorch-bot bot commented Mar 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121669

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 6035e67 with merge base 443444d (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@d4l3k d4l3k requested a review from kurman March 11, 2024 20:59
@github-actions github-actions bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 11, 2024
@awgu
Copy link
Collaborator

awgu commented Mar 11, 2024

Oof, this will conflict with #121660 -- we want to have a canonical TP/SP sharding for the Transformer so we can use that for FSDP + TP/SP unit tests.
cc: @wanchaol

Should we just duplicate the TP/SP sharding logic since it seems we need to insert inline tests here?

@d4l3k
Copy link
Member Author

d4l3k commented Mar 11, 2024

@awgu I can just pull the tests to wrap the whole parallelize method since that shouldn't be making any network calls. Not a problem. Do you know when that's landing?

@d4l3k d4l3k requested a review from awgu March 11, 2024 21:09
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and fast changes! have some comments around the sharding init comm tracking

self._check_module(model, model_tp)
if is_seq_parallel:
self.assertDictEqual(comm_mode.get_comm_counts(), {
c10d_functional.all_reduce: 30,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wat, I am shocked that there're 30 allreduces in optimizer step 😮 I was expecting around 5 allreduces hmm

could you pass foreach=True to Adam optimizer and see if that would reduce # allreduces?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think foreach is a single kernel but the tensors are still kept separate

Copy link
Contributor

@tianyu-l tianyu-l Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems each LayerNorm would incur 2 all-reduces, weight & bias, in the backward pass. With default two-layer Transformer, we have two attention_norm, two ffn_norm, and one final norm -- adding up to 10 all-reduces already?

In general, I wonder if we should include a breakdown of the collectives in the test -- ideally it should be a function on the number of TransformerBlocks (and other relevant configs), so even if we modify the model architecture, the test could still pass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Traceback (most recent call last):
  File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 540, in wrapper
    self._join_processes(fn)
  File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 759, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 809, in _check_return_codes
    raise RuntimeError(error)
RuntimeError: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 656, in run_test
    getattr(self, test_name)()
  File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 542, in wrapper
    fn()
  File "/home/tristanr/pytorch/torch/testing/_internal/common_utils.py", line 2739, in wrapper
    method(*args, **kwargs)
  File "/home/tristanr/pytorch/torch/testing/_internal/common_utils.py", line 439, in instantiated_test
    test(self, **param_kwargs)
  File "/home/tristanr/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 374, in wrapper
    func(self, *args, **kwargs)  # type: ignore[misc]
  File "/home/tristanr/pytorch/torch/testing/_internal/common_distributed.py", line 181, in wrapper
    return func(*args, **kwargs)
  File "/home/tristanr/pytorch/test/distributed/tensor/parallel/test_tp_examples.py", line 247, in test_transformer_training
    optim_tp.step()
  File "/home/tristanr/pytorch/torch/optim/optimizer.py", line 391, in wrapper
    out = func(*args, **kwargs)
  File "/home/tristanr/pytorch/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/tristanr/pytorch/torch/optim/adam.py", line 168, in step
    adam(
  File "/home/tristanr/pytorch/torch/optim/adam.py", line 318, in adam
    func(params,
  File "/home/tristanr/pytorch/torch/optim/adam.py", line 522, in _multi_tensor_adam
    torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
  File "/home/tristanr/pytorch/torch/distributed/_tensor/api.py", line 279, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/home/tristanr/pytorch/torch/distributed/_tensor/dispatch.py", line 111, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/home/tristanr/pytorch/torch/distributed/_tensor/dispatch.py", line 314, in unwrap_to_op_info
    raise RuntimeError(
RuntimeError: aten._foreach_lerp_.Scalar: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Seems to error out with foreach=True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to bypass that error:

from torch.distributed._tensor.experimental import implicit_replication

with implicit_replication():
    optim.step()

cc: @wanchaol

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh i see, this is nn.Layernorm not RMSNorm, where it have 2 allreduces per-layernorm, this make sense then

@d4l3k d4l3k force-pushed the tristanr/test_tp_examples branch from 5114cc4 to 6035e67 Compare March 14, 2024 17:24
@d4l3k d4l3k requested a review from wanchaol March 14, 2024 20:30
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lgtm! thanks for the fast PR!

@d4l3k
Copy link
Member Author

d4l3k commented Mar 15, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 15, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@d4l3k d4l3k deleted the tristanr/test_tp_examples branch March 15, 2024 18:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

guard on tensor parallelism test examples

6 participants