Skip to content

Commit

Permalink
Back out "Revert D31299350: Back out "Revert D31005792: [NCCL] Init d…
Browse files Browse the repository at this point in the history
…ummy NCCL comms in constructor"" (#66393)

Summary:
Pull Request resolved: #66393

Third try!

Fixes:
- test_nccl_timeout can be flaky because of 1s timeout, bump up the timeout to resolve the flakiness. But in general we should not have been relying on time.sleep for this test, filed #66354 to track that.
- ciflow/all did not actually run tests due to a bug causing multigpu tests to not be run. This has since been fixed.
ghstack-source-id: 140560113

Test Plan: CI

Reviewed By: mrshenli

Differential Revision: D31534735

fbshipit-source-id: 8b7e0f4fed3972b7a77cbcda28876c9eefb0c7e2
  • Loading branch information
rohan-varma authored and facebook-github-bot committed Oct 15, 2021
1 parent 59b2806 commit 06fa6c1
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 34 deletions.
79 changes: 74 additions & 5 deletions test/cpp/c10d/ProcessGroupNCCLTest.cpp
@@ -1,9 +1,12 @@
#include <chrono>
#include <iostream>

#include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupNCCL.hpp>
#include "CUDATest.hpp"
#include "TestUtils.hpp"
#include "c10d/Types.hpp"
#include "c10d/ProcessGroup.hpp"

#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
Expand All @@ -19,7 +22,10 @@ using c10d::ProcessGroup;

class NCCLTestBase {
public:
NCCLTestBase(const std::string& path) : path_(path) {}
NCCLTestBase(
const std::string& path,
const std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout
) : path_(path), pgTimeout_(pgTimeout) {}

NCCLTestBase(NCCLTestBase&& other) {
path_ = std::move(other.path_);
Expand All @@ -33,19 +39,22 @@ class NCCLTestBase {
void initialize(int rank, int size) {
auto store = c10::make_intrusive<::c10d::FileStore>(path_, size);

c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts = c10::make_intrusive<c10d::ProcessGroupNCCL::Options>();
opts->timeout = pgTimeout_;
pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>(
new ::c10d::ProcessGroupNCCL(store, rank, size));
new ::c10d::ProcessGroupNCCL(store, rank, size, std::move(opts)));
}

protected:
std::string path_;
std::unique_ptr<::c10d::ProcessGroupNCCL> pg_;
std::chrono::milliseconds pgTimeout_;
};

class NCCLTest : public NCCLTestBase {
public:
NCCLTest(const std::string& path, int worldSize)
: NCCLTestBase(path),
NCCLTest(const std::string& path, int worldSize, std::chrono::milliseconds pgTimeout = kProcessGroupDefaultTimeout)
: NCCLTestBase(path, pgTimeout),
numDevices_(cudaNumDevices()),
state_(::at::globalContext().lazyInitCUDA()),
worldSize_(worldSize) {
Expand Down Expand Up @@ -497,10 +506,50 @@ void testReduceScatter(const std::string& path, int rank, int size) {
}
}

void testProcessGroupNCCLHealthCheckFailHelper(const std::string& path, bool timeout) {
// simulate world_size > 1 here via threads.
const int worldSize = 4;
std::mutex m;
std::unordered_set<uint64_t> nums;
auto runTest = [&](int i) {
NCCLTest test(path, worldSize, std::chrono::milliseconds(3000));
// Catch error relating to health check failure
bool error_caught = false;
try {
test.initialize(timeout ? 0 : -1, worldSize);
} catch (const std::exception &e) {
std::string errMsg = e.what();
const std::string kTimeoutErr = "Failed to initialize NCCL communicator on rank";
const std::string kInvalidRankErr = "Invalid rank";
std::string expectedSubstr = timeout ? kTimeoutErr : kInvalidRankErr;
bool cond = errMsg.find(expectedSubstr) != std::string::npos;
EXPECT_TRUE(cond);
error_caught = true;
}
EXPECT_TRUE(error_caught);
};
std::vector<std::thread> threads;
threads.reserve(worldSize);
for (const auto r : c10::irange(worldSize)) {
threads.emplace_back(std::thread([=]() { runTest(r); }));
}
for (auto& t : threads) {
t.join();
}
}

void testProcessGroupNCCLHealthCheckFailException(const std::string& path, int /* unused */, int /* unused */) {
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ false);
}

void testProcessGroupNCCLHealthCheckFailTimeout(const std::string& path, int /* unused */, int /* unused */) {
testProcessGroupNCCLHealthCheckFailHelper(path, /* timeout */ true);
}

void testSequenceNumInit(const std::string& path, int /* unused */, int /* unused */) {
// Note: ProcessGroupNCCLTest doesn't support multiprocess testing. So we
// simulate world_size > 1 here via threads.
const int worldSize = 4;
const int worldSize = 2;
std::mutex m;
std::unordered_set<uint64_t> nums;
auto runTest = [&](int i) {
Expand Down Expand Up @@ -625,6 +674,26 @@ TEST_F(ProcessGroupNCCLTest, testSequenceNumInit) {
}
}

TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailTimeout) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testProcessGroupNCCLHealthCheckFailTimeout(file.path, rank_, size_);
}
}

TEST_F(ProcessGroupNCCLTest, testProcessGroupNCCLHealthCheckFailException) {
if (skipTest()) {
return;
}
{
TemporaryFile file;
testProcessGroupNCCLHealthCheckFailException(file.path, rank_, size_);
}
}

TEST_F(ProcessGroupNCCLTest, testReduceScatterBase) {
if (skipTest()) {
return;
Expand Down
40 changes: 20 additions & 20 deletions test/distributed/test_c10d_nccl.py
Expand Up @@ -657,11 +657,13 @@ def test_nccl_propagate_error_reason(self):
# otherwise process will be taken down and we can't check for errors.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
os.environ["NCCL_BLOCKING_WAIT"] = "1"
timeout = timedelta(seconds=2)
# TODO: smaller timeout can fail since PG NCCl does health check in
# constructor. Look into reducing this test's runtime.
store = c10d.FileStore(self.file_name, self.world_size)
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timeout)
# provide sufficient timeout to initialize NCCL comm.
pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, timeout=timedelta(seconds=15))
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
pg.barrier().wait()
pg.barrier().wait(timedelta(seconds=5))
# Simulate stuckness in rank 0.
if self.rank == 0:
pg_gloo.barrier().wait()
Expand All @@ -670,7 +672,7 @@ def test_nccl_propagate_error_reason(self):
if self.rank != 0:
# Time out due to rank 0 not calling into allreduce.
with self.assertRaises(RuntimeError):
pg.allreduce([inp]).wait()
pg.allreduce([inp]).wait(timedelta(seconds=5))

# Now when nonzero rank attempts to use communicator, original failure reason should be logged.j
try:
Expand Down Expand Up @@ -2264,14 +2266,14 @@ def _test_nccl_errors_blocking(self, func):
store,
self.rank,
self.world_size,
timeout=timedelta(seconds=self.op_timeout_sec),
timeout=timedelta(seconds=10),
)
process_group.allreduce(torch.rand(10).cuda(self.rank))
if self.rank == 0:
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
# Operation would time out in blocking mode.
work.wait()
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
# Run some GPU operations to make sure cuda has not gotten stuck.
# It was observed cuda could get stuck if NCCL communicators were
# not properly aborted before throwing RuntimeError.
Expand Down Expand Up @@ -2340,13 +2342,13 @@ def test_nccl_blocking_wait_with_barrier(self):
store,
self.rank,
self.world_size,
timeout=timedelta(seconds=self.op_timeout_sec),
timeout=timedelta(seconds=10),
)
process_group.barrier().wait()
if self.rank == 0:
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
# This should timeout
process_group.barrier().wait()
process_group.barrier().wait(timeout=timedelta(seconds=self.op_timeout_sec))

def _run_invalid_nccl_blocking_wait_env(self, val):
os.environ["NCCL_BLOCKING_WAIT"] = val
Expand Down Expand Up @@ -2383,21 +2385,19 @@ def test_nccl_timeout(self):
store = c10d.FileStore(self.file_name, self.world_size)

# Initialize process_group.
timeout = 1
process_group = c10d.ProcessGroupNCCL(
store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
)
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=5))

if self.rank == 0:
# This should timeout in about 1 second.
start = time.time()
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
else:
# Sleep to ensure timeout.
time.sleep(2 * timeout)
time.sleep(10)

self._wait_for_comm_abort(process_group)

Expand Down Expand Up @@ -2547,14 +2547,14 @@ def test_nccl_barrier_timeout(self):
store = c10d.FileStore(self.file_name, self.world_size)
if self.rank == 0:
with self.assertRaisesRegex(
RuntimeError, "Timed out initializing process group"
RuntimeError, "Health check failure"
):
c10d.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=store,
timeout=timedelta(seconds=1),
timeout=timedelta(seconds=10),
)

@requires_nccl()
Expand All @@ -2566,12 +2566,12 @@ def test_nccl_barrier_timeout_new_group(self):
rank=self.rank,
world_size=self.world_size,
store=store,
timeout=timedelta(seconds=1),
timeout=timedelta(seconds=10),
)

if self.rank == 0:
with self.assertRaisesRegex(
RuntimeError, "Timed out initializing process group"
RuntimeError, "Health check failure"
):
c10d.new_group([0, 1], timeout=timedelta(seconds=1))

Expand All @@ -2589,12 +2589,12 @@ def test_nccl_barrier_timeout_new_group_non_member(self):
rank=self.rank,
world_size=self.world_size,
store=store,
timeout=timedelta(seconds=1),
timeout=timedelta(seconds=10),
)

if self.rank == 1:
with self.assertRaisesRegex(
RuntimeError, "Timed out initializing process group"
RuntimeError, "Health check failure"
):
c10d.new_group([0, 1], timeout=timedelta(seconds=1))

Expand Down
8 changes: 4 additions & 4 deletions test/distributed/test_jit_c10d.py
Expand Up @@ -39,7 +39,7 @@ def setUp(self):

def _create_nccl_pg(self, name_prefix):
tcp_store = create_tcp_store(jit_class=True)
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True)
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(10000, True)

name = unique_process_group_name(name_prefix)

Expand All @@ -49,7 +49,7 @@ def _create_nccl_pg_as_base_process_group(self, name):
tcp_store = create_tcp_store(jit_class=True)

return torch.classes.dist_c10d.frontend().new_process_group_helper(
self.world_size, self.rank, [], "nccl", tcp_store, name, 0)
self.world_size, self.rank, [], "nccl", tcp_store, name, 10000)

@requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_frontend_singleton(self):
pg_name = unique_process_group_name("singleton_test_process_group")

ProcessGroupNCCL1 = frontend1.new_process_group_helper(
self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 0)
self.world_size, self.rank, [], "nccl", tcp_store, pg_name, 10000)

ProcessGroupNCCL2 = frontend2.get_process_group_by_name(pg_name)
self.assertEqual(frontend2.get_name_of_process_group(ProcessGroupNCCL2), pg_name)
Expand All @@ -189,7 +189,7 @@ def __init__(self):

name = unique_process_group_name("module_member_process_group")
self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper(
1, 0, [], "nccl", tcp_store, name, 0)
1, 0, [], "nccl", tcp_store, name, 10000)

def forward(self, input: torch.Tensor):
if self.pg is None:
Expand Down

0 comments on commit 06fa6c1

Please sign in to comment.