Skip to content

Conversation

H-Huang
Copy link
Member

@H-Huang H-Huang commented Jun 24, 2024

Stack from ghstack (oldest at bottom):

Changes

  • small fix in stage error message
  • Move format_pipeline_order and _validate_pipeline_order out of test_schedule.py into schedules.py.
  • Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented Jun 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129369

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 57f195c with merge base 37e3c60 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 24, 2024
H-Huang added a commit that referenced this pull request Jun 24, 2024
@H-Huang H-Huang requested review from wconstab June 24, 2024 14:11
# Changes
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`. 
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
# Changes
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`. 
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
# Changes
* small fix in stage error message
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`. 
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jun 25, 2024
@H-Huang H-Huang changed the title [pipelining] [BE] Move pipeline_order validation to utils [pipelining] [BE] Move pipeline_order validation to schedules.py Jun 26, 2024
@H-Huang H-Huang requested a review from haocizhang June 26, 2024 19:03
]
# Join the rows into a single string
formatted_table = (
"=========== ALL_RANK_ACTIONS ===========\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think this function could stay as a helper function maybe called by repr

and it's a little better IMO to remove the ====ALL_RANK.. part so that inside repr we can customize a bit more

repr should probably print out the class name and relevant field values, and maybe print out some shortened summary of the actions but not the whole page-long actions by default? I'd suggest looking at hohw torch.tensor repr works

PipelineSchedule
[..possibly, some args such as number of stages, etc]
Actions
[ ... ]

_batch_p2p(ops).wait()
except Exception as e:
logger.error(
"Exception in rank %s at time step %s", self.rank, time_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this string should include more info. not sure what amount is right but some things that might be good

  • which stage?
  • which action?

Maybe also phrase it more like "PipelineSchedule {schedulename?} caught the following exception when running {}th microbatch {Action} for stage {}" or something?

…les.py"


# Changes
* small fix in stage error message
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`. 
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
@H-Huang
Copy link
Member Author

H-Huang commented Jul 1, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 1, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

…les.py"


# Changes
* small fix in stage error message
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`. 
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jul 1, 2024
@H-Huang
Copy link
Member Author

H-Huang commented Jul 2, 2024

@pytorchbot merge -i

…les.py"


# Changes
* small fix in stage error message
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`. 
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jul 2, 2024
@H-Huang
Copy link
Member Author

H-Huang commented Jul 2, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: pull / linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu), trunk / linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test / test (distributed, 1, 3, linux.8xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@clee2000
Copy link
Contributor

clee2000 commented Jul 2, 2024

@pytorchbot revert -m "broke test/distributed/pipelining/test_schedule.py::ScheduleTest::test_non_symmetric_stage_ids_ScheduleClass0 on distributed cuda https://github.com/pytorch/pytorch/actions/runs/9766039400/job/26959115773 https://hud.pytorch.org/pytorch/pytorch/commit/ec789a3c9ddd4e550b3dea6934ce2d41deb98784. You can see the error on the PR, but Dr. CI classified it wrong" -c weird

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@H-Huang your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Jul 2, 2024
….py (#129369)"

This reverts commit ec789a3.

Reverted #129369 on behalf of https://github.com/clee2000 due to broke test/distributed/pipelining/test_schedule.py::ScheduleTest::test_non_symmetric_stage_ids_ScheduleClass0 on distributed cuda https://github.com/pytorch/pytorch/actions/runs/9766039400/job/26959115773 https://hud.pytorch.org/pytorch/pytorch/commit/ec789a3c9ddd4e550b3dea6934ce2d41deb98784.  You can see the error on the PR, but Dr. CI classified it wrong ([comment](#129369 (comment)))
…les.py"


# Changes
* small fix in stage error message
* Move `format_pipeline_order` and `_validate_pipeline_order` out of `test_schedule.py` into `schedules.py`. 
* Wrap the execution runtime in a try-except which on error will log the timestep and schedule plan before re-raising the exception.

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jul 3, 2024
@H-Huang
Copy link
Member Author

H-Huang commented Jul 3, 2024

Thanks for catching that @clee2000, yeah the error was legit. Fixed now

@H-Huang
Copy link
Member Author

H-Huang commented Jul 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at:
#130041

Details for Dev Infra team Raised by workflow job

@H-Huang
Copy link
Member Author

H-Huang commented Jul 4, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/H-Huang/130/head branch August 4, 2024 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants