Skip to content

Conversation

@JackCaoG
Copy link
Collaborator

This is to enable us to support torch distributed api.

@JackCaoG JackCaoG requested review from hjm-aws and miladm February 10, 2022 05:28
double scale, int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups);

static ir::Value reduce_scatter_out(XLATensor& output, const XLATensor& input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to follow Google C++ style here? https://google.github.io/styleguide/cppguide.html#Inputs_and_Outputs requires you to put output params at the end of the list.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out! I was following the same pattern on this file. We put out as the first tensor argument

https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_methods.cpp#L907

@JackCaoG JackCaoG force-pushed the out_of_place_commun branch from 19f5e1d to 65756ff Compare February 10, 2022 22:30
@JackCaoG
Copy link
Collaborator Author

@hjm-aws I will merge this pr once all test passed.

@JackCaoG JackCaoG merged commit 3a239e6 into master Feb 11, 2022
@JackCaoG JackCaoG deleted the out_of_place_commun branch February 11, 2022 01:29
Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JackCaoG!
LGTM. Some nits are below.

const std::shared_ptr<ir::Value>& token, double scale, int64_t scatter_dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups) {
XLATensor out = bridge::GetXlaTensor(output);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there a reason we don't do this line inside of the reduce_scatter_out call? (like input)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce_scatter_out takes out as a reference(instead of constant reference), so it has to be assigned outside of function calling line.

Copy link
Collaborator

@hjm-aws hjm-aws Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came here for inspiration of how to set the ir::Value for an output tensor. Here I agree with Milad: I don't believe we need to leave the out variable outside of the reduce_scatter_out call. Leaving out outside doesn't do anything -- out will be destructed immediately after the return statement anyway.

XLATensor is a holder for a shared_ptr<Data>, so it can be used as a temp. XLATensor::SetIrValue (and any other non-const method on XLATensor) manipulates the shared_ptr<Data> member inside XLATensor, so passing bridge::GetXlaTensor(output) into a function that will call non-const method on XLATensor is fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM, I found the compiler doesn't like it :D

at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<ir::Value>& token, int64_t dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups) {
XLATensor out = bridge::GetXlaTensor(output);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants