-
Notifications
You must be signed in to change notification settings - Fork 562
Add all_gather_out and reduce_scatter_out #3359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| double scale, int64_t scatter_dim, int64_t shard_count, | ||
| std::vector<std::vector<int64_t>> groups); | ||
|
|
||
| static ir::Value reduce_scatter_out(XLATensor& output, const XLATensor& input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to follow Google C++ style here? https://google.github.io/styleguide/cppguide.html#Inputs_and_Outputs requires you to put output params at the end of the list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out! I was following the same pattern on this file. We put out as the first tensor argument
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_methods.cpp#L907
19f5e1d to
65756ff
Compare
|
@hjm-aws I will merge this pr once all test passed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @JackCaoG!
LGTM. Some nits are below.
| const std::shared_ptr<ir::Value>& token, double scale, int64_t scatter_dim, | ||
| int64_t shard_count, | ||
| const std::vector<std::vector<int64_t>>& replica_groups) { | ||
| XLATensor out = bridge::GetXlaTensor(output); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is there a reason we don't do this line inside of the reduce_scatter_out call? (like input)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reduce_scatter_out takes out as a reference(instead of constant reference), so it has to be assigned outside of function calling line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came here for inspiration of how to set the ir::Value for an output tensor. Here I agree with Milad: I don't believe we need to leave the out variable outside of the reduce_scatter_out call. Leaving out outside doesn't do anything -- out will be destructed immediately after the return statement anyway.
XLATensor is a holder for a shared_ptr<Data>, so it can be used as a temp. XLATensor::SetIrValue (and any other non-const method on XLATensor) manipulates the shared_ptr<Data> member inside XLATensor, so passing bridge::GetXlaTensor(output) into a function that will call non-const method on XLATensor is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVM, I found the compiler doesn't like it :D
| at::Tensor& output, const at::Tensor& input, | ||
| const std::shared_ptr<ir::Value>& token, int64_t dim, int64_t shard_count, | ||
| const std::vector<std::vector<int64_t>>& replica_groups) { | ||
| XLATensor out = bridge::GetXlaTensor(output); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
This is to enable us to support torch distributed api.