Skip to content

Commit

Permalink
c10d: added Collectives abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed May 17, 2024
1 parent 51e9bb8 commit b345e70
Show file tree
Hide file tree
Showing 11 changed files with 837 additions and 53 deletions.
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ cc_library(
[
"torch/*.h",
"torch/csrc/**/*.h",
"torch/csrc/distributed/c10d/*.hpp",
"torch/csrc/distributed/c10d/**/*.hpp",
"torch/lib/libshm/*.h",
],
exclude = [
Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ libtorch_core_sources = sorted(
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/Functional.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
Expand Down
189 changes: 189 additions & 0 deletions test/distributed/test_control_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Owner(s): ["oncall: distributed"]

from datetime import timedelta
from multiprocessing.pool import ThreadPool

import torch
import torch.distributed as dist
from torch.testing._internal.common_utils import run_tests, TestCase


class TestCollectives(TestCase):
def test_barrier(self) -> None:
store = dist.HashStore()

world_size = 2

def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
collectives.barrier("foo", timedelta(seconds=10), True)

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_broadcast(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
collectives.broadcast_send("foo", b"data", timeout)
else:
out = collectives.broadcast_recv("foo", timeout)
self.assertEqual(out, b"data")

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_gather(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.gather_recv("foo", str(rank), timeout)
self.assertEqual(out, [b"0", b"1", b"2", b"3"])
else:
collectives.gather_send("foo", str(rank), timeout)

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_scatter(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.scatter_send(
"foo", [str(i) for i in range(world_size)], timeout
)
else:
out = collectives.scatter_recv("foo", timeout)
self.assertEqual(out, str(rank).encode())

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_all_sum(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist._StoreCollectives(store, rank, world_size)
out = collectives.all_sum("foo", rank, timeout)
self.assertEqual(out, sum(range(world_size)))

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_broadcast_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.broadcast_recv("foo", timeout)

def test_gather_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "gather failed -- missing ranks: 0, 2, 3"
):
collectives.gather_recv("foo", "data", timeout)

def test_scatter_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.scatter_recv("foo", timeout)

def test_all_gather_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "all_gather failed -- missing ranks: 0, 2, 3"
):
collectives.all_gather("foo", "data", timeout)

def test_barrier_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.barrier("foo", timeout, True)

def test_all_sum_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist._StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(
Exception, "barrier failed -- missing ranks: 0, 2, 3"
):
collectives.all_sum("foo", 1, timeout)

def test_unique(self) -> None:
store = dist.HashStore()

collectives = dist._StoreCollectives(store, 1, 1)
collectives.broadcast_send("foo", "bar")

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_send("foo", "bar")

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.broadcast_recv("foo")

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_send("foo", "bar")

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.gather_recv("foo", "asdf")

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_send("foo", ["asdf"])

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.scatter_recv("foo")

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_gather("foo", "bar")

with self.assertRaisesRegex(Exception, "Key foo has already been used"):
collectives.all_sum("foo", 2)


if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"

run_tests()
14 changes: 14 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,20 @@ class PrefixStore(Store):
@property
def underlying_store(self) -> Store: ...

class _ControlCollectives:
def barrier(self, key: str, timeout: timedelta, blocking: bool) -> None: ...
def broadcast_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def broadcast_recv(self, key: str, timeout: timedelta) -> str: ...
def gather_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def gather_recv(self, key: str, timeout: timedelta) -> str: ...
def scatter_send(self, key: str, data: str, timeout: timedelta) -> None: ...
def scatter_recv(self, key: str, timeout: timedelta) -> str: ...
def all_gather(self, key: str, data: str, timeout: timedelta) -> str: ...
def all_sum(self, key: str, data: str, timeout: timedelta) -> int: ...

class _StoreCollectives(_ControlCollectives):
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...

class _DistributedBackendOptions:
def __init__(self): ...
@property
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/HashStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store {
std::vector<uint8_t> get(const std::string& key) override;

void wait(const std::vector<std::string>& keys) override {
wait(keys, Store::kDefaultTimeout);
wait(keys, timeout_);
}

void wait(
Expand Down
29 changes: 29 additions & 0 deletions torch/csrc/distributed/c10d/Store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,33 @@ class TORCH_API Store : public torch::CustomClassHolder {
std::chrono::milliseconds timeout_;
};

/*
StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it
when it returns.
*/
class StoreTimeoutGuard {
public:
explicit StoreTimeoutGuard(
Store& store,
const std::chrono::milliseconds& timeout)
: store_(store) {
oldTimeout_ = store.getTimeout();
store.setTimeout(timeout);
}

~StoreTimeoutGuard() {
store_.setTimeout(oldTimeout_);
}

/* Disabling copy and move semantics */
StoreTimeoutGuard(const StoreTimeoutGuard&) = delete;
StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete;
StoreTimeoutGuard(StoreTimeoutGuard&&) = delete;
StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete;

private:
Store& store_;
std::chrono::milliseconds oldTimeout_;
};

} // namespace c10d
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#pragma once

#include <ATen/core/ivalue.h>
#include <chrono>
#include <cstdint>
#include <string>
#include <vector>

#include <c10/macros/Macros.h>
#include <torch/custom_class.h>

namespace c10d {

using namespace std::chrono_literals;

class TORCH_API ControlCollectives : public torch::CustomClassHolder {
public:
virtual void barrier(
const std::string& key,
std::chrono::milliseconds timeout = 5min,
bool block = true) = 0;

virtual void broadcastSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<uint8_t> broadcastRecv(
const std::string& key,
std::chrono::milliseconds timeout = 5min) = 0;

virtual void gatherSend(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<std::vector<uint8_t>> gatherRecv(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;

virtual std::vector<uint8_t> scatterSend(
const std::string& key,
const std::vector<std::vector<uint8_t>>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<uint8_t> scatterRecv(
const std::string& key,
std::chrono::milliseconds timeout = 5min) = 0;

virtual std::vector<std::vector<uint8_t>> allGather(
const std::string& key,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;

virtual int64_t allSum(
const std::string& key,
int64_t data,
std::chrono::milliseconds timeout = 5min) = 0;
};

} // namespace c10d

0 comments on commit b345e70

Please sign in to comment.