diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index dd32ff91603c..e9d8f618eb21 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1000,17 +1000,16 @@ that adds a prefix to each key inserted to the store. const c10::intrusive_ptr<::c10d::Store>&, int, int, - const c10::intrusive_ptr< - ::c10d::ProcessGroupNCCL::Options>&>(), + ::c10d::ProcessGroupNCCL::Options>(), py::call_guard()) .def( py::init([](const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::milliseconds& timeout) { - auto options = ::c10d::ProcessGroupNCCL::Options::create(); - options->isHighPriorityStream = false; - options->opTimeout = timeout; + ::c10d::ProcessGroupNCCL::Options options; + options.isHighPriorityStream = false; + options.opTimeout = timeout; return std::make_shared<::c10d::ProcessGroupNCCL>( store, rank, size, options); }), @@ -1021,8 +1020,7 @@ that adds a prefix to each key inserted to the store. ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis), py::call_guard()); - intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>( - processGroupNCCL, "Options") + py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options") .def(py::init<>()) .def_readwrite( "is_high_priority", diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 0b1a4c9f34e6..acb81d0cad6d 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -440,14 +440,14 @@ ProcessGroupNCCL::ProcessGroupNCCL( const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& options) + Options options) : ProcessGroup(rank, size), store_(store), ncclCommCounter_(0), terminateProcessGroup_(false), - opTimeout_(options->opTimeout), + opTimeout_(options.opTimeout), futureNCCLCallbackStreams_(c10::cuda::device_count()), - isHighPriorityStream_(options->isHighPriorityStream) { + isHighPriorityStream_(options.isHighPriorityStream) { TORCH_CHECK(at::cuda::getNumGPUs() != 0, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"); blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT); diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index b84cc4deb051..b93bd0c2d70c 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -18,8 +17,6 @@ #include #include -#include - namespace c10d { // Environment variable which controls whether or not wait() is blocking or @@ -178,16 +175,9 @@ class ProcessGroupNCCL : public ProcessGroup { friend class ProcessGroupNCCL; }; - struct Options : torch::CustomClassHolder { + struct Options { explicit Options(); - // return intrusive_ptr of the object - static c10::intrusive_ptr create( - std::chrono::milliseconds timeout = kNoTimeout, - bool isHighStream = false) { - return c10::make_intrusive(); - } - std::chrono::milliseconds opTimeout; bool isHighPriorityStream; }; @@ -406,7 +396,7 @@ class ProcessGroupNCCL : public ProcessGroup { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& options = Options::create()); + Options options = Options()); // This constructor includes the deprecated `groupName` argument. // If you have existing code that uses the `groupName`, you can replace @@ -416,7 +406,7 @@ class ProcessGroupNCCL : public ProcessGroup { int rank, int size, const std::string& groupName, - const c10::intrusive_ptr& options = Options::create()) + Options options = Options()) : ProcessGroupNCCL(store, rank, size, options) {} virtual ~ProcessGroupNCCL(); diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 82ca25049c63..e19981c523de 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -40,7 +40,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& opts) + c10d::ProcessGroupNCCL::Options opts) : ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {} std::exception_ptr checkForNCCLErrors( @@ -109,7 +109,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { const c10::intrusive_ptr& store, int rank, int size, - const c10::intrusive_ptr& opts) + c10d::ProcessGroupNCCL::Options opts) : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), set_timedout_error_(false) {} @@ -177,8 +177,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - auto options = c10d::ProcessGroupNCCL::Options::create(); - options->opTimeout = std::chrono::milliseconds(1000); + c10d::ProcessGroupNCCL::Options options; + options.opTimeout = std::chrono::milliseconds(1000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options); @@ -206,8 +206,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { } ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0); - auto options = c10d::ProcessGroupNCCL::Options::create(); - options->opTimeout = std::chrono::milliseconds(3000); + c10d::ProcessGroupNCCL::Options options; + options.opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLTimedOutErrors pg( store_, 0, 1, options); @@ -229,8 +229,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { return; } - auto options = c10d::ProcessGroupNCCL::Options::create(); - options->opTimeout = std::chrono::milliseconds(3000); + c10d::ProcessGroupNCCL::Options options; + options.opTimeout = std::chrono::milliseconds(3000); ProcessGroupNCCLSimulateErrors pg( store_, 0, 1, options);