-
Notifications
You must be signed in to change notification settings - Fork 547
implement send and recv using collective_permute #9373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
WARNING: This function is not very reliable, may produce wrong results under | ||
certain inputs. Use it at your own risk. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed in #8815 there's no context for this ancient warning. Given the age, lack of details, and lack of any other reported bugs I think it's best to remove it. If we get a specific bug report then we can act on that.
dist.init_process_group("xla", init_method='xla://') | ||
device = torch_xla.device() | ||
world_size = xr.world_size() | ||
cutoff = world_size // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if the world size is not even, this test will hang. For example, if world size is 3, then index 0 will send to 1 and 1 will recv from 0, but index 2 will try to recv from 1 without an associated send.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I'll update the test so that it is more defensive
torch_xla/distributed/xla_backend.py
Outdated
logging.warning( | ||
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute()." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it happen to print it everytime we trace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably. I'm not sure how to only make it print once -- will look into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked around, and couldn't find a in built way to do this through logging.warning
. Given this is at warning level and can be filtered out, is it worth to seek a solution?
# test/test_torch_distributed_xla_backend.py for an example. | ||
def make_recv_channel_id(self, src_rank, tag): | ||
raise NotImplementedError | ||
|
||
# Call site e.g. | ||
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913 | ||
def recv(self, out_tensors, src_rank, tag=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the warning on the recv end too, so each host has it?
# test/test_torch_distributed_xla_backend.py for an example. | ||
def make_send_channel_id(self, dst_rank, tag): | ||
raise NotImplementedError | ||
|
||
# Call site e.g. | ||
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877 | ||
def send(self, tensors, dst_rank, tag=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're warning to use collective_permute, but it still ends up using a collective permute, should the warning itself be clearer that this is happening under the hood?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could word this better. The real advice is to restructure your code so that each process calls collective_permute with all of the send-recv pairs
@@ -326,6 +326,28 @@ def test_all_to_all_single(self, use_dynamo): | |||
expected.sort().values), | |||
f"Got {val}, expected {expected}") | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last time we checked, we also noticed that https://github.com/pytorch/xla/blob/master/test/test_mp_collective_permute.py didn't work on the CPU, but send/recv did. We might want to double check it.
Is test/test_torch_distributed_xla_backend.py
tested for CPU and Neuron? Would it be possible to test it and see if the change is compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
test/test_torch_distributed_xla_backend.py
tested for CPU and Neuron? Would it be possible to test it and see if the change is compatible?
It is, but it just checks that the expected IR is emitted. It doesn't run anything. And in this case it wasn't a reliable test because, at least for TPU, that IR does not actually run.
test_mp_collective_permute is run for both TPU and Neuron. I don't think it works for CPU but neither do send/recv. The success of test_mp_collective_permute indicates this change should work for Neuron, but to be more certain I could add a test that covers a pipeline-like transfer in addition to the existing test of a permutation-like transfer.
The most direct test would be something like what's in test_collective_ops_tpu.py, which runs the ops to completion, for Neuron.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The most direct test would be something like what's in test_collective_ops_tpu.py, which runs the ops to completion, for Neuron.
This would be great. Any chance we can move it outside of this file and make it general? I can help test it out if so. Otherwise, I'll need to follow up if we can port this entire file to Neuron. I see tpu.num_expected_global_devices
, and pjrt.run_multiprocess
, but haven't seen/used these before.
test/pjrt/test_collective_ops_tpu.py
Outdated
dist.recv(tensor, index - cutoff) | ||
return tensor.cpu() | ||
|
||
def test_send_recv(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original test separated both send and receive. While this is more code efficient, it might be harder to debug as it will not be obvious what the issue is.
I think keeping a test for the total interaction is valid, but is there a way to replicate the other two tests that existed previously?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
send and recv don't work independently. The original test was a "dry run" -- it checked the IR but didn't execute. If it did execute it would fail.
torch_xla/distributed/xla_backend.py
Outdated
logging.warning( | ||
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute()." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked around, and couldn't find a in built way to do this through logging.warning
. Given this is at warning level and can be filtered out, is it worth to seek a solution?
torch_xla/distributed/xla_backend.py
Outdated
# in the sending process it is unchanged. The solution used here is to | ||
# have every process copy a linear combination of the two tensors, but | ||
# send/recv use different coefficients to achieve different outcomes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This took a couple reads until I understood what was going on here. My understanding is that by having both result_t * X + t * Y
you are having both operation IRs be the same as X and Y are constants. That way when the IRs are compared they will be equivalent.
If this understanding is correct, could you add a little bit more here to make it more apparent?
# test/test_torch_distributed_xla_backend.py for an example. | ||
def make_recv_channel_id(self, src_rank, tag): | ||
raise NotImplementedError | ||
|
||
# Call site e.g. | ||
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913 | ||
def recv(self, out_tensors, src_rank, tag=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not assume someone reading "recv" will have read the documentation for "send". I think we should add documentation here. I would then add a note specific about what the IR expectation will be for "send" and "recv" on each of their comments.
The approach implemented here works for a "pipeline" type operation but does not work for a "permutation" type operation. The way this is commonly done in native pytorch in order to avoid deadlocks is that half of the devices send and the other receive, then they switch roles. What this means is that the sending and receiving tensors must be different, and one half of the devices end up having a different IR than the other half, resulting in a deadlock. I'm still searching for a way around this. |
The only way I was able to make a "permutation" type op (every device sends and every device receives) work is by inserting a sync after each set of send/recv. This is not ideal. It's better than the status quo for TPU, which is that send/recv don't work at all. But since Neuron does have something working I'll defer to you @rpsilva-aws . We can put this on ice until the Send/Recv XLA ops can be called directly. |
Hm, that does complicate things... I have it working on TRN, though I deviated a bit with multi-operands to capture tokens. I'll end up creating a PR for this one, which would build upon the work you had in the prior commits. Actually, TRN has the same limitation for send/recv, requiring a graph break. Do you think we can merge this PR without the sync since it's working for existing devices (e.g. TRN), and revisit as we figure out the underlying issues with TPU? If you want to defer until the new ops, or we re-raise the need as we bring in our work, both are ok with me. |
There are two tests in the PR,
I'd be interested in seeing that |
#9315