Skip to content

Commit

Permalink
Revert D24667127: [c10d] switch ProcessGroupNCCL:Options to be manage…
Browse files Browse the repository at this point in the history
…d by intrusive_ptr

Test Plan: revert-hammer

Differential Revision:
D24667127 (ae5c2fe)

Original commit changeset: 54986193ba1b

fbshipit-source-id: 12e1ebea1981c0b1b6dff4c8a2e2045878d44537
  • Loading branch information
wanchaol authored and facebook-github-bot committed Nov 11, 2020
1 parent 0c64f9f commit 2204374
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 31 deletions.
12 changes: 5 additions & 7 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -1000,17 +1000,16 @@ that adds a prefix to each key inserted to the store.
const c10::intrusive_ptr<::c10d::Store>&,
int,
int,
const c10::intrusive_ptr<
::c10d::ProcessGroupNCCL::Options>&>(),
::c10d::ProcessGroupNCCL::Options>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::milliseconds& timeout) {
auto options = ::c10d::ProcessGroupNCCL::Options::create();
options->isHighPriorityStream = false;
options->opTimeout = timeout;
::c10d::ProcessGroupNCCL::Options options;
options.isHighPriorityStream = false;
options.opTimeout = timeout;
return std::make_shared<::c10d::ProcessGroupNCCL>(
store, rank, size, options);
}),
Expand All @@ -1021,8 +1020,7 @@ that adds a prefix to each key inserted to the store.
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis),
py::call_guard<py::gil_scoped_release>());

intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
processGroupNCCL, "Options")
py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options")
.def(py::init<>())
.def_readwrite(
"is_high_priority",
Expand Down
6 changes: 3 additions & 3 deletions torch/lib/c10d/ProcessGroupNCCL.cpp
Expand Up @@ -440,14 +440,14 @@ ProcessGroupNCCL::ProcessGroupNCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const c10::intrusive_ptr<Options>& options)
Options options)
: ProcessGroup(rank, size),
store_(store),
ncclCommCounter_(0),
terminateProcessGroup_(false),
opTimeout_(options->opTimeout),
opTimeout_(options.opTimeout),
futureNCCLCallbackStreams_(c10::cuda::device_count()),
isHighPriorityStream_(options->isHighPriorityStream) {
isHighPriorityStream_(options.isHighPriorityStream) {
TORCH_CHECK(at::cuda::getNumGPUs() != 0,
"ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT);
Expand Down
16 changes: 3 additions & 13 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
@@ -1,6 +1,5 @@
#pragma once

#include <chrono>
#include <iostream>
#include <list>
#include <mutex>
Expand All @@ -18,8 +17,6 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAStream.h>

#include <torch/custom_class.h>

namespace c10d {

// Environment variable which controls whether or not wait() is blocking or
Expand Down Expand Up @@ -178,16 +175,9 @@ class ProcessGroupNCCL : public ProcessGroup {
friend class ProcessGroupNCCL;
};

struct Options : torch::CustomClassHolder {
struct Options {
explicit Options();

// return intrusive_ptr of the object
static c10::intrusive_ptr<Options> create(
std::chrono::milliseconds timeout = kNoTimeout,
bool isHighStream = false) {
return c10::make_intrusive<Options>();
}

std::chrono::milliseconds opTimeout;
bool isHighPriorityStream;
};
Expand Down Expand Up @@ -406,7 +396,7 @@ class ProcessGroupNCCL : public ProcessGroup {
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const c10::intrusive_ptr<Options>& options = Options::create());
Options options = Options());

// This constructor includes the deprecated `groupName` argument.
// If you have existing code that uses the `groupName`, you can replace
Expand All @@ -416,7 +406,7 @@ class ProcessGroupNCCL : public ProcessGroup {
int rank,
int size,
const std::string& groupName,
const c10::intrusive_ptr<Options>& options = Options::create())
Options options = Options())
: ProcessGroupNCCL(store, rank, size, options) {}

virtual ~ProcessGroupNCCL();
Expand Down
16 changes: 8 additions & 8 deletions torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp
Expand Up @@ -40,7 +40,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
const c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options>& opts)
c10d::ProcessGroupNCCL::Options opts)
: ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {}

std::exception_ptr checkForNCCLErrors(
Expand Down Expand Up @@ -109,7 +109,7 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
const c10::intrusive_ptr<c10d::Store>& store,
int rank,
int size,
const c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options>& opts)
c10d::ProcessGroupNCCL::Options opts)
: ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
set_timedout_error_(false) {}

Expand Down Expand Up @@ -177,8 +177,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
}

ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->opTimeout = std::chrono::milliseconds(1000);
c10d::ProcessGroupNCCL::Options options;
options.opTimeout = std::chrono::milliseconds(1000);
ProcessGroupNCCLSimulateErrors pg(
store_, 0, 1, options);

Expand Down Expand Up @@ -206,8 +206,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
}

ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
auto options = c10d::ProcessGroupNCCL::Options::create();
options->opTimeout = std::chrono::milliseconds(3000);
c10d::ProcessGroupNCCL::Options options;
options.opTimeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLTimedOutErrors pg(
store_, 0, 1, options);

Expand All @@ -229,8 +229,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
return;
}

auto options = c10d::ProcessGroupNCCL::Options::create();
options->opTimeout = std::chrono::milliseconds(3000);
c10d::ProcessGroupNCCL::Options options;
options.opTimeout = std::chrono::milliseconds(3000);
ProcessGroupNCCLSimulateErrors pg(
store_, 0, 1, options);

Expand Down

0 comments on commit 2204374

Please sign in to comment.