-
Notifications
You must be signed in to change notification settings - Fork 22.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
701 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from multiprocessing.pool import ThreadPool | ||
from datetime import timedelta | ||
|
||
import torch | ||
from torch.testing._internal.common_utils import ( | ||
run_tests, | ||
TestCase, | ||
) | ||
import torch.distributed as dist | ||
|
||
class TestCollectives(TestCase): | ||
def test_barrier(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 2 | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist.StoreCollectives(store, rank, world_size) | ||
collectives.barrier("foo", timedelta(seconds=10), True) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_broadcast(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist.StoreCollectives(store, rank, world_size) | ||
if rank == 2: | ||
collectives.broadcast_send("foo", "data", timeout) | ||
else: | ||
out = collectives.broadcast_recv("foo", timeout) | ||
self.assertEqual(out, "data") | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_gather(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist.StoreCollectives(store, rank, world_size) | ||
if rank == 2: | ||
out = collectives.gather_recv("foo", str(rank), timeout) | ||
self.assertEqual(out, ["0", "1", "2", "3"]) | ||
else: | ||
collectives.gather_send("foo", str(rank), timeout) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_scatter(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist.StoreCollectives(store, rank, world_size) | ||
if rank == 2: | ||
out = collectives.scatter_send("foo", [str(i) for i in range(world_size)], timeout) | ||
else: | ||
out = collectives.scatter_recv("foo", timeout) | ||
self.assertEqual(out, str(rank)) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_all_sum(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(seconds=10) | ||
|
||
def f(rank: int) -> None: | ||
collectives = dist.StoreCollectives(store, rank, world_size) | ||
out = collectives.all_sum("foo", rank, timeout) | ||
self.assertEqual(out, sum(range(world_size))) | ||
|
||
with ThreadPool(world_size) as pool: | ||
pool.map(f, range(world_size)) | ||
|
||
def test_broadcast_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist.StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "Wait timeout"): | ||
collectives.broadcast_recv("foo", timeout) | ||
|
||
def test_gather_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist.StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "gather failed -- missing ranks: 0, 2, 3"): | ||
collectives.gather_recv("foo", "data", timeout) | ||
|
||
def test_scatter_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist.StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "Wait timeout"): | ||
collectives.scatter_recv("foo", timeout) | ||
|
||
def test_all_gather_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist.StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "all_gather failed -- missing ranks: 0, 2, 3"): | ||
collectives.all_gather("foo", "data", timeout) | ||
|
||
def test_barrier_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist.StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "barrier failed -- missing ranks: 0, 2, 3"): | ||
collectives.barrier("foo", timeout, True) | ||
|
||
def test_all_sum_timeout(self) -> None: | ||
store = dist.HashStore() | ||
|
||
world_size = 4 | ||
timeout = timedelta(milliseconds=1) | ||
collectives = dist.StoreCollectives(store, 1, world_size) | ||
with self.assertRaisesRegex(Exception, "barrier failed -- missing ranks: 0, 2, 3"): | ||
collectives.all_sum("foo", 1, timeout) | ||
|
||
|
||
|
||
|
||
if __name__ == "__main__": | ||
assert ( | ||
not torch.cuda._initialized | ||
), "test_distributed must not have initialized CUDA context on main process" | ||
|
||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#pragma once | ||
|
||
#include <ATen/core/ivalue.h> | ||
#include <chrono> | ||
#include <cstdint> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include <c10/macros/Macros.h> | ||
#include <torch/custom_class.h> | ||
|
||
namespace c10d { | ||
|
||
using namespace std::chrono_literals; | ||
|
||
class TORCH_API Collectives : public torch::CustomClassHolder { | ||
public: | ||
virtual void barrier( | ||
const std::string& prefix, | ||
std::chrono::milliseconds timeout = 5min, | ||
bool block = true) = 0; | ||
|
||
virtual void broadcast_send( | ||
const std::string& prefix, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
virtual std::vector<uint8_t> broadcast_recv( | ||
const std::string& prefix, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual void gather_send( | ||
const std::string& prefix, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
virtual std::vector<std::vector<uint8_t>> gather_recv( | ||
const std::string& prefix, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual std::vector<uint8_t> scatter_send( | ||
const std::string& prefix, | ||
const std::vector<std::vector<uint8_t>>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
virtual std::vector<uint8_t> scatter_recv( | ||
const std::string& prefix, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual std::vector<std::vector<uint8_t>> all_gather( | ||
const std::string& prefix, | ||
const std::vector<uint8_t>& data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
|
||
virtual int64_t all_sum( | ||
const std::string& prefix, | ||
int64_t data, | ||
std::chrono::milliseconds timeout = 5min) = 0; | ||
}; | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.