diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index de232fcb4cc06..be5d5e26cd7d4 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1199,6 +1199,29 @@ def test_close_pg(self): with self.assertRaises(dist.DistBackendError): pg.allreduce([t]) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs") + def test_tensor_register_hook(self): + os.environ["NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"] = "1" + + store = c10d.FileStore(self.file_name, self.world_size) + pg = self._create_process_group_nccl(store, self.opts()) + local_device_id = self.rank_to_GPU[self.rank][0] + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + # allgather_base is GPU number agnostic. + # Each rank contribute one tensor regardless of GPU counts + tensor = torch.tensor([self.rank]).cuda(local_device_id) + output_t = torch.empty((self.world_size), dtype=tensor.dtype).cuda(local_device_id) + + allgather_base(output_t, tensor) + + # Verification + self.assertEqual(torch.arange(self.world_size), output_t) + class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase ): diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index a7729f1b0a3e0..5b4419459a982 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -60,6 +60,14 @@ #define NCCL_HAS_COMM_CTA_CGA #endif +#if defined(NCCL_REGISTRATION_SUPPORTED) || \ + ((defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 19))) +#define NCCL_HAS_COMM_REGISTER +#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) +#define NCCL_HAS_COMM_REGISTER +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -264,6 +272,21 @@ class NCCLComm { return; } +#ifdef NCCL_HAS_COMM_REGISTER + // Deregister all registered segments before aborting. + for (auto& it : registeredSegmentHandles_) { + void* handle = it.second; + C10D_NCCL_CHECK( + ::ncclCommDeregister(ncclComm_, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + " on ncclComm_ ", + ncclComm_)); + } + registeredSegmentHandles_.clear(); +#endif + // Set true failure reason if provided by ProcessGroupNCCL (e.g. work // timeout) commFailureReason_ = commFailureReason; @@ -306,6 +329,62 @@ class NCCLComm { #endif } + ncclResult_t registerSegment(void* ptr, size_t size) { + std::unique_lock lock(mutex_); +#ifdef NCCL_HAS_COMM_REGISTER + // We register only segments from cache allocator + // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always + // maps to a unique handle and should not be registered before the current + // ptr is deregistered and freed. + TORCH_CHECK( + registeredSegmentHandles_.count(ptr) == 0, + "Segment with ptr ", + ptr, + " has already been registered on ncclComm_ ", + ncclComm_); + + void* handle; + C10D_NCCL_CHECK( + ncclCommRegister(ncclComm_, ptr, size, &handle), + c10::str( + "Failed to register segment with ptr ", + ptr, + ", size ", + size, + " on ncclComm_ ", + ncclComm_)); + registeredSegmentHandles_[ptr] = handle; + return ncclSuccess; +#else + return ncclInvalidUsage; +#endif + } + + ncclResult_t deregisterSegment(void* ptr) { + std::unique_lock lock(mutex_); +#ifdef NCCL_HAS_COMM_REGISTER + TORCH_CHECK( + registeredSegmentHandles_.count(ptr) == 1, + "Segment with ptr ", + ptr, + " is not registered on ncclComm_ ", + ncclComm_); + + void* handle = registeredSegmentHandles_[ptr]; + C10D_NCCL_CHECK( + ncclCommDeregister(ncclComm_, handle), + c10::str( + "Failed to deregister segment handle ", + handle, + " on ncclComm_ ", + ncclComm_)); + registeredSegmentHandles_.erase(ptr); + return ncclSuccess; +#else + return ncclInvalidUsage; +#endif + } + protected: ncclComm_t ncclComm_; // Unique nccl_id for this communicator. @@ -318,6 +397,10 @@ class NCCLComm { // Optional reason for communicator failure, provided by ProcessGroupNCCL for // better error messaging. c10::optional commFailureReason_; +#ifdef NCCL_HAS_COMM_REGISTER + // Stores handlers for tensors registered by NCCL + std::unordered_map registeredSegmentHandles_; +#endif }; // Helper that automatically cleans up premul sums. diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index bbadbb1618c89..898ca0c06b8d1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -317,6 +318,55 @@ c10::List new_list() { } // namespace +// Map from each communicator to its device index. +// This map is used when register/deregister cache segments from cache +// allocator. See design notes below: +// - Each segment should be registered only to the communicator on the +// same device. +// - We cannot reuse devNCCLCommMap_ in each ProcessGroup because the key may be +// ranks rather than device in point-to-point case. +// - This map has also to be maintained as global variable since the register +// hooks are called outside the scope of any PG, thus we need traverse +// communicators in all PGs. +static std::unordered_map, int> ncclCommDevIdxMap; +static std::mutex ncclCommDevIdxMapMutex; +static bool allocatorHooksAttached = false; +void cacheAllocatorRegisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // Register after SEGMENT_ALLOC + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + } + } +} + +void cacheAllocatorDeregisterHook( + const c10::cuda::CUDACachingAllocator::TraceEntry& te) { + // deregister before SEGMENT_FREE + if (te.action_ != + c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) { + return; + } + + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& it : ncclCommDevIdxMap) { + auto& ncclComm = it.first; + auto& devIdx = it.second; + if (te.device_ == devIdx) { + ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + } + } +} + struct NCCLTraceBuffer { static NCCLTraceBuffer* get() { // intentionally leak on exit @@ -817,6 +867,12 @@ void ProcessGroupNCCL::WorkNCCL::abort() { for (const auto& ncclComm : ncclComms_) { ncclComm->ncclCommAbort(); } + + ncclCommDevIdxMapMutex.lock(); + for (const auto& comm : ncclComms_) { + ncclCommDevIdxMap.erase(comm); + } + ncclCommDevIdxMapMutex.unlock(); } ProcessGroupNCCL::CoalescedWorkNCCL::CoalescedWorkNCCL( @@ -881,6 +937,17 @@ ProcessGroupNCCL::ProcessGroupNCCL( parseEnvVarIntDefault("TORCH_NCCL_TRACE_BUFFER_SIZE", 0) > 0); #endif avoidRecordStreams_ = parseEnvVarFlag(TORCH_NCCL_AVOID_RECORD_STREAMS); +#ifdef NCCL_HAS_COMM_REGISTER + useTensorRegisterAllocatorHook_ = + parseEnvVarFlag(NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK); + if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + expandable_segments()) { + useTensorRegisterAllocatorHook_ = false; + LOG(INFO) + << "[Rank " << rank_ + << "] disables NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + } +#endif if (blockingWait_) { if (asyncErrorHandling_ != NoHandling || desyncDebug_) { @@ -932,6 +999,10 @@ ProcessGroupNCCL::ProcessGroupNCCL( << options_->is_high_priority_stream << ", TORCH_DISTRIBUTED_DEBUG: " << std::string(torch_distributed_debug) +#ifdef NCCL_HAS_COMM_REGISTER + << ", NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " + << useTensorRegisterAllocatorHook_ +#endif << ", NCCL_DEBUG: " << std::string(nccl_debug) << ", ID=" << this->getID(); @@ -947,6 +1018,19 @@ ProcessGroupNCCL::ProcessGroupNCCL( std::vector(), // outSplitSizes size_); // worldSize + // Attach hooks to cache allocator to trigger the hooks whenever a traced + // action is called. In the following hooks, we register a newly allocated + // segment when SEGMENT_ALLOC action occurs, and deregister a segment when + // SEGMENT_FREE action occurs. + // We attach hooks only once at the first PG creation. + if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorRegisterHook); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorDeregisterHook); + allocatorHooksAttached = true; + } + #ifdef USE_NCCL_WITH_UCC static c10::once_flag initialize_ucc_lib_flag; c10::call_once(initialize_ucc_lib_flag, [&] { @@ -1124,6 +1208,20 @@ void abortCommsFromMap( // Abort all communicators on this rank void ProcessGroupNCCL::abort(c10::optional abortReason) { + // Remove record from global ncclCommDevIdxMapMutex before aboarting, + // so that a new cache segment would not register to already aborded + // communicators. Note that ncclCommDevIdxMap is a global container which may + // contain other PG's communicators, thus we need to only erase communicators + // for the current PG. + ncclCommDevIdxMapMutex.lock(); + for (auto& it : devNCCLCommMap_) { + auto& ncclComms = it.second; + for (const auto& ncclComm : ncclComms) { + ncclCommDevIdxMap.erase(ncclComm); + } + } + ncclCommDevIdxMapMutex.unlock(); + std::lock_guard lock(mutex_); abortCommsFromMap(devNCCLCommMap_, rank_, abortReason); abortCommsFromMap(inInitializationCommMap_, rank_, abortReason); @@ -1498,6 +1596,12 @@ void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { devNCCLCommMap_.erase(devNCCLCommMapKey); // Clear used device indices. usedDeviceIdxs_.clear(); + + ncclCommDevIdxMapMutex.lock(); + for (const auto& comm : ncclComms) { + ncclCommDevIdxMap.erase(comm); + } + ncclCommDevIdxMapMutex.unlock(); } std::vector>& ProcessGroupNCCL::getNCCLComm( @@ -1662,6 +1766,34 @@ std::vector>& ProcessGroupNCCL::getNCCLComm( if (it != inInitializationCommMap_.end()) { devNCCLCommMap_.emplace(devicesKey, std::move(it->second)); inInitializationCommMap_.erase(devicesKey); + + // Now ncclComms are fully initialized. + // Register all active CUDA memory segments in cache allocator to + // the new NCCL communicators + if (useTensorRegisterAllocatorHook_) { + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); + // Register the segment to a new NCCL communicator if on the same device + for (const auto& segmentInfo : snapshot.segments) { + for (const auto i : c10::irange(devices.size())) { + if (segmentInfo.device != devices[i].index()) + continue; + ncclComms[i]->registerSegment( + reinterpret_cast(segmentInfo.address), + segmentInfo.total_size); + } + } + + // Record the mapping between ncclComm and device index so that later + // register hook can register a newly allocated segment to communicators + // on the same device. + // NOTE: we need remove the communicator from this map when it is + // destroyed, otherwise may register onto an invalid communicator. + ncclCommDevIdxMapMutex.lock(); + for (const auto i : c10::irange(devices.size())) { + ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index()); + } + ncclCommDevIdxMapMutex.unlock(); + } } it = devNCCLCommMap_.find(devicesKey); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index ac69725a4adef..53bec2d8d5594 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -76,6 +76,12 @@ enum ErrorHandlingMode { constexpr const char* TORCH_NCCL_AVOID_RECORD_STREAMS = "TORCH_NCCL_AVOID_RECORD_STREAMS"; +// If set, ProcessGroupNCCL registers postAlloc and preFree hooks to cuda cache +// allocator so that whenever a tensor is allocated or freed, ProcessGroupNCCL +// can register/deregister the tensor on all available NCCL communicators. +constexpr const char* NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK = + "NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"; + // ProcessGroupNCCL implements NCCL bindings for c10d. // // All functions of the class are expected to be called in the same order @@ -766,6 +772,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // for the operation to complete. bool blockingWait_ = false; + // Whether or not to hook the cache allocator to register all allocated + // tensors + bool useTensorRegisterAllocatorHook_ = false; + // Whether or not the workCleanupThread is used to perform async error // handling. ErrorHandlingMode asyncErrorHandling_ = NoHandling;