Skip to content

Commit

Permalink
Introduce ProcessGroupCudaP2P (#122163)
Browse files Browse the repository at this point in the history
## Context
This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via
Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers.

The stack contains several components:
- `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining.
- `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops.
- Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops.

To enable the prototype feature:
- Set the distributed backend to `cuda_p2p`.
- Set `torch._inductor.config._micro_pipeline_tp` to `True`.

*NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.*

## Benchmark
Setup:
- 8 x H100 (500W) + 3rd gen NVSwitch.
- Llama3 8B training w/ torchtitan.
- 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose.

Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0
<img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1">

Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn
<img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2">

## This PR
`ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA.

`ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it.
Usage:
```
    # Using ProcessGroupCudaP2P
    dist.init_process_group(backend="cuda_p2p", ...)

    # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
    pg_options = ProcessGroupCudaP2P.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
    pg_options = ProcessGroupNCCL.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Using ProcessGroupCudaP2P while specifying both
    # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
    pg_options = ProcessGroupCudaP2P.Options()
    pg_options.nccl_options = ProcessGroupNCCL.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Down-casting the backend to access p2p buffers for cuda_p2p specific
    # optimizations
    if is_cuda_p2p_group(group):
        backend = get_cuda_p2p_backend(group)
        if required_p2p_buffer_size > backend.get_buffer_size():
            # fallback
        p2p_buffer = backend.get_p2p_buffer(...)
    else:
        # fallback
```

Pull Request resolved: #122163
Approved by: https://github.com/wanchaol
  • Loading branch information
yifuwang authored and pytorchmergebot committed May 22, 2024
1 parent 8a45979 commit 2dd2699
Show file tree
Hide file tree
Showing 13 changed files with 794 additions and 118 deletions.
1 change: 1 addition & 0 deletions .ci/pytorch/multigpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ time python test/run_test.py --verbose -i distributed/test_c10d_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_nccl
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_cuda_p2p
time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
Expand Down
120 changes: 64 additions & 56 deletions benchmarks/distributed/intra_node_comm/allgather_matmul.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
#!/usr/bin/env python3
# This file contains an example for using IntraNodeComm to implement efficient fused
# This file contains an example for using cuda_p2p backend to implement efficient fused
# allgather_matmul (inspired by https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 and
# @lw's efficient GPU implementation in xformers). Its purpose to help guide the
# development of relevant primitives and serve as an example for interested users.
#
# The benchmark can be executed as follows:
# torchrun --nproc-per-node 8 allgather_matmul.py
#
# NOTE: _IntraNodeComm is a prototype API which WILL change over time.
import os

import torch
import torch._C._distributed_c10d as c10d
import torch.distributed as dist
from torch.distributed._cuda_p2p import ProcessGroupCudaP2P

M = 16384
N = 8192
Expand All @@ -21,55 +20,60 @@
BENCH_ITERS = 50


comm = None
internal_stream = None
internal_event = None
def allgather_matmul(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
group = dist.group.WORLD
group_size = group.size()
A = torch.ops._c10d_functional.all_gather_into_tensor(A_shard, group_size, "0")
A = torch.ops._c10d_functional.wait_tensor(A)
return A @ B


def allgather_matmul(A_shard, B, out, rank, world_size):
def allgather_matmul_p2p(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Equivalent to `torch.matmul(dist.all_gather(A_shard), B)`.
"""
buf_0 = torch.empty_like(A_shard)
buf_1 = torch.empty_like(A_shard)
out_shards = [
out[i : i + A_shard.shape[0]]
for i in range(0, world_size * A_shard.shape[0], A_shard.shape[0])
]
group = dist.group.WORLD
group_size = group.size()
rank = group.rank()
backend = group._get_backend(torch.device("cuda"))

out = torch.empty(
(A_shard.shape[0] * group.size(), B.shape[1]),
dtype=A_shard.dtype,
device="cuda",
)
out_shards = out.chunk(group_size)
local_p2p_buf = backend.get_p2p_buffer(rank, A_shard.shape, A_shard.dtype)

# Perform matmul with the local input shard
torch.matmul(A_shard, B, out=out_shards[rank])

# In another stream, copy the local input shard into the intra-node
# buffer. After the barrier, all peers' input shards are accessible
# via their intra-node buffer without requiring synchronization.
with torch.cuda.stream(internal_stream):
comm.put(A_shard)
comm.barrier()
internal_event.record()
internal_event.wait()

# Copy input shard from remote buffer and perform matmul.
# Alternate between two streams to offset the wave quantization
# effect of smaller matmuls.
for i in range(1, world_size):
with torch.cuda.stream(backend.stream()):
local_p2p_buf.copy_(A_shard)
work = backend.intra_node_barrier()
work.wait()

buf_0 = torch.empty_like(A_shard)
buf_1 = torch.empty_like(A_shard)
for i in range(1, group_size):
if i % 2 == 0:
buf = buf_0
stream = torch.cuda.current_stream()
else:
buf = buf_1
stream = internal_stream
remote = (i + rank) % world_size
stream = backend.stream()
remote_rank = (i + rank) % group_size
remote_p2p_buf = backend.get_p2p_buffer(
remote_rank, A_shard.shape, A_shard.dtype
)
with torch.cuda.stream(stream):
comm.get(remote, buf)
torch.matmul(buf, B, out=out_shards[remote])
buf.copy_(remote_p2p_buf)
torch.matmul(buf, B, out=out_shards[remote_rank])

# Perform another barrier to ensure all peers have completed consuming the
# intra-node buffer so it can be reused.
with torch.cuda.stream(internal_stream):
comm.barrier()
internal_event.record()
internal_event.wait()
with torch.cuda.stream(backend.stream()):
work = backend.intra_node_barrier()
work.wait()
return out


def do_bench(fn):
Expand All @@ -89,42 +93,39 @@ def do_bench(fn):


def main():
os.environ["ENABLE_INTRA_NODE_COMM"] = "1"

rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

assert M % world_size == 0

torch.cuda.set_device(local_rank)
store, _, _ = next(torch.distributed.rendezvous("env://", rank, world_size))

global comm, internal_stream, internal_event
comm = c10d._IntraNodeComm(
store=store,
rank=rank,
world_size=world_size,
buffer_size=M * K * torch.finfo(torch.bfloat16).bits // 8 // world_size,
)
internal_stream = torch.cuda.Stream()
internal_event = torch.cuda.Event()

options = ProcessGroupCudaP2P.Options()
options.buffer_size = M * N * 2 // world_size
dist.init_process_group("cuda_p2p", pg_options=options)

torch.manual_seed(42)
A = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
B = torch.randn((K, N), dtype=torch.bfloat16, device="cuda")
out = torch.empty((M, N), dtype=torch.bfloat16, device="cuda")

stride = M // world_size
A_shard = A[rank * stride : (rank + 1) * stride]

comm.barrier()
torch.cuda.synchronize()
allgather_matmul_ms = do_bench(
lambda: allgather_matmul(A_shard, B, out, rank, world_size)
assert torch.allclose(
allgather_matmul(A_shard, B),
allgather_matmul_p2p(A_shard, B),
)

comm.barrier()
dist.barrier()
torch.cuda.synchronize()
allgather_matmul_ms = do_bench(lambda: allgather_matmul(A_shard, B))

dist.barrier()
torch.cuda.synchronize()
allgather_matmul_p2p_ms = do_bench(lambda: allgather_matmul_p2p(A_shard, B))

dist.barrier()
torch.cuda.synchronize()
matmul_ms = do_bench(lambda: torch.matmul(A, B))

Expand All @@ -134,8 +135,15 @@ def main():
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
f"{allgather_matmul_ms:.4} ms/iter"
)
print(
"allgather_matmul_p2p "
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
f"{allgather_matmul_p2p_ms:.4} ms/iter"
)
print(f"matmul (M={M}, N={N}, K={K}): {matmul_ms:.4} ms/iter")

dist.destroy_process_group()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ libtorch_cuda_distributed_base_sources = [
# These files are only supported on Linux (and others) but not on Windows.
libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
"torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp",
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",
Expand Down
139 changes: 139 additions & 0 deletions test/distributed/test_cuda_p2p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Owner(s): ["module: c10d"]
import os
from typing import List

import torch

import torch.distributed as dist
from torch.distributed._cuda_p2p import (
get_cuda_p2p_backend,
get_p2p_buffer_size,
is_cuda_p2p_group,
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
)


def requires_cuda_p2p_access():
cuda_p2p_access_available = (
torch.cuda.is_available()
and torch.cuda.device_count() >= 2
and dist.is_nccl_available()
)
num_devices = torch.cuda.device_count()
for i in range(num_devices - 1):
for j in range(i + 1, num_devices):
if not torch.cuda.can_device_access_peer(i, j):
cuda_p2p_access_available = False
break
if not cuda_p2p_access_available:
break

return skip_but_pass_in_sandcastle_if(
not cuda_p2p_access_available,
"cuda p2p access is not available",
)


@requires_nccl()
@requires_cuda_p2p_access()
class ProcessGroupCudaP2PTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()

@property
def world_size(self) -> int:
return 2

@property
def ranks(self) -> List[int]:
return list(range(self.world_size))

@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")

def _init_process_group(self, buffer_size: int) -> None:
os.environ["TEST_INTRA_NODE_COMM"] = "1"
torch.cuda.set_device(self.device)

# Verify cuda p2p specific APIs on ProcessGroupCudaP2P
store = dist.FileStore(self.file_name, self.world_size)
options = dist.ProcessGroupCudaP2P.Options()
options.buffer_size = buffer_size
dist.init_process_group(
backend="cuda_p2p",
world_size=self.world_size,
rank=self.rank,
store=store,
pg_options=options,
)

@skip_if_lt_x_gpu(2)
def test_p2p_apis(self) -> None:
BUFFER_SIZE = 4 * 1024

self._init_process_group(BUFFER_SIZE)

# Verify cuda p2p specific APIs on ProcessGroupCudaP2P
assert is_cuda_p2p_group(dist.group.WORLD)
assert get_p2p_buffer_size(dist.group.WORLD) == BUFFER_SIZE

backend = get_cuda_p2p_backend(dist.group.WORLD)
assert isinstance(backend, dist.ProcessGroupCudaP2P)
assert backend.get_buffer_size() == BUFFER_SIZE

backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float)
with self.assertRaises(RuntimeError):
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4 + 1,), torch.float)
with self.assertRaises(RuntimeError):
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float, 1)

# Verify cuda p2p specific APIs on non-cuda p2p process group
non_cuda_p2p_pg = dist.new_group(backend="nccl")

assert not is_cuda_p2p_group(non_cuda_p2p_pg)
assert get_p2p_buffer_size(non_cuda_p2p_pg) == 0
with self.assertRaises(TypeError):
get_cuda_p2p_backend(non_cuda_p2p_pg)

dist.barrier()
torch.cuda.synchronize()
dist.destroy_process_group()

@skip_if_lt_x_gpu(2)
def test_p2p_buffer(self) -> None:
BUFFER_SIZE = 4 * 1024

self._init_process_group(BUFFER_SIZE)
rank = self.rank
world_size = self.world_size

assert is_cuda_p2p_group(dist.group.WORLD)
backend = get_cuda_p2p_backend(dist.group.WORLD)
local_buffer = backend.get_p2p_buffer(
(rank) % world_size, (BUFFER_SIZE // 4,), torch.float
)
remote_buffer = backend.get_p2p_buffer(
(rank + 1) % world_size, (BUFFER_SIZE // 4,), torch.float
)

local_buffer.fill_(rank)
backend.intra_node_barrier()
assert remote_buffer.eq((rank + 1) % world_size).all()

dist.barrier()
torch.cuda.synchronize()
dist.destroy_process_group()


if __name__ == "__main__":
run_tests()
27 changes: 27 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,30 @@ def _register_process_group(
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
def _unregister_all_process_groups() -> None: ...
def _unregister_process_group(group_name: str) -> None: ...

class ProcessGroupCudaP2P(Backend):
class Options:
nccl_options: Optional[ProcessGroupNCCL.Options]
buffer_size: Optional[int]

def __init__(self) -> None: ...

def __init__(
self,
store: Store,
rank: int,
size: int,
options: ProcessGroupCudaP2P.Options,
) -> None: ...
def is_p2p_available(self) -> bool: ...
def get_buffer_size(self) -> int: ...
def stream(self) -> torch.cuda.Stream: ...
def intra_node_barrier(self) -> Work: ...
def get_p2p_buffer(
self,
rank: int,
sizes: torch.Size,
dtype: torch.dtype,
storage_offset: Optional[int] = 0,
) -> torch.Tensor: ...
def _shutdown(self) -> None: ...
Loading

1 comment on commit 2dd2699

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #122163 on behalf of https://github.com/jithunnair-amd due to This is breaking ROCm distributed CI on trunk (comment)

Please sign in to comment.