Skip to content
This repository has been archived by the owner on Jan 6, 2023. It is now read-only.

Use EtcdStore rather than TCPStore when using etcd_rdzv #34

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions test/distributed/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.abs

import multiprocessing as mp
import unittest
from unittest.mock import patch

import torchelastic.distributed as edist


def _return_false():
return False


def _return_true():
return True


def _return_one():
return 1


def _get_rank(ignored):
"""
wrapper around torchelastic.distributed.get_rank()
take the element in the input argument as parameter
since multiprocessing.Pool.map requires the function to
"""
return edist.get_rank()


class TestUtils(unittest.TestCase):
@patch("torch.distributed.is_available", _return_true)
@patch("torch.distributed.is_initialized", _return_false)
def test_get_rank_no_process_group_initialized(self):
# always return rank 0 when process group is not initialized
num_procs = 4
with mp.Pool(num_procs) as p:
ret = p.map(_get_rank, range(0, num_procs))
for rank in ret:
self.assertEqual(0, rank)

@patch("torch.distributed.is_available", _return_false)
@patch("torch.distributed.is_initialized", _return_true)
def test_get_rank_no_dist_available(self):
# always return rank 0 when distributed torch is not available
num_procs = 4
with mp.Pool(num_procs) as p:
ret = p.map(_get_rank, range(0, num_procs))
for rank in ret:
self.assertEqual(0, rank)

@patch("torch.distributed.is_available", _return_true)
@patch("torch.distributed.is_initialized", _return_true)
@patch("torch.distributed.get_rank", _return_one)
def test_get_rank(self):
world_size = 4
with mp.Pool(world_size) as p:
ret = p.map(_get_rank, range(0, world_size))

# since we mocked a return value of 1
# from torch.distributed.get_rank()
# we expect that the sum of ranks == world_size
self.assertEqual(world_size, sum(ret))
1 change: 1 addition & 0 deletions torchelastic/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
# LICENSE file in the root directory of this source tree.

from .collectives import * # noqa F401
from .utils import * # noqa F401
18 changes: 18 additions & 0 deletions torchelastic/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch.distributed as dist


def get_rank():
"""
Simple wrapper for correctly getting rank in both distributed
/ non-distributed settings
"""
return dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
53 changes: 2 additions & 51 deletions torchelastic/rendezvous/etcd_rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,8 @@ def __del__(self):
def next_rendezvous(self):
rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier()

# TODO: https://github.com/pytorch/elastic/issues/11
# make EtcdStore the default and remove TCPStore code
# Setup a c10d store for this specific rendezvous version,
# by piggybacking on the etcd handler used during rendezvous.
# Switch back to EtcdStore once issue with
# pybind11-trampoline for c10d Store is resolved.
# store = self._rdzv_impl.setup_kv_store(rdzv_version)
# path once the pybind11-trampoline fix for c10d::Store is included in
# the next pytorch release. Then, remove this hack.
import torchelastic.rendezvous # noqa

if "_TORCHELASTIC_USE_ETCDSTORE" in torchelastic.rendezvous.__dict__:
log.info("Using EtcdStore for c10d::Store implementation")
store = self._rdzv_impl.setup_kv_store(rdzv_version)
else:
log.info("Using TCPStore for c10d::Store implementation")
store = setup_tcpstore(rank, world_size, rdzv_version, self._rdzv_impl)
log.info("Creating EtcdStore as the c10d::Store implementation")
store = self._rdzv_impl.setup_kv_store(rdzv_version)

return store, rank, world_size

Expand Down Expand Up @@ -1065,40 +1050,6 @@ def _get_socket_with_port():
raise RuntimeError("Failed to create a socket")


# Helper function to setup a TCPStore-based c10d::Store implementation.
def setup_tcpstore(rank, world_size, rdzv_version, rdzv_impl):
if rank == 0:
import socket
from contextlib import closing

# FIXME: ideally, TCPStore should have an API that
# accepts a pre-constructed socket.
with closing(_get_socket_with_port()) as sock:
host = socket.gethostname()
port = sock.getsockname()[1]

rdzv_impl.store_extra_data(
rdzv_version, key="tcpstore_server", value="{}:{}".format(host, port)
)

log.info(f"Setting up TCPStore server on {host}:{port}")
start_daemon = True
sock.close() # FIXME: get rid of race-condition by improving TCPStore API
store = TCPStore(host, port, world_size, start_daemon)
log.info(f"TCPStore server initialized on {host}:{port}")
else:
hostport = rdzv_impl.load_extra_data(rdzv_version, key="tcpstore_server")
log.info(f"Rank {rank} will conenct to TCPStore server at {hostport}")

import re

host, port = re.match(r"(.+):(\d+)$", hostport).groups()
start_daemon = False
store = TCPStore(host, int(port), world_size, start_daemon)

return store


# Helper for _etcd_rendezvous_handler(url)
def _parse_etcd_client_params(params):
kwargs = {}
Expand Down