- 
                Notifications
    
You must be signed in to change notification settings  - Fork 559
 
Add XLA backend for torch.distributed (#3339) #3378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - Thanks @JackCaoG
(pending the pass of currently running CI tests)
| expected = torch.ones((2, 3)) * i | ||
| assert torch.all(o.cpu() == expected), f'{o} != {expected}' | ||
| expected0 = torch.zeros_like(input) | ||
| assert torch.all(xoutput0.cpu() == expected0), f'{xoutput0} != {expected0}' | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to make sure after allgather tensor has the expected value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, then shouldn't we reverse the order here
xoutput0 = xoutputs[0] # copy
dist.all_gather(xoutputs, xinput)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, it is a special use case we need to support. We need to make sure xoutput0 also get updated when we do all_gather on xoutputs
| @@ -0,0 +1,38 @@ | |||
| import os | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not easy in the current test structure, but would be more consistent if we can group all_reduce tests in a single test file using unit test framework. Also, it might be better to have a separate distributed test folder under test/.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea.. we should eventually have a separate repo for distributed tests
| '''ProcessGroup for XLA devices. See ProcessGroup for doc. | ||
| 
               | 
          ||
| Here we are implementing only a Python subclass. For implementing a | ||
| C++/Python extension, see | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have the C++ binding in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if aws plans to do that, it is also unclear to me how C++ binding is going to help
| return WorkXla(output_tensors) | ||
| 
               | 
          ||
| # Call site: | ||
| # https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace the reference link with https://github.com/pytorch/pytorch/blob/release/1.11/torch/distributed/distributed_c10d.py#L2774
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is OK, this pr was already submitted to master and we don't need to replace links for every release branch.
| else: | ||
| raise ValueError(f'Invalid reduce op {reduce_op}') | ||
| 
               | 
          ||
| def allreduce(self, tensors, all_reduce_options): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: inconsistent naming convention, should use all_reduce(), all_gather() like reduce_scatter below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This api if I am not mistaken is inherit from torch distributed, so it has to be allreduce. Example https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1217
| raise NotImplementedError | ||
| 
               | 
          ||
| 
               | 
          ||
| class WorkXla(Work): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. should comment/describe the class better, as it's not obvious from the name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should fix that in the master.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments mostly, approving as the original PR was already approved.
* Add XLA backend for torch.distributed * Add XLA backend multiprocessing tests. * Linter fixes. * Address Jack's comments. * Fix multiprocessing tests: forgot to import the backend. * Addressing Shen and Jack's comments. * Fix a search/replace error. * Fix typo in test_mp_all_gather_xla_backend to make it real. * Use new reduce_scatter output param and use tensor.copy_ for all_gather result tensor to avoid graph execution. * Lint fix. * Fix TODO(alias). * Add XRT_WORKERS and XRT_DEVICE_MAP setting back to the unit test as we do not aim to exercise GPU spicific code in the unit test. * Lint fix. * Skip XLA backend unit tests for GPU/TPU. * Address Jack's comments. * rename tests according to Jack's comment.
d6f3e85    to
    ba3ed7a      
    Compare
  
    
Cherry picking #3339 to the release 1.11 branch