diff --git a/tests/ignite/conftest.py b/tests/ignite/conftest.py index 6541af0abb90..374fba30d26c 100644 --- a/tests/ignite/conftest.py +++ b/tests/ignite/conftest.py @@ -90,14 +90,54 @@ def _destroy_dist_context(): _set_model(_SerialModel()) +def _find_free_port(): + # Taken from https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def _setup_free_port(local_rank): + import time + import os + + port_file = "/tmp/free_port" + + if local_rank == 0: + port = _find_free_port() + with open(port_file, "w") as h: + h.write(str(port)) + return port + else: + counter = 10 + while counter > 0: + counter -= 1 + time.sleep(1) + if not os.path.exists(port_file): + continue + with open(port_file, "r") as h: + port = h.readline() + return int(port) + + raise RuntimeError("Failed to fetch free port on local rank {}".format(local_rank)) + + @pytest.fixture() def distributed_context_single_node_nccl(local_rank, world_size): + free_port = _setup_free_port(local_rank) + + print(local_rank, "Port:", free_port) + dist_info = { "backend": "nccl", "world_size": world_size, "rank": local_rank, - "init_method": "tcp://localhost:2223", + "init_method": "tcp://localhost:{}".format(free_port), } yield _create_dist_context(dist_info, local_rank) _destroy_dist_context() @@ -108,11 +148,14 @@ def distributed_context_single_node_gloo(local_rank, world_size): from datetime import timedelta - init_method = "tcp://localhost:2223" - temp_file = None if sys.platform.startswith("win"): temp_file = tempfile.NamedTemporaryFile(delete=False) init_method = "file:///{}".format(temp_file.name.replace("\\", "/")) + else: + free_port = _setup_free_port(local_rank) + print(local_rank, "Port:", free_port) + init_method = "tcp://localhost:{}".format(free_port) + temp_file = None dist_info = { "backend": "gloo",