Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipelining] Add manual pipeline stage #126123

Closed
wants to merge 3 commits into from

Conversation

H-Huang
Copy link
Member

@H-Huang H-Huang commented May 13, 2024

Add ManualPipelineStage under _PipelineStage.py

Fix some type hints since args_recv_info can contain more than one RecvInfo. Previously the hint was Tuple[InputInfo] which meant it is a tuple of size 1. This is different from List[InputInfo] which can contain any number of items. I needed to update to Tuple[InputInfo, ...] to make the number of items flexible.

Stack from ghstack (oldest at bottom):

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented May 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126123

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c14c901 with merge base 7f1d5ab (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category labels May 13, 2024
H-Huang added a commit that referenced this pull request May 13, 2024
ghstack-source-id: 810e042f19328e74385fa6a57ccf8d15cfbaa7bc
Pull Request resolved: #126123
Add `ManualPipelineStage` under `_PipelineStage.py`




cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request May 14, 2024
ghstack-source-id: 25a61caadb9e7ef56f5b51e0766aeca501dd59d6
Pull Request resolved: #126123
@H-Huang H-Huang requested review from wconstab and kwen2501 May 14, 2024 14:42
Add `ManualPipelineStage` under `_PipelineStage.py`

Fix some type hints since `args_recv_info` can contain more than one RecvInfo. Previously the hint was `Tuple[InputInfo]` which meant it is a tuple of size 1. This is different from `List[InputInfo]` which can contain any number of items. I needed to update to `Tuple[InputInfo, ...]` to make the number of items flexible.




cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request May 14, 2024
ghstack-source-id: cae56595598646cc58af6dd6298d03aa31cc4740
Pull Request resolved: #126123
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
nit: privatize methods

return metadata


def get_stage_shapes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm happy to land this code first as is, but for the changes for lazy shape inference (and adding complete info like strides dtype etc) can you open an issue to start discussing the proposed changes?

)
return grad_recv_info

def init_p2p_neighbors(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this API be exposed from only ManualStage? I think it should be exposed by the base stage ideally.

(should we recommend users call this? how much does it help perf?)

return True


def validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this method be needed anymore after we switch to lazy shape inference? I think we should be confident enough in our shape inference code that having a separate validation pass would be superfluous. Maybe we just delete this?

PLACEHOLDER_VAL = -1


def create_empty_tensors(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a lower priority thing than the shape inference, I'd like to replace this with use of pytree if that makes sense. The upside would be not having to write a helper function at all, and at the same time getting other data structures (like a dict of tensors) for free.

The downsides could be- perf isn't great for pytree, but it wont matter if we only use it during initialization; we would be supporting more types of inputs, but then we'd also need to define what types of inputs are not supported and guard against those.

return metadata_tensor


def extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any helper method on torch.tensor for getting its metadata as a dict, but i think building a dict here with keys like 'shape, stride, dtype' makes sense, and then just sending that whole dict over the wire, and finally feeding the received dict as **kwargs to torch.empty call would be my approach. do you think that would work?

@H-Huang
Copy link
Member Author

H-Huang commented May 14, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 14, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
Add `ManualPipelineStage` under `_PipelineStage.py`

Fix some type hints since `args_recv_info` can contain more than one RecvInfo. Previously the hint was `Tuple[InputInfo]` which meant it is a tuple of size 1. This is different from `List[InputInfo]` which can contain any number of items. I needed to update to `Tuple[InputInfo, ...]` to make the number of items flexible.

Pull Request resolved: pytorch#126123
Approved by: https://github.com/kwen2501
@github-actions github-actions bot deleted the gh/H-Huang/115/head branch June 14, 2024 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants