Skip to content

Commit

Permalink
enable alltoall_single torchscript support (#48345)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #48345

Test Plan: wait for sandcastle

Differential Revision: D25074475

fbshipit-source-id: 04261f8453567154b0464f8348320e936ca06384
  • Loading branch information
wanchaol authored and facebook-github-bot committed Jan 7, 2021
1 parent 4e2ab2c commit 838e73d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
Expand Up @@ -72,7 +72,7 @@ def allow_listed(schema, allow_list):
dont_parse_list = [
("_TorchScriptTesting.*", datetime.date(2099, 9, 17)),
("test_backend", datetime.date(2099, 9, 17)),
("c10d.frontend", datetime.date(2020, 12, 30)),
("dist_c10d", datetime.date(2021, 1, 30)),
]


Expand Down
27 changes: 14 additions & 13 deletions test/distributed/test_jit_c10d.py
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.distributed as c10d
import time
from datetime import timedelta
from typing import List

import torch.testing._internal.common_utils as common
Expand Down Expand Up @@ -31,6 +32,14 @@ def unique_process_group_name(prefix):
now = int(time.time() * 1000)
return "%s_%d" % (prefix, now)

def _create_tcp_store():
addr = "localhost"
port = common.find_free_port()
timeout = timedelta(minutes=5)
timeout_millisecond = int(timeout / timedelta(milliseconds=1))
return torch.classes.dist_c10d.TCPStore(addr, port, 1, True, timeout_millisecond)


@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
Expand All @@ -48,19 +57,15 @@ def setUp(self):
raise unittest.SkipTest("NCCL test requires 2+ GPUs")

def _create_nccl_pg(self, name_prefix):
addr = "localhost"
port = common.find_free_port()
tcp_store = torch.classes.dist_c10d.TCPStore(addr, port, 1, True)
tcp_store = _create_tcp_store()
opts = torch.classes.dist_c10d.ProcessGroupNCCLOptions(0, True)

name = unique_process_group_name(name_prefix)

return torch.classes.dist_c10d.ProcessGroupNCCL(tcp_store, self.rank, self.world_size, opts, name)
return torch.classes.dist_c10d.ProcessGroupNCCL(tcp_store, self.rank, self.world_size, opts, name)

def _create_nccl_pg_as_base_process_group(self, name):
addr = "localhost"
port = common.find_free_port()
tcp_store = torch.classes.dist_c10d.TCPStore(addr, port, 1, True)
tcp_store = _create_tcp_store()

return torch.classes.dist_c10d.frontend().new_process_group_helper(
self.world_size, self.rank, [], "nccl", tcp_store, name, 0)
Expand Down Expand Up @@ -155,9 +160,7 @@ def test_frontend_singleton(self):
frontend1 = torch.classes.dist_c10d.frontend()
frontend2 = torch.classes.dist_c10d.frontend()

addr = "localhost"
port = common.find_free_port()
tcp_store = torch.classes.dist_c10d.TCPStore(addr, port, 1, True)
tcp_store = _create_tcp_store()

pg_name = unique_process_group_name("singleton_test_process_group")

Expand All @@ -180,9 +183,7 @@ def test_process_group_as_module_member(self):
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
addr = "localhost"
port = common.find_free_port()
tcp_store = torch.classes.dist_c10d.TCPStore(addr, port, 1, True)
tcp_store = _create_tcp_store()

name = unique_process_group_name("module_member_process_group")
self.pg = torch.classes.dist_c10d.frontend().new_process_group_helper(
Expand Down
27 changes: 24 additions & 3 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -1260,11 +1260,25 @@ static const auto TCPStoreTorchBind =
.def(torch::init([](const std::string& host_name,
int64_t port,
int64_t world_size,
bool is_master) {
bool is_master,
int64_t timeout) {
auto timeout_miliseconds = std::chrono::milliseconds(timeout);
return c10::make_intrusive<::c10d::TCPStore>(
host_name, port, world_size, is_master);
host_name, port, world_size, is_master, timeout_miliseconds);
}));

// TODO: This should really take Store as constructor argument instead of
// TCPStore, but the fact that TorchScript does not support polymorphism
// forced us to cast in C++ instead of automatic casting
static const auto PrefixStoreTorchBind =
torch::class_<::c10d::PrefixStore>("dist_c10d", "PrefixStore")
.def(torch::init([](const std::string& prefix,
const c10::intrusive_ptr<::c10d::TCPStore>& store) {
return c10::make_intrusive<::c10d::PrefixStore>(
prefix, store);
}));


// Torchbind the ProcessGroup to make it available in TorchScript
static const auto ProcessGroupWorkTorchBind =
torch::class_<::c10d::ProcessGroup::Work>("dist_c10d", "Work")
Expand Down Expand Up @@ -1624,7 +1638,14 @@ static const auto ProcessGroupNCCLTorchBind =
outputSplitSizes,
inputSplitSizes,
::c10d::AllToAllOptions());
});

})
.def("size", [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
return (int64_t) self->getSize();
})
.def("rank", [](const c10::intrusive_ptr<::c10d::ProcessGroupNCCL>& self) {
return (int64_t) self->getRank();
});
#endif

static const auto DistributedC10dFrontendTorchBind =
Expand Down

0 comments on commit 838e73d

Please sign in to comment.