-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pipelining] Add manual pipeline stage #126123
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126123
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c14c901 with merge base 7f1d5ab (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 810e042f19328e74385fa6a57ccf8d15cfbaa7bc Pull Request resolved: #126123
Add `ManualPipelineStage` under `_PipelineStage.py` cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: 25a61caadb9e7ef56f5b51e0766aeca501dd59d6 Pull Request resolved: #126123
Add `ManualPipelineStage` under `_PipelineStage.py` Fix some type hints since `args_recv_info` can contain more than one RecvInfo. Previously the hint was `Tuple[InputInfo]` which meant it is a tuple of size 1. This is different from `List[InputInfo]` which can contain any number of items. I needed to update to `Tuple[InputInfo, ...]` to make the number of items flexible. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse kwen2501 awgu penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: cae56595598646cc58af6dd6298d03aa31cc4740 Pull Request resolved: #126123
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
nit: privatize methods
return metadata | ||
|
||
|
||
def get_stage_shapes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm happy to land this code first as is, but for the changes for lazy shape inference (and adding complete info like strides dtype etc) can you open an issue to start discussing the proposed changes?
) | ||
return grad_recv_info | ||
|
||
def init_p2p_neighbors(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this API be exposed from only ManualStage? I think it should be exposed by the base stage ideally.
(should we recommend users call this? how much does it help perf?)
return True | ||
|
||
|
||
def validate_stage_shapes(pipeline_stages: List[ManualPipelineStage]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this method be needed anymore after we switch to lazy shape inference? I think we should be confident enough in our shape inference code that having a separate validation pass would be superfluous. Maybe we just delete this?
PLACEHOLDER_VAL = -1 | ||
|
||
|
||
def create_empty_tensors( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a lower priority thing than the shape inference, I'd like to replace this with use of pytree if that makes sense. The upside would be not having to write a helper function at all, and at the same time getting other data structures (like a dict of tensors) for free.
The downsides could be- perf isn't great for pytree, but it wont matter if we only use it during initialization; we would be supporting more types of inputs, but then we'd also need to define what types of inputs are not supported and guard against those.
return metadata_tensor | ||
|
||
|
||
def extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see any helper method on torch.tensor for getting its metadata as a dict, but i think building a dict here with keys like 'shape, stride, dtype' makes sense, and then just sending that whole dict over the wire, and finally feeding the received dict as **kwargs to torch.empty call would be my approach. do you think that would work?
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Add `ManualPipelineStage` under `_PipelineStage.py` Fix some type hints since `args_recv_info` can contain more than one RecvInfo. Previously the hint was `Tuple[InputInfo]` which meant it is a tuple of size 1. This is different from `List[InputInfo]` which can contain any number of items. I needed to update to `Tuple[InputInfo, ...]` to make the number of items flexible. Pull Request resolved: pytorch#126123 Approved by: https://github.com/kwen2501
Add
ManualPipelineStage
under_PipelineStage.py
Fix some type hints since
args_recv_info
can contain more than one RecvInfo. Previously the hint wasTuple[InputInfo]
which meant it is a tuple of size 1. This is different fromList[InputInfo]
which can contain any number of items. I needed to update toTuple[InputInfo, ...]
to make the number of items flexible.Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k