-
Notifications
You must be signed in to change notification settings - Fork 559
[SPMD] Support SPMDFullToShardShape #6922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
74b494d
fa34ef6
0e0ad18
a112051
b3ee21a
cfa9c56
a6d1be2
3ececd3
d77f85c
ef834bf
d9c3e24
52bf290
c70900b
7098b3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,14 +7,27 @@ namespace torch_xla { | |
|
|
||
| class CustomSharding : public XlaNode { | ||
| public: | ||
| // The following enum represents the custom_call_target being | ||
| // passed to xla builder. The actual sharding will still be | ||
| // attached to the XLATensor. | ||
| enum class Type { | ||
| kSharding, | ||
| kSPMDFullToShardShape, | ||
| kSPMDShardToFullShape, | ||
| }; | ||
|
Comment on lines
+13
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This enum is really confusing, can you add some comment around what they actually does? I was reading the SPMD code again, this op itself only means we want to shard the underlying value and the actual sharding resides in the XlaTensor or Based XLAIR object? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, this is just the name of the custom call. The sharding annotation is in XlaTensor as normal. I can add more explanations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can annotate explicilty that this is sharding type for custom call in the enum class name or somethinhg. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the current approach sort of does it already? Can you be more specific? @yeounoh There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, |
||
|
|
||
| // Make a custom call to Sharding. | ||
| CustomSharding(const torch::lazy::Value& input); | ||
| CustomSharding(const torch::lazy::Value& input, | ||
| const xla::Shape& output_shape, const Type& type); | ||
|
|
||
| torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; | ||
|
|
||
| XlaOpVector Lower(LoweringContext* loctx) const override; | ||
|
|
||
| std::string ToString() const override; | ||
|
|
||
| Type type; | ||
| xla::Shape output_shape; | ||
| }; | ||
|
|
||
| } // namespace torch_xla | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you intend to keep this
xx.cpu?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, it's more like a note that this won't work... I was trying to use with self.assertRaises but that doesn't capture the exception... I have noticed this before too. When libtpu crashed, it's hard to catch it in the py level. Not sure why. Maybe you have some better ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I think I run into similar issue before.. The way I handle it was ugly through
xla/test/spmd/test_dynamo_spmd.py
Lines 172 to 181 in a7a1357
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ crash on pt level can be caught with
self.assertRaisebut not libtpu level.... I'm not sure why... yea, not even with this hack...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @will-cromar Do you know how to catch libtpu exception on py? Appreciate your insights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you can. To make a proper runtime error, you have to raise an exception, and Google internal binaries don't generally do that. I wrote about a similar case in #6700 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Will. That makes a lot of sense now.