From c5e509827c6f780c74cd9f785d7afb2ad3eb21d7 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Tue, 2 Jul 2019 05:01:11 -0700 Subject: [PATCH] Add device guard around MPI operations If the current CUDA device is not the same as the device that hosts the tensor the operation works on then OpenMPI will segfault, as reported in #21922. This changes adds a device guard for every operation to ensure the correct device is set. Fixes #21922. --- torch/lib/c10d/ProcessGroupMPI.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index ea1f9f46ced45..14341b82a03da 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -2,6 +2,8 @@ #include +#include + #if defined(OPEN_MPI) && OPEN_MPI #include // Needed for CUDA-aware check #endif @@ -316,6 +318,7 @@ std::shared_ptr ProcessGroupMPI::broadcast( std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->src)[0]; + c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Bcast( data.data_ptr(), @@ -337,6 +340,7 @@ std::shared_ptr ProcessGroupMPI::allreduce( std::function&)> runFunc = [opts, this](std::unique_ptr& entry) { auto data = (entry->src)[0]; + c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Allreduce( MPI_IN_PLACE, @@ -363,6 +367,7 @@ std::shared_ptr ProcessGroupMPI::reduce( void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr; void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr; + c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Reduce( sendbuf, @@ -402,6 +407,7 @@ std::shared_ptr ProcessGroupMPI::allgather( std::vector& outputDataVec = entry->dst; auto flatOutputTensor = newLikeFlat(outputDataVec); + c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Allgather( data.data_ptr(), @@ -456,6 +462,7 @@ std::shared_ptr ProcessGroupMPI::gather( recvbuf = flatOutputTensor.data_ptr(); } + c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Gather( data.data_ptr(), @@ -529,6 +536,7 @@ std::shared_ptr ProcessGroupMPI::scatter( } } + c10::DeviceGuard guard(data.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Scatter( sendbuf, @@ -569,6 +577,7 @@ std::shared_ptr ProcessGroupMPI::send( MPI_Request request = MPI_REQUEST_NULL; { + c10::DeviceGuard guard(tensor.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Isend( tensor.data_ptr(), @@ -593,6 +602,7 @@ std::shared_ptr ProcessGroupMPI::recv( MPI_Request request = MPI_REQUEST_NULL; { + c10::DeviceGuard guard(tensor.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Irecv( tensor.data_ptr(), @@ -616,6 +626,7 @@ std::shared_ptr ProcessGroupMPI::recvAnysource( MPI_Request request = MPI_REQUEST_NULL; { + c10::DeviceGuard guard(tensor.device()); std::unique_lock globalLock(pgGlobalMutex_); MPI_CHECK(MPI_Irecv( tensor.data_ptr(),