-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
837 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
from datetime import timedelta | ||
from multiprocessing.pool import ThreadPool | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
|
||
|
||
class TestCollectives(TestCase): | ||
def test_barrier(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 2 | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist._StoreCollectives(store, rank, world_size) | ||
collectives.barrier("foo", timedelta(seconds=10), True) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_broadcast(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist._StoreCollectives(store, rank, world_size) | ||
if rank == 2: | ||
collectives.broadcast_send("foo", b"data", timeout) | ||
else: | ||
out = collectives.broadcast_recv("foo", timeout) | ||
self.assertEqual(out, b"data") | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_gather(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist._StoreCollectives(store, rank, world_size) | ||
if rank == 2: | ||
out = collectives.gather_recv("foo", str(rank), timeout) | ||
self.assertEqual(out, [b"0", b"1", b"2", b"3"]) | ||
else: | ||
collectives.gather_send("foo", str(rank), timeout) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_scatter(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist._StoreCollectives(store, rank, world_size) | ||
if rank == 2: | ||
out = collectives.scatter_send( | ||
"foo", [str(i) for i in range(world_size)], timeout | ||
) | ||
else: | ||
out = collectives.scatter_recv("foo", timeout) | ||
self.assertEqual(out, str(rank).encode()) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_all_sum(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist._StoreCollectives(store, rank, world_size) | ||
out = collectives.all_sum("foo", rank, timeout) | ||
self.assertEqual(out, sum(range(world_size))) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_broadcast_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist._StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "Wait timeout"): | ||
collectives.broadcast_recv("foo", timeout) | ||
|
||
def test_gather_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist._StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex( | ||
Exception, "gather failed -- missing ranks: 0, 2, 3" | ||
): | ||
collectives.gather_recv("foo", "data", timeout) | ||
|
||
def test_scatter_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist._StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "Wait timeout"): | ||
collectives.scatter_recv("foo", timeout) | ||
|
||
def test_all_gather_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist._StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex( | ||
Exception, "all_gather failed -- missing ranks: 0, 2, 3" | ||
): | ||
collectives.all_gather("foo", "data", timeout) | ||
|
||
def test_barrier_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist._StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex( | ||
Exception, "barrier failed -- missing ranks: 0, 2, 3" | ||
): | ||
collectives.barrier("foo", timeout, True) | ||
|
||
def test_all_sum_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist._StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex( | ||
Exception, "barrier failed -- missing ranks: 0, 2, 3" | ||
): | ||
collectives.all_sum("foo", 1, timeout) | ||
|
||
def test_unique(self) -> None: | ||
store = dist.HashStore() | ||
|
||
collectives = dist._StoreCollectives(store, 1, 1) | ||
collectives.broadcast_send("foo", "bar") | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.broadcast_send("foo", "bar") | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.broadcast_recv("foo") | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.gather_send("foo", "bar") | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.gather_recv("foo", "asdf") | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.scatter_send("foo", ["asdf"]) | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.scatter_recv("foo") | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.all_gather("foo", "bar") | ||
|
||
with self.assertRaisesRegex(Exception, "Key foo has already been used"): | ||
collectives.all_sum("foo", 2) | ||
|
||
|
||
if __name__ == "__main__": | ||
assert ( | ||
not torch.cuda._initialized | ||
), "test_distributed must not have initialized CUDA context on main process" | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#pragma once | ||
|
||
#include <ATen/core/ivalue.h> | ||
#include <chrono> | ||
#include <cstdint> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include <c10/macros/Macros.h> | ||
#include <torch/custom_class.h> | ||
|
||
namespace c10d { | ||
|
||
using namespace std::chrono_literals; | ||
|
||
class TORCH_API ControlCollectives : public torch::CustomClassHolder { | ||
public: | ||
virtual void barrier( | ||
const std::string& key, | ||
std::chrono::milliseconds timeout = 5min, | ||
bool block = true) = 0; | ||
|
||
virtual void broadcastSend( | ||
const std::string& key, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
virtual std::vector<uint8_t> broadcastRecv( | ||
const std::string& key, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual void gatherSend( | ||
const std::string& key, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
virtual std::vector<std::vector<uint8_t>> gatherRecv( | ||
const std::string& key, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual std::vector<uint8_t> scatterSend( | ||
const std::string& key, | ||
const std::vector<std::vector<uint8_t>>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
virtual std::vector<uint8_t> scatterRecv( | ||
const std::string& key, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual std::vector<std::vector<uint8_t>> allGather( | ||
const std::string& key, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual int64_t allSum( | ||
const std::string& key, | ||
int64_t data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
}; | ||
|
||
} // namespace c10d |
Oops, something went wrong.