Skip to content

Commit

Permalink
c10d: added Collectives abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed May 10, 2024
1 parent b08072f commit b9e1a53
Show file tree
Hide file tree
Showing 10 changed files with 701 additions and 1 deletion.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ libtorch_core_sources = sorted(
# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/StoreCollectives.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/Functional.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
Expand Down
151 changes: 151 additions & 0 deletions test/distributed/test_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from multiprocessing.pool import ThreadPool
from datetime import timedelta

import torch
from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)
import torch.distributed as dist

class TestCollectives(TestCase):
def test_barrier(self) -> None:
store = dist.HashStore()

world_size = 2

def f(rank: int) -> None:
collectives = dist.StoreCollectives(store, rank, world_size)
collectives.barrier("foo", timedelta(seconds=10), True)

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_broadcast(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist.StoreCollectives(store, rank, world_size)
if rank == 2:
collectives.broadcast_send("foo", "data", timeout)
else:
out = collectives.broadcast_recv("foo", timeout)
self.assertEqual(out, "data")

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_gather(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist.StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.gather_recv("foo", str(rank), timeout)
self.assertEqual(out, ["0", "1", "2", "3"])
else:
collectives.gather_send("foo", str(rank), timeout)

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_scatter(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist.StoreCollectives(store, rank, world_size)
if rank == 2:
out = collectives.scatter_send("foo", [str(i) for i in range(world_size)], timeout)
else:
out = collectives.scatter_recv("foo", timeout)
self.assertEqual(out, str(rank))

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_all_sum(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(seconds=10)

def f(rank: int) -> None:
collectives = dist.StoreCollectives(store, rank, world_size)
out = collectives.all_sum("foo", rank, timeout)
self.assertEqual(out, sum(range(world_size)))

with ThreadPool(world_size) as pool:
pool.map(f, range(world_size))

def test_broadcast_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist.StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.broadcast_recv("foo", timeout)

def test_gather_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist.StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "gather failed -- missing ranks: 0, 2, 3"):
collectives.gather_recv("foo", "data", timeout)

def test_scatter_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist.StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "Wait timeout"):
collectives.scatter_recv("foo", timeout)

def test_all_gather_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist.StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "all_gather failed -- missing ranks: 0, 2, 3"):
collectives.all_gather("foo", "data", timeout)

def test_barrier_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist.StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "barrier failed -- missing ranks: 0, 2, 3"):
collectives.barrier("foo", timeout, True)

def test_all_sum_timeout(self) -> None:
store = dist.HashStore()

world_size = 4
timeout = timedelta(milliseconds=1)
collectives = dist.StoreCollectives(store, 1, world_size)
with self.assertRaisesRegex(Exception, "barrier failed -- missing ranks: 0, 2, 3"):
collectives.all_sum("foo", 1, timeout)




if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"

run_tests()
14 changes: 14 additions & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,20 @@ class PrefixStore(Store):
@property
def underlying_store(self) -> Store: ...

class Collectives:
def barrier(self, prefix: str, timeout: timedelta, blocking: bool) -> None: ...
def broadcast_send(self, prefix: str, data: str, timeout: timedelta) -> None: ...
def broadcast_recv(self, prefix: str, timeout: timedelta) -> str: ...
def gather_send(self, prefix: str, data: str, timeout: timedelta) -> None: ...
def gather_recv(self, prefix: str, timeout: timedelta) -> str: ...
def scatter_send(self, prefix: str, data: str, timeout: timedelta) -> None: ...
def scatter_recv(self, prefix: str, timeout: timedelta) -> str: ...
def all_gather(self, prefix: str, data: str, timeout: timedelta) -> str: ...
def all_sum(self, prefix: str, data: str, timeout: timedelta) -> int: ...

class StoreCollectives(Collectives):
def __init__(self, store: Store, rank: int, world_size: int) -> None: ...

class _DistributedBackendOptions:
def __init__(self): ...
@property
Expand Down
59 changes: 59 additions & 0 deletions torch/csrc/distributed/c10d/Collectives.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#pragma once

#include <ATen/core/ivalue.h>
#include <chrono>
#include <cstdint>
#include <string>
#include <vector>

#include <c10/macros/Macros.h>
#include <torch/custom_class.h>

namespace c10d {

using namespace std::chrono_literals;

class TORCH_API Collectives : public torch::CustomClassHolder {
public:
virtual void barrier(
const std::string& prefix,
std::chrono::milliseconds timeout = 5min,
bool block = true) = 0;

virtual void broadcast_send(
const std::string& prefix,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<uint8_t> broadcast_recv(
const std::string& prefix,
std::chrono::milliseconds timeout = 5min) = 0;

virtual void gather_send(
const std::string& prefix,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<std::vector<uint8_t>> gather_recv(
const std::string& prefix,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;

virtual std::vector<uint8_t> scatter_send(
const std::string& prefix,
const std::vector<std::vector<uint8_t>>& data,
std::chrono::milliseconds timeout = 5min) = 0;
virtual std::vector<uint8_t> scatter_recv(
const std::string& prefix,
std::chrono::milliseconds timeout = 5min) = 0;

virtual std::vector<std::vector<uint8_t>> all_gather(
const std::string& prefix,
const std::vector<uint8_t>& data,
std::chrono::milliseconds timeout = 5min) = 0;

virtual int64_t all_sum(
const std::string& prefix,
int64_t data,
std::chrono::milliseconds timeout = 5min) = 0;
};

} // namespace c10d
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/HashStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TORCH_API HashStore : public Store {
std::vector<uint8_t> get(const std::string& key) override;

void wait(const std::vector<std::string>& keys) override {
wait(keys, Store::kDefaultTimeout);
wait(keys, timeout_);
}

void wait(
Expand Down
29 changes: 29 additions & 0 deletions torch/csrc/distributed/c10d/Store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,33 @@ class TORCH_API Store : public torch::CustomClassHolder {
std::chrono::milliseconds timeout_;
};

/*
StoreTimeoutGuard is a RAII guard that will set the store timeout and restore it
when it returns.
*/
class StoreTimeoutGuard {
public:
explicit StoreTimeoutGuard(
Store& store,
const std::chrono::milliseconds& timeout)
: store_(store) {
old_timeout_ = store.getTimeout();
store.setTimeout(timeout);
}

~StoreTimeoutGuard() {
store_.setTimeout(old_timeout_);
}

/* Disabling copy and move semantics */
StoreTimeoutGuard(const StoreTimeoutGuard&) = delete;
StoreTimeoutGuard& operator=(const StoreTimeoutGuard&) = delete;
StoreTimeoutGuard(StoreTimeoutGuard&&) = delete;
StoreTimeoutGuard& operator=(StoreTimeoutGuard&&) = delete;

private:
Store& store_;
std::chrono::milliseconds old_timeout_;
};

} // namespace c10d
Loading

0 comments on commit b9e1a53

Please sign in to comment.