-
Notifications
You must be signed in to change notification settings - Fork 21.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce ProcessGroupCudaP2P (#122163)
## Context This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers. The stack contains several components: - `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining. - `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops. - Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops. To enable the prototype feature: - Set the distributed backend to `cuda_p2p`. - Set `torch._inductor.config._micro_pipeline_tp` to `True`. *NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.* ## Benchmark Setup: - 8 x H100 (500W) + 3rd gen NVSwitch. - Llama3 8B training w/ torchtitan. - 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose. Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0 <img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1"> Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn <img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2"> ## This PR `ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA. `ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it. Usage: ``` # Using ProcessGroupCudaP2P dist.init_process_group(backend="cuda_p2p", ...) # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options pg_options = ProcessGroupCudaP2P.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options pg_options = ProcessGroupNCCL.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Using ProcessGroupCudaP2P while specifying both # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options pg_options = ProcessGroupCudaP2P.Options() pg_options.nccl_options = ProcessGroupNCCL.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Down-casting the backend to access p2p buffers for cuda_p2p specific # optimizations if is_cuda_p2p_group(group): backend = get_cuda_p2p_backend(group) if required_p2p_buffer_size > backend.get_buffer_size(): # fallback p2p_buffer = backend.get_p2p_buffer(...) else: # fallback ``` Pull Request resolved: #122163 Approved by: https://github.com/wanchaol
- Loading branch information
1 parent
8a45979
commit 2dd2699
Showing
13 changed files
with
794 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Owner(s): ["module: c10d"] | ||
import os | ||
from typing import List | ||
|
||
import torch | ||
|
||
import torch.distributed as dist | ||
from torch.distributed._cuda_p2p import ( | ||
get_cuda_p2p_backend, | ||
get_p2p_buffer_size, | ||
is_cuda_p2p_group, | ||
) | ||
from torch.testing._internal.common_distributed import ( | ||
MultiProcessTestCase, | ||
requires_nccl, | ||
skip_if_lt_x_gpu, | ||
) | ||
from torch.testing._internal.common_utils import ( | ||
run_tests, | ||
skip_but_pass_in_sandcastle_if, | ||
) | ||
|
||
|
||
def requires_cuda_p2p_access(): | ||
cuda_p2p_access_available = ( | ||
torch.cuda.is_available() | ||
and torch.cuda.device_count() >= 2 | ||
and dist.is_nccl_available() | ||
) | ||
num_devices = torch.cuda.device_count() | ||
for i in range(num_devices - 1): | ||
for j in range(i + 1, num_devices): | ||
if not torch.cuda.can_device_access_peer(i, j): | ||
cuda_p2p_access_available = False | ||
break | ||
if not cuda_p2p_access_available: | ||
break | ||
|
||
return skip_but_pass_in_sandcastle_if( | ||
not cuda_p2p_access_available, | ||
"cuda p2p access is not available", | ||
) | ||
|
||
|
||
@requires_nccl() | ||
@requires_cuda_p2p_access() | ||
class ProcessGroupCudaP2PTest(MultiProcessTestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
self._spawn_processes() | ||
|
||
@property | ||
def world_size(self) -> int: | ||
return 2 | ||
|
||
@property | ||
def ranks(self) -> List[int]: | ||
return list(range(self.world_size)) | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
return torch.device(f"cuda:{self.rank}") | ||
|
||
def _init_process_group(self, buffer_size: int) -> None: | ||
os.environ["TEST_INTRA_NODE_COMM"] = "1" | ||
torch.cuda.set_device(self.device) | ||
|
||
# Verify cuda p2p specific APIs on ProcessGroupCudaP2P | ||
store = dist.FileStore(self.file_name, self.world_size) | ||
options = dist.ProcessGroupCudaP2P.Options() | ||
options.buffer_size = buffer_size | ||
dist.init_process_group( | ||
backend="cuda_p2p", | ||
world_size=self.world_size, | ||
rank=self.rank, | ||
store=store, | ||
pg_options=options, | ||
) | ||
|
||
@skip_if_lt_x_gpu(2) | ||
def test_p2p_apis(self) -> None: | ||
BUFFER_SIZE = 4 * 1024 | ||
|
||
self._init_process_group(BUFFER_SIZE) | ||
|
||
# Verify cuda p2p specific APIs on ProcessGroupCudaP2P | ||
assert is_cuda_p2p_group(dist.group.WORLD) | ||
assert get_p2p_buffer_size(dist.group.WORLD) == BUFFER_SIZE | ||
|
||
backend = get_cuda_p2p_backend(dist.group.WORLD) | ||
assert isinstance(backend, dist.ProcessGroupCudaP2P) | ||
assert backend.get_buffer_size() == BUFFER_SIZE | ||
|
||
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float) | ||
with self.assertRaises(RuntimeError): | ||
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4 + 1,), torch.float) | ||
with self.assertRaises(RuntimeError): | ||
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float, 1) | ||
|
||
# Verify cuda p2p specific APIs on non-cuda p2p process group | ||
non_cuda_p2p_pg = dist.new_group(backend="nccl") | ||
|
||
assert not is_cuda_p2p_group(non_cuda_p2p_pg) | ||
assert get_p2p_buffer_size(non_cuda_p2p_pg) == 0 | ||
with self.assertRaises(TypeError): | ||
get_cuda_p2p_backend(non_cuda_p2p_pg) | ||
|
||
dist.barrier() | ||
torch.cuda.synchronize() | ||
dist.destroy_process_group() | ||
|
||
@skip_if_lt_x_gpu(2) | ||
def test_p2p_buffer(self) -> None: | ||
BUFFER_SIZE = 4 * 1024 | ||
|
||
self._init_process_group(BUFFER_SIZE) | ||
rank = self.rank | ||
world_size = self.world_size | ||
|
||
assert is_cuda_p2p_group(dist.group.WORLD) | ||
backend = get_cuda_p2p_backend(dist.group.WORLD) | ||
local_buffer = backend.get_p2p_buffer( | ||
(rank) % world_size, (BUFFER_SIZE // 4,), torch.float | ||
) | ||
remote_buffer = backend.get_p2p_buffer( | ||
(rank + 1) % world_size, (BUFFER_SIZE // 4,), torch.float | ||
) | ||
|
||
local_buffer.fill_(rank) | ||
backend.intra_node_barrier() | ||
assert remote_buffer.eq((rank + 1) % world_size).all() | ||
|
||
dist.barrier() | ||
torch.cuda.synchronize() | ||
dist.destroy_process_group() | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
2dd2699
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted #122163 on behalf of https://github.com/jithunnair-amd due to This is breaking ROCm distributed CI on trunk (comment)