Skip to content

Commit

Permalink
Introduce a prototype for SymmetricMemory (#128582)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them.

### SymmetricMemory

`SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers).

### Python API Example

```python
from torch._C.distributed_c10d import _SymmetricMemory

# Set a store for rendezvousing symmetric allocations on a group of devices
# identified by group_name. The concept of groups is logical; users can
# utilize predefined groups (e.g., a group of device identified by a
# ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
# backends might employ a more efficient communication channel for the actual
# rendezvous process and only use the store for bootstrapping purposes.
_SymmetricMemory.set_group_info(group_name, rank, world_size, store)

# Identical to empty_strided, but allows symmetric memory access to be
# established for the allocated tensor via _SymmetricMemory.rendezvous().
# This function itself is not a collective operation.
t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name)

# Users can write Python custom ops that leverages the symmetric memory access.
# Below are examples of things users can do (assuming the group's world_size is 2).

# Establishes symmetric memory access on tensors allocated via
# _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process,
# and the mapping between a local memory region and the associated SymmetricMemory
# object is unique. Subsequent calls to rendezvous() with the same tensor will receive
# the cached SymmetricMemory object.
#
# The function has a collective semantic and must be invoked simultaneously
# from all rendezvous participants.
symm_mem = _SymmetricMemory.rendezvous(t)

# This represents the allocation on rank 0 and is accessible from all devices.
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)

if symm_mem.rank == 0:
    symm_mem.wait_signal(src_rank=1)
    assert buf.eq(42).all()
else:
    # The remote buffer can be used as a regular tensor
    buf.fill_(42)
    symm_mem.put_signal(dst_rank=0)

symm_mem.barrier()

if symm_mem.rank == 0:
    symm_mem.barrier()
    assert buf.eq(43).all()
else:
    new_val = torch.empty_like(buf)
    new_val.fill_(43)
    # Contiguous copies to/from a remote buffer utilize copy engines
    # which bypasses SMs (i.e. no need to load the data into registers)
    buf.copy_(new_val)
    symm_mem.barrier()
```

### Custom CUDA Comm Kernels

Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels.

```cpp
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
    const at::Tensor& tensor);

class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
 public:
  ...
  virtual std::vector<void*> get_buffer_ptrs() = 0;
  virtual std::vector<void*> get_signal_pad_ptrs() = 0;
  virtual void** get_buffer_ptrs_dev() = 0;
  virtual void** get_signal_pad_ptrs_dev() = 0;
  virtual size_t get_buffer_size() = 0;
  virtual size_t get_signal_pad_size() = 0;
  virtual int get_rank() = 0;
  virtual int get_world_size() = 0;
  ...
};
```

### Limitations of IntraNodeComm and ProcessGroupCudaP2p
Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach:
- Leads to awkward UX in which the required workspace needs to be specified upfront.
- Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather).
- Prevents torch.compile from eliminating all copies.

In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels.

* __->__ #128582

Differential Revision: [D58849033](https://our.internmc.facebook.com/intern/diff/D58849033)
Pull Request resolved: #128582
Approved by: https://github.com/wanchaol
  • Loading branch information
yifuwang authored and pytorchmergebot committed Jun 21, 2024
1 parent f0443ad commit 217aac9
Show file tree
Hide file tree
Showing 16 changed files with 1,265 additions and 111 deletions.
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ include_patterns = [
'aten/src/ATen/native/cudnn/*.cpp',
'c10/**/*.h',
'c10/**/*.cpp',
'distributed/c10d/*SymmetricMemory.*',
'torch/csrc/**/*.h',
'torch/csrc/**/*.hpp',
'torch/csrc/**/*.cpp',
Expand Down
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ cc_library(
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],
Expand Down
2 changes: 2 additions & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/ProcessGroupMPI.cpp",
"torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp",
"torch/csrc/distributed/c10d/Store.cpp",
"torch/csrc/distributed/c10d/SymmetricMemory.cpp",
"torch/csrc/distributed/c10d/TCPStore.cpp",
"torch/csrc/distributed/c10d/TCPStoreBackend.cpp",
"torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp",
Expand Down Expand Up @@ -684,6 +685,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
Expand Down
19 changes: 11 additions & 8 deletions c10/cuda/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
} \
} while (0)

#define C10_LIBCUDA_DRIVER_API(_) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
_(cuMemAddressFree) \
_(cuMemSetAccess) \
_(cuMemUnmap) \
_(cuMemCreate) \
#define C10_LIBCUDA_DRIVER_API(_) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
_(cuMemAddressFree) \
_(cuMemSetAccess) \
_(cuMemUnmap) \
_(cuMemCreate) \
_(cuMemGetAllocationGranularity) \
_(cuMemExportToShareableHandle) \
_(cuMemImportFromShareableHandle) \
_(cuGetErrorString)

#define C10_NVML_DRIVER_API(_) \
Expand Down
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ if(USE_CUDA)
append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS)
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()
Expand Down
158 changes: 158 additions & 0 deletions test/distributed/test_symmetric_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Owner(s): ["module: c10d"]

import torch

import torch.distributed as dist
from torch._C._distributed_c10d import _SymmetricMemory
from torch.distributed.distributed_c10d import _get_process_group_store

from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)


def requires_cuda_p2p_access():
cuda_p2p_access_available = (
torch.cuda.is_available() and torch.cuda.device_count() >= 2
)
num_devices = torch.cuda.device_count()
for i in range(num_devices - 1):
for j in range(i + 1, num_devices):
if not torch.cuda.can_device_access_peer(i, j):
cuda_p2p_access_available = False
break
if not cuda_p2p_access_available:
break

return skip_but_pass_in_sandcastle_if(
not cuda_p2p_access_available,
"cuda p2p access is not available",
)


@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmetricMemoryTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()

@property
def world_size(self) -> int:
return 2

@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")

def _init_process(self):
torch.cuda.set_device(self.device)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
_SymmetricMemory.set_group_info(
"0",
self.rank,
self.world_size,
_get_process_group_store(dist.GroupMember.WORLD),
)

def _verify_symmetric_memory(self, symm_mem):
self.assertEqual(symm_mem.world_size, 2)

buf = symm_mem.get_buffer(0, (64, 64), torch.float32)
if symm_mem.rank == 0:
symm_mem.wait_signal(src_rank=1)
self.assertTrue(buf.eq(42).all())
else:
buf.fill_(42)
symm_mem.put_signal(dst_rank=0)

symm_mem.barrier()

if symm_mem.rank == 0:
symm_mem.barrier()
self.assertTrue(buf.eq(43).all())
else:
buf.fill_(43)
symm_mem.barrier()

symm_mem.barrier()

@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_empty_strided_p2p(self) -> None:
self._init_process()

shape = (64, 64)
stride = (64, 1)
dtype = torch.float32
device = self.device
group_name = "0"
alloc_args = (shape, stride, dtype, device, group_name)

t = torch.empty(shape, dtype=dtype, device=device)
with self.assertRaises(RuntimeError):
_SymmetricMemory.rendezvous(t)

t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
symm_mem = _SymmetricMemory.rendezvous(t)

del t
self._verify_symmetric_memory(symm_mem)
dist.destroy_process_group()

@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_empty_strided_p2p_persistent(self) -> None:
self._init_process()

shape = (64, 64)
stride = (64, 1)
dtype = torch.float32
device = self.device
alloc_id = 42 # Persistent allocation
group_name = "0"
alloc_args = (shape, stride, dtype, device, group_name, alloc_id)

t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
data_ptr = t.data_ptr()

# Verify that persistent allocation would fail if there's an active
# allocation with the same alloc_id.
with self.assertRaises(RuntimeError):
_SymmetricMemory.empty_strided_p2p(*alloc_args)

# Verify that persistent allocation would succeed in lieu of activate
# allocations with the same alloc_id, and the returned tensor would
# have the same data pointer.
del t
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
self.assertEqual(t.data_ptr(), data_ptr)

# Verify that get_symmetric_memory would fail if called before
# rendezvous.
with self.assertRaises(RuntimeError):
_SymmetricMemory.get_symmetric_memory(t)

symm_mem_0 = _SymmetricMemory.rendezvous(t)
symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t)
self.assertEqual(id(symm_mem_0), id(symm_mem_1))

self._verify_symmetric_memory(symm_mem_0)
dist.destroy_process_group()


if __name__ == "__main__":
run_tests()
30 changes: 30 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,33 @@ class ProcessGroupCudaP2P(Backend):
storage_offset: Optional[int] = 0,
) -> torch.Tensor: ...
def _shutdown(self) -> None: ...

class _SymmetricMemory:
@staticmethod
def set_group_info(
group_name: str, rank: int, world_size: int, store: Store
) -> None: ...
@staticmethod
def empty_strided_p2p(
size: torch.types._size,
stride: torch.types._size,
dtype: torch.dtype,
device: torch.device,
group_name: str,
) -> torch.Tensor: ...
@property
def rank(self) -> int: ...
@property
def world_size(self) -> int: ...
@staticmethod
def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ...
def get_buffer(
self,
rank: int,
sizes: torch.Size,
dtype: torch.dtype,
storage_offset: Optional[int] = 0,
) -> torch.Tensor: ...
def barrier(self, channel: int = 0) -> None: ...
def put_signal(self, dst_rank: int, channel: int = 0) -> None: ...
def wait_signal(self, src_rank: int, channel: int = 0) -> None: ...
Loading

0 comments on commit 217aac9

Please sign in to comment.