Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for NCCL alltoall #44374

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d427943
Add support for NCCL all to all
zasdfgbnm Sep 8, 2020
0215034
fix
zasdfgbnm Sep 9, 2020
d03c3b1
fix
zasdfgbnm Sep 9, 2020
03831ec
fix
zasdfgbnm Sep 9, 2020
d0da056
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Sep 9, 2020
ffe67dc
tests
zasdfgbnm Sep 9, 2020
16cec66
fix
zasdfgbnm Sep 9, 2020
c331af0
fix
zasdfgbnm Sep 9, 2020
865f4a8
cleanup
zasdfgbnm Sep 9, 2020
7efc5a3
group
zasdfgbnm Oct 16, 2020
1c39d26
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 16, 2020
1f54c1f
save
zasdfgbnm Oct 16, 2020
0683f63
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 16, 2020
a78d22d
error message
zasdfgbnm Oct 16, 2020
4fb63ac
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 20, 2020
2ab4ff2
OpType::ALLTOALL
zasdfgbnm Oct 20, 2020
62a0d7b
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 20, 2020
213409e
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Oct 21, 2020
76090bb
update all to all
zasdfgbnm Oct 21, 2020
8ba2aa8
more
zasdfgbnm Oct 21, 2020
06e9d9a
fix
zasdfgbnm Oct 21, 2020
e2586cb
fix
zasdfgbnm Oct 21, 2020
3d3e3d8
@skip_if_rocm
zasdfgbnm Oct 21, 2020
81ccdda
Merge branch 'master' into nccl-all2all
zasdfgbnm Oct 23, 2020
de82338
fix
zasdfgbnm Oct 23, 2020
d1590e8
Merge branch 'master' into nccl-all2all
zasdfgbnm Nov 9, 2020
e94b602
Merge branch 'master' into nccl-all2all
zasdfgbnm Nov 14, 2020
81c214b
Update ProcessGroupNCCL.cpp
zasdfgbnm Nov 15, 2020
58d50c5
Update ProcessGroupNCCL.cpp
zasdfgbnm Nov 16, 2020
1e163a6
fix
zasdfgbnm Nov 19, 2020
3e5f29f
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Nov 19, 2020
482368a
fix
zasdfgbnm Nov 19, 2020
7abc38a
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Dec 7, 2020
3b20dd6
Merge branch 'master' of github.com:pytorch/pytorch into nccl-all2all
zasdfgbnm Jan 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -1555,12 +1555,62 @@ void ProcessGroupNCCL::groupEnd() {
--ncclActiveGroupCounter_;
}

#ifdef ENABLE_NCCL_P2P_SUPPORT
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& /* unused */) {
auto device = outputTensors[0].device();
for (size_t r = 0; r < outputTensors.size(); r++) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
check_gpu_single_tensor(outputTensors[r]);
check_gpu_single_tensor(inputTensors[r]);
TORCH_CHECK(device == outputTensors[r].device() && device == inputTensors[r].device(),
"Tensors must be on the same device")
}
std::vector<at::Tensor> inputTensor0 = {inputTensors[0]};
std::vector<at::Tensor> outputTensor0 = {outputTensors[0]};
return collective(
inputTensor0,
outputTensor0,
[&](at::Tensor& /* unused */,
at::Tensor& /* unused */,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
C10D_NCCL_CHECK(ncclGroupStart());
for (size_t r = 0; r < outputTensors.size(); r++) {
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
at::Tensor &input = inputTensors[r];
at::Tensor &output = outputTensors[r];
if (input.numel() != 0) {
C10D_NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
r,
comm,
stream.stream()));
}
if (output.numel() != 0) {
C10D_NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
getNcclDataType(output.scalar_type()),
r,
comm,
stream.stream()));
}
}
C10D_NCCL_CHECK(ncclGroupEnd());
return ncclSuccess;
});
}
#else
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall(
std::vector<at::Tensor>& /* unused */,
std::vector<at::Tensor>& /* unused */,
const AllToAllOptions& /* unused */) {
throw std::runtime_error("ProcessGroupNCCL does not support alltoall");
zasdfgbnm marked this conversation as resolved.
Show resolved Hide resolved
}
#endif

std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::gather(
std::vector<std::vector<at::Tensor>>& /* unused */,
Expand Down
55 changes: 48 additions & 7 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -72,6 +72,7 @@ def __eq__(self, other):

BACKEND = os.environ["BACKEND"]
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
SKIP_NCCL_A2A = (BACKEND == 'nccl' and not bool(int(os.getenv("PYTORCH_TEST_NCCL_A2A", 0))))

DEFAULT_TIMEOUT = 300
CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
Expand Down Expand Up @@ -2018,7 +2019,14 @@ def _test_all_to_all_single_unequal_split_helper(
self.assertEqual(out_tensor, expected_tensor)
self._barrier()

def _test_all_to_all_helper(self, group, group_id, rank):
def _test_all_to_all_helper(
self,
group,
group_id,
rank,
cuda=False,
rank_to_GPU=None,
):
if group_id is not None:
size = len(group)
in_splits = [i + 1 for i in group]
Expand All @@ -2027,6 +2035,10 @@ def _test_all_to_all_helper(self, group, group_id, rank):
]
out_tensors = [torch.ones([(rank + 1), size]) for _ in group]
expected_tensors = [torch.ones([rank + 1, size]) * i for i in group]
if cuda:
in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
expected_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors]
out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
dist.all_to_all(out_tensors, in_tensors, group=group_id)
for t1, t2 in zip(out_tensors, expected_tensors):
self.assertEqual(t1, t2)
Expand All @@ -2039,7 +2051,7 @@ def test_all_to_all_single_equal_split(self):
group, group_id, rank = self._init_global_test()
self._test_all_to_all_single_equal_split_helper(group, group_id, rank)

@unittest.skip("NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(
BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
)
Expand All @@ -2063,7 +2075,7 @@ def test_all_to_all_single_unequal_split(self):
group, group_id, rank = self._init_global_test()
self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)

@unittest.skip("NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(
BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
)
Expand All @@ -2085,6 +2097,13 @@ def test_all_to_all(self):
group, group_id, rank = self._init_global_test()
self._test_all_to_all_helper(group, group_id, rank)

@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
def test_all_to_all_cuda(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()
self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)

@unittest.skipIf(
BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
)
Expand All @@ -2093,7 +2112,7 @@ def test_all_to_all_single_equal_split_group(self):
group, group_id, rank = self._init_group_test()
self._test_all_to_all_single_equal_split_helper(group, group_id, rank)

@unittest.skip("NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(
BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
)
Expand All @@ -2119,7 +2138,7 @@ def test_all_to_all_single_unequal_split_group(self):
group, group_id, rank = self._init_group_test()
self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)

@unittest.skip("NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(
BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
)
Expand All @@ -2143,14 +2162,29 @@ def test_all_to_all_group(self):
group, group_id, rank = self._init_group_test()
self._test_all_to_all_helper(group, group_id, rank)

@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(
BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
)
@skip_if_small_worldsize
def test_all_to_all_group_cuda(self):
group, group_id, rank = self._init_group_test()
rank_to_GPU = self._init_multigpu_helper()
self._test_all_to_all_helper(
group,
group_id,
rank,
True,
rank_to_GPU)

@unittest.skipIf(
BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
)
def test_all_to_all_single_equal_split_full_group(self):
group, group_id, rank = self._init_full_group_test()
self._test_all_to_all_single_equal_split_helper(group, group_id, rank)

@unittest.skip("NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(
BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
)
Expand All @@ -2174,7 +2208,7 @@ def test_all_to_all_single_unequal_split_full_group(self):
group, group_id, rank = self._init_full_group_test()
self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)

@unittest.skip("NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(
BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
)
Expand All @@ -2196,6 +2230,13 @@ def test_all_to_all_full_group(self):
group, group_id, rank = self._init_full_group_test()
self._test_all_to_all_helper(group, group_id, rank)

@unittest.skipIf(SKIP_NCCL_A2A, "NCCL A2A is not enabled for OSS builds")
@unittest.skipIf(BACKEND != "nccl", "Only NCCL supports CUDA all_to_all")
def test_all_to_all_full_group_cuda(self):
group, group_id, rank = self._init_full_group_test()
rank_to_GPU = self._init_multigpu_helper()
self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)

# BARRIER
def _test_barrier_helper(
self, group, group_id, rank, cuda=False, rank_to_GPU=None):
Expand Down