Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce a prototype for SymmetricMemory (#128582)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them. ### SymmetricMemory `SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers). ### Python API Example ```python from torch._C.distributed_c10d import _SymmetricMemory # Set a store for rendezvousing symmetric allocations on a group of devices # identified by group_name. The concept of groups is logical; users can # utilize predefined groups (e.g., a group of device identified by a # ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator # backends might employ a more efficient communication channel for the actual # rendezvous process and only use the store for bootstrapping purposes. _SymmetricMemory.set_group_info(group_name, rank, world_size, store) # Identical to empty_strided, but allows symmetric memory access to be # established for the allocated tensor via _SymmetricMemory.rendezvous(). # This function itself is not a collective operation. t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name) # Users can write Python custom ops that leverages the symmetric memory access. # Below are examples of things users can do (assuming the group's world_size is 2). # Establishes symmetric memory access on tensors allocated via # _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process, # and the mapping between a local memory region and the associated SymmetricMemory # object is unique. Subsequent calls to rendezvous() with the same tensor will receive # the cached SymmetricMemory object. # # The function has a collective semantic and must be invoked simultaneously # from all rendezvous participants. symm_mem = _SymmetricMemory.rendezvous(t) # This represents the allocation on rank 0 and is accessible from all devices. buf = symm_mem.get_buffer(0, (64, 64), torch.float32) if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) assert buf.eq(42).all() else: # The remote buffer can be used as a regular tensor buf.fill_(42) symm_mem.put_signal(dst_rank=0) symm_mem.barrier() if symm_mem.rank == 0: symm_mem.barrier() assert buf.eq(43).all() else: new_val = torch.empty_like(buf) new_val.fill_(43) # Contiguous copies to/from a remote buffer utilize copy engines # which bypasses SMs (i.e. no need to load the data into registers) buf.copy_(new_val) symm_mem.barrier() ``` ### Custom CUDA Comm Kernels Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels. ```cpp TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory( const at::Tensor& tensor); class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { public: ... virtual std::vector<void*> get_buffer_ptrs() = 0; virtual std::vector<void*> get_signal_pad_ptrs() = 0; virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; ... }; ``` ### Limitations of IntraNodeComm and ProcessGroupCudaP2p Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach: - Leads to awkward UX in which the required workspace needs to be specified upfront. - Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather). - Prevents torch.compile from eliminating all copies. In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels. * __->__ #128582 Differential Revision: [D58849033](https://our.internmc.facebook.com/intern/diff/D58849033) Pull Request resolved: #128582 Approved by: https://github.com/wanchaol
- Loading branch information