Skip to content

Conversation

@JackCaoG
Copy link
Collaborator

Cherry picking #3339 to the release 1.11 branch

@miladm miladm self-requested a review February 16, 2022 01:10
Copy link
Collaborator

@miladm miladm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - Thanks @JackCaoG
(pending the pass of currently running CI tests)

expected = torch.ones((2, 3)) * i
assert torch.all(o.cpu() == expected), f'{o} != {expected}'
expected0 = torch.zeros_like(input)
assert torch.all(xoutput0.cpu() == expected0), f'{xoutput0} != {expected0}'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to make sure after allgather tensor has the expected value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, then shouldn't we reverse the order here

xoutput0 = xoutputs[0] # copy
dist.all_gather(xoutputs, xinput)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, it is a special use case we need to support. We need to make sure xoutput0 also get updated when we do all_gather on xoutputs

@@ -0,0 +1,38 @@
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not easy in the current test structure, but would be more consistent if we can group all_reduce tests in a single test file using unit test framework. Also, it might be better to have a separate distributed test folder under test/.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea.. we should eventually have a separate repo for distributed tests

'''ProcessGroup for XLA devices. See ProcessGroup for doc.

Here we are implementing only a Python subclass. For implementing a
C++/Python extension, see
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have the C++ binding in a separate PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if aws plans to do that, it is also unclear to me how C++ binding is going to help

return WorkXla(output_tensors)

# Call site:
# https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace the reference link with https://github.com/pytorch/pytorch/blob/release/1.11/torch/distributed/distributed_c10d.py#L2774

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is OK, this pr was already submitted to master and we don't need to replace links for every release branch.

else:
raise ValueError(f'Invalid reduce op {reduce_op}')

def allreduce(self, tensors, all_reduce_options):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: inconsistent naming convention, should use all_reduce(), all_gather() like reduce_scatter below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This api if I am not mistaken is inherit from torch distributed, so it has to be allreduce. Example https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1217

raise NotImplementedError


class WorkXla(Work):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. should comment/describe the class better, as it's not obvious from the name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fix that in the master.

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments mostly, approving as the original PR was already approved.

* Add XLA backend for torch.distributed

* Add XLA backend multiprocessing tests.

* Linter fixes.

* Address Jack's comments.

* Fix multiprocessing tests: forgot to import the backend.

* Addressing Shen and Jack's comments.

* Fix a search/replace error.

* Fix typo in test_mp_all_gather_xla_backend to make it real.

* Use new reduce_scatter output param and use tensor.copy_ for all_gather result tensor to avoid graph execution.

* Lint fix.

* Fix TODO(alias).

* Add XRT_WORKERS and XRT_DEVICE_MAP setting back to the unit test as we do not aim to exercise GPU spicific code in the unit test.

* Lint fix.

* Skip XLA backend unit tests for GPU/TPU.

* Address Jack's comments.

* rename tests according to Jack's comment.
@JackCaoG JackCaoG force-pushed the cherry_pick_torch_distributed branch from d6f3e85 to ba3ed7a Compare February 16, 2022 04:08
@JackCaoG JackCaoG merged commit 5066ab1 into release/1.11 Feb 16, 2022
@JackCaoG JackCaoG deleted the cherry_pick_torch_distributed branch February 16, 2022 18:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants