Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipelining] Add grad test for interleaved schedules #126931

Closed
wants to merge 3 commits into from

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented May 22, 2024

Copy link

pytorch-bot bot commented May 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126931

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 763dd43 with merge base c46b38b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category labels May 22, 2024
kwen2501 added a commit that referenced this pull request May 22, 2024
ghstack-source-id: 183286926363630293f1b6b3f6655400f50f0538
Pull Request resolved: #126931
```
Traceback (most recent call last):
  File "/data/users/kw2501/pytorch/torch/testing/_internal/common_utils.py", line 2756, in wrapper
    method(*args, **kwargs)
  File "/data/users/kw2501/pytorch/torch/testing/_internal/common_utils.py", line 443, in instantiated_test
    test(self, **param_kwargs)
  File "/data/users/kw2501/pytorch/test/distributed/pipelining/test_schedule.py", line 316, in test_grad_with_manual_interleaved
    out = schedule.step(target=target, losses=losses)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/kw2501/pytorch/torch/distributed/pipelining/PipelineSchedule.py", line 578, in step
    self._step_microbatches(args_split, kwargs_split, targets_split, losses)
  File "/data/users/kw2501/pytorch/torch/distributed/pipelining/PipelineSchedule.py", line 820, in _step_microbatches
    ops.extend(bwd_stage.get_bwd_send_ops())
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/kw2501/pytorch/torch/distributed/pipelining/PipelineStage.py", line 339, in get_bwd_send_ops
    raise RuntimeError(
RuntimeError: [1] for chunk 0 has gradients None and is expecting to send gradients to stage 0

To execute this test, run the following from the base repo dir:
     python test/distributed/pipelining/test_schedule.py -k ScheduleTest.test_grad_with_manual_interleaved_ScheduleClass0
```

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request May 23, 2024
ghstack-source-id: 2dbaf5e4a27442563288663de2dd21310d83623f
Pull Request resolved: #126931
@kwen2501 kwen2501 changed the title [WIP][pipelining] Add grad test for interleaved schedules [pipelining] Add grad test for interleaved schedules May 23, 2024
@kwen2501 kwen2501 requested review from wconstab and H-Huang May 23, 2024 00:15
with torch.no_grad():
y = ref_mod(x)
# Add a small perturbation
target = y + torch.randn(batch_size, d_hid, device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice way of doing the grad test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except, how did you determine the perturbation is 'small'? (is the default of randn small relative to the norm of the model's output?)

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! thanks for putting this in.

@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 23, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@kwen2501
Copy link
Contributor Author

@pytorchbot merge -f "the pull or windows failure does not seem related"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@clee2000
Copy link
Contributor

@pytorchbot revert -m "newly added test fails distributed/pipelining/test_schedule.py::ScheduleTest::test_grad_with_manual_interleaved_ScheduleClass0 https://hud.pytorch.org/pytorch/pytorch/commit/abf6d4e6bc1a9a0e08bfc2204560ca7858fa90cd https://github.com/pytorch/pytorch/actions/runs/9214413308/job/25352507591, pull workflow failed on startup on PR, so no distributed tests ran at all" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 23, 2024
This reverts commit abf6d4e.

Reverted #126931 on behalf of https://github.com/clee2000 due to newly added test fails distributed/pipelining/test_schedule.py::ScheduleTest::test_grad_with_manual_interleaved_ScheduleClass0 https://hud.pytorch.org/pytorch/pytorch/commit/abf6d4e6bc1a9a0e08bfc2204560ca7858fa90cd https://github.com/pytorch/pytorch/actions/runs/9214413308/job/25352507591, pull workflow failed on startup on PR, so no distributed tests ran at all ([comment](#126931 (comment)))
@pytorchmergebot
Copy link
Collaborator

@kwen2501 your PR has been successfully reverted.

Added `test_grad_with_manual_interleaved`:
- Model: `MultiMLP`
- Tested schedules: Interleaved1F1B, LoopedBFS
- Two stages per rank
```
Rank 0 stages: [0, 2]
Rank 1 stages: [1, 3]
```

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request May 24, 2024
Added `test_grad_with_manual_interleaved`:
- Model: `MultiMLP`
- Tested schedules: Interleaved1F1B, LoopedBFS
- Two stages per rank
```
Rank 0 stages: [0, 2]
Rank 1 stages: [1, 3]
```

Pull Request resolved: #126931
Approved by: https://github.com/wconstab
ghstack dependencies: #126812, #126721, #126735, #126927
ghstack-source-id: effa0de1fe6d4d422fdcab0813d98fc1f02e9186
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 25, 2024
Added to `multigpu` test config, which is run periodically.

Pull Request resolved: #127066
Approved by: https://github.com/H-Huang, https://github.com/wconstab
ghstack dependencies: #127136, #126931
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this pull request May 28, 2024
Added `test_grad_with_manual_interleaved`:
- Model: `MultiMLP`
- Tested schedules: Interleaved1F1B, LoopedBFS
- Two stages per rank
```
Rank 0 stages: [0, 2]
Rank 1 stages: [1, 3]
```

Pull Request resolved: pytorch#126931
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#126812, pytorch#126721, pytorch#126735, pytorch#126927
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this pull request May 28, 2024
…#126931)"

This reverts commit abf6d4e.

Reverted pytorch#126931 on behalf of https://github.com/clee2000 due to newly added test fails distributed/pipelining/test_schedule.py::ScheduleTest::test_grad_with_manual_interleaved_ScheduleClass0 https://hud.pytorch.org/pytorch/pytorch/commit/abf6d4e6bc1a9a0e08bfc2204560ca7858fa90cd https://github.com/pytorch/pytorch/actions/runs/9214413308/job/25352507591, pull workflow failed on startup on PR, so no distributed tests ran at all ([comment](pytorch#126931 (comment)))
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this pull request May 28, 2024
Added `test_grad_with_manual_interleaved`:
- Model: `MultiMLP`
- Tested schedules: Interleaved1F1B, LoopedBFS
- Two stages per rank
```
Rank 0 stages: [0, 2]
Rank 1 stages: [1, 3]
```

Pull Request resolved: pytorch#126931
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#127136
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this pull request May 28, 2024
Added to `multigpu` test config, which is run periodically.

Pull Request resolved: pytorch#127066
Approved by: https://github.com/H-Huang, https://github.com/wconstab
ghstack dependencies: pytorch#127136, pytorch#126931
bigfootjon pushed a commit that referenced this pull request May 28, 2024
Added `test_grad_with_manual_interleaved`:
- Model: `MultiMLP`
- Tested schedules: Interleaved1F1B, LoopedBFS
- Two stages per rank
```
Rank 0 stages: [0, 2]
Rank 1 stages: [1, 3]
```

Pull Request resolved: #126931
Approved by: https://github.com/wconstab
ghstack dependencies: #127136

(cherry picked from commit c1d2564)
bigfootjon pushed a commit that referenced this pull request May 28, 2024
Added to `multigpu` test config, which is run periodically.

Pull Request resolved: #127066
Approved by: https://github.com/H-Huang, https://github.com/wconstab
ghstack dependencies: #127136, #126931

(cherry picked from commit 8bd26ec)
@github-actions github-actions bot deleted the gh/kwen2501/36/head branch June 25, 2024 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants