Skip to content

Commit

Permalink
[TorchElastic] Option for sharing TCPStore created by rdzv handlers (#…
Browse files Browse the repository at this point in the history
…125743)

Summary:

1. Define explicit `use_agent_store` on rdzv handlers. Handlers that set is true can share the store.
2. Instead of agent coordinating master_add/master_port values, the logic is now encapsulated by a *rdzv_handler* where `RendezvousInfo` will have `RendezvousStoreInfo` object that handlers must return.
    - Depending on the implementation they can either:
         - point to existing store (and expected to `use_agent_store` as true - point 1). Client code will rely on `TORCHELASTIC_USE_AGENT_STORE` env variable to know if the store is shared.
         - build args that `torch.distributed.init_process_group` can bootstrap by creating new store.

Additional points:

- When TCPStore is shared, it should be wrapped in PrefixStore to qualify/scope namespace for other usecases.
- `next_rendezvous` signature changed to return instance of `RendezvousInfo` instead of a (store, rank, world_size) tuple for extensibility purposes.

Why:
- Reduce moving parts
   - easier to swap implementation
   - improve tractability
   - addressing perf/debug-ability will benefit all usecases
   -
Test Plan: CI

Differential Revision: D57055235

Pull Request resolved: #125743
Approved by: https://github.com/d4l3k
  • Loading branch information
kurman authored and pytorchmergebot committed May 22, 2024
1 parent fde1e8a commit d62b025
Show file tree
Hide file tree
Showing 13 changed files with 238 additions and 144 deletions.
11 changes: 10 additions & 1 deletion docs/source/elastic/rendezvous.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ Registry
:members:

.. autoclass:: RendezvousHandlerRegistry
:members:

.. automodule:: torch.distributed.elastic.rendezvous.registry

Expand All @@ -28,6 +27,16 @@ Handler
.. autoclass:: RendezvousHandler
:members:

Dataclasses
-----------
.. autoclass:: RendezvousInfo

.. currentmodule:: torch.distributed.elastic.rendezvous.api

.. autoclass:: RendezvousStoreInfo

.. automethod:: build(rank, store)

Exceptions
----------
.. autoclass:: RendezvousError
Expand Down
34 changes: 17 additions & 17 deletions test/distributed/elastic/agent/server/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ def _get_worker_spec(
):
run_id = str(uuid.uuid4().int)
port = get_free_port()
endpoint = f"127.0.0.1:{port}"
if local_addr is None:
endpoint = f"127.0.0.1:{port}"
else:
endpoint = f"{local_addr}:{port}"

rdzv_params = RendezvousParameters(
backend="static",
endpoint=endpoint,
Expand Down Expand Up @@ -298,7 +302,8 @@ def _get_record_metrics_test_calls(
return calls

def test_rendezvous(self):
spec = self._get_worker_spec(max_restarts=1)
hostname = _get_fq_hostname()
spec = self._get_worker_spec(max_restarts=1, local_addr=hostname)
agent = TestAgent(spec)
worker_group = agent.get_worker_group()
agent._rendezvous(worker_group)
Expand All @@ -307,10 +312,8 @@ def test_rendezvous(self):
self.assertEqual(1, worker_group.group_world_size)
self.assertEqual(0, worker_group.group_rank)

master_addr, master_port = agent._get_master_addr_port(worker_group.store)

self.assertEqual(_get_fq_hostname(), master_addr)
self.assertTrue(master_port > 0)
self.assertEqual(hostname, worker_group.master_addr)
self.assertTrue(worker_group.master_port > 0)

rank_set = {w.global_rank for w in worker_group.workers}
for w in worker_group.workers:
Expand All @@ -326,28 +329,25 @@ def test_rendezvous(self):
self.assertSetEqual(set(range(w.world_size)), rank_set)

def test_rendezvous_default_master_addr(self):
spec = self._get_worker_spec(max_restarts=1)
hostname = _get_fq_hostname()
spec = self._get_worker_spec(max_restarts=1, local_addr=hostname)
agent = TestAgent(spec)
worker_group = agent.get_worker_group()
agent._rendezvous(worker_group)

master_addr, master_port = agent._get_master_addr_port(worker_group.store)

self.assertEqual(_get_fq_hostname(), master_addr)
self.assertGreater(master_port, 0)
self.assertEqual(_get_fq_hostname(), worker_group.master_addr)
self.assertGreater(worker_group.master_port, 0)

def test_rendezvous_master_addr_with_local_addr(self):
spec_local_addr = "1.2.3.4"
spec_local_addr = "127.0.0.1"
spec = self._get_worker_spec(max_restarts=1, local_addr=spec_local_addr)
agent = TestAgent(spec)
worker_group = agent.get_worker_group()
agent._rendezvous(worker_group)

master_addr, master_port = agent._get_master_addr_port(worker_group.store)

self.assertNotEqual(_get_fq_hostname(), master_addr)
self.assertEqual(spec_local_addr, master_addr)
self.assertGreater(master_port, 0)
self.assertNotEqual(_get_fq_hostname(), worker_group.master_addr)
self.assertEqual(spec_local_addr, worker_group.master_addr)
self.assertGreater(worker_group.master_port, 0)

def test_initialize_workers(self):
spec = self._get_worker_spec(max_restarts=1)
Expand Down
6 changes: 3 additions & 3 deletions test/distributed/elastic/rendezvous/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, cast, Dict, SupportsInt, Tuple
from typing import Any, cast, Dict, SupportsInt
from unittest import TestCase

from torch.distributed import Store
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousHandlerRegistry,
RendezvousInfo,
RendezvousParameters,
)

Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(self, params: RendezvousParameters) -> None:
def get_backend(self) -> str:
return "dummy_backend"

def next_rendezvous(self) -> Tuple[Store, int, int]:
def next_rendezvous(self) -> RendezvousInfo:
raise NotImplementedError

def is_closed(self) -> bool:
Expand Down
33 changes: 18 additions & 15 deletions test/distributed/elastic/rendezvous/dynamic_rendezvous_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from unittest import TestCase
from unittest.mock import call, MagicMock, Mock, patch

from torch.distributed import Store
from torch.distributed import HashStore, Store
from torch.distributed.elastic.rendezvous import (
RendezvousClosedError,
RendezvousError,
RendezvousInfo,
RendezvousParameters,
RendezvousStateError,
RendezvousTimeoutError,
Expand Down Expand Up @@ -1154,9 +1155,11 @@ def setUp(self) -> None:

self._store = DummyStore()

self._mock_store_get = MagicMock(return_value=b"dummy_value")
self._mock_store_get = MagicMock(return_value=b"123")
self._mock_store_set = MagicMock()

setattr(self._store, "get", self._mock_store_get) # noqa: B010
setattr(self._store, "set", self._mock_store_set) # noqa: B010

self._state_holder = FakeRendezvousStateHolder()

Expand Down Expand Up @@ -1208,14 +1211,14 @@ def test_next_rendezvous_returns_expected_value(self) -> None:

handler = self._create_handler()

store, rank, world_size = handler.next_rendezvous()
rdzv_info = handler.next_rendezvous()

self.assertEqual(rank, 2)
self.assertEqual(world_size, 3)
self.assertEqual(rdzv_info.rank, 2)
self.assertEqual(rdzv_info.world_size, 3)

_ = store.get("dummy_key")
_ = rdzv_info.store.get("dummy_key")

self._mock_store_get.assert_called_once_with(
self._mock_store_get.assert_called_with(
"torch.rendezvous.dummy_run_id.0/dummy_key"
)

Expand Down Expand Up @@ -1595,7 +1598,7 @@ def join(self, *args):

class IntegrationTest(TestCase):
def setUp(self) -> None:
self._store = DummyStore()
self._store = HashStore()
self._handlers = []
self._backend = _InMemoryRendezvousBackend()

Expand Down Expand Up @@ -1631,15 +1634,15 @@ def test_all_nodes_join_rendezvous(self) -> None:
handler1_thread.start()
handler2_thread.start()

store1, rank1, world_size1 = handler1_thread.join()
store2, rank2, world_size2 = handler2_thread.join()
self.assertEqual(store1.underlying_store, self._store)
self.assertEqual(store2.underlying_store, self._store)
rdzv_info1: RendezvousInfo = handler1_thread.join()
rdzv_info2: RendezvousInfo = handler2_thread.join()
self.assertEqual(rdzv_info1.store.underlying_store, self._store)
self.assertEqual(rdzv_info2.store.underlying_store, self._store)

self.assertNotEqual(rank1, rank2)
self.assertNotEqual(rdzv_info1.rank, rdzv_info2.rank)

self.assertEqual(world_size1, 2)
self.assertEqual(world_size2, 2)
self.assertEqual(rdzv_info1.world_size, 2)
self.assertEqual(rdzv_info2.world_size, 2)

def test_redundancy_list(self) -> None:
handler1 = self._create_handler(min_nodes=2, max_nodes=2)
Expand Down
8 changes: 4 additions & 4 deletions test/distributed/elastic/rendezvous/etcd_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def test_etcd_server_with_rendezvous(self):
last_call_timeout=30,
)
rdzv_handler = EtcdRendezvousHandler(rdzv)
store, rank, world_size = rdzv_handler.next_rendezvous()
self.assertIsNotNone(store)
self.assertEqual(0, rank)
self.assertEqual(1, world_size)
rdzv_info = rdzv_handler.next_rendezvous()
self.assertIsNotNone(rdzv_info.store)
self.assertEqual(0, rdzv_info.rank)
self.assertEqual(1, rdzv_info.world_size)
finally:
server.stop()
16 changes: 8 additions & 8 deletions test/distributed/elastic/rendezvous/static_rendezvous_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def test_static_rdzv_multiple_calls(self):
rdzv_handler = create_rdzv_handler(rdzv_params)

# Call rendezvous two times
store, rank, world_size = rdzv_handler.next_rendezvous()
self.assertIsNotNone(store)
self.assertEqual(0, rank)
self.assertEqual(1, world_size)
rdzv_info = rdzv_handler.next_rendezvous()
self.assertIsNotNone(rdzv_info.store)
self.assertEqual(0, rdzv_info.rank)
self.assertEqual(1, rdzv_info.world_size)

store, rank, world_size = rdzv_handler.next_rendezvous()
self.assertIsNotNone(store)
self.assertEqual(0, rank)
self.assertEqual(1, world_size)
rdzv_info = rdzv_handler.next_rendezvous()
self.assertIsNotNone(rdzv_info.store)
self.assertEqual(0, rdzv_info.rank)
self.assertEqual(1, rdzv_info.world_size)
91 changes: 16 additions & 75 deletions torch/distributed/elastic/agent/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
import time
import traceback
import warnings
from contextlib import closing, contextmanager
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch.distributed.elastic.rendezvous as rdzv
import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError
from torch.distributed import Store
from torch.distributed.elastic.events import Event, EventSource, record
from torch.distributed.elastic.metrics import prof, put_metric
from torch.distributed.elastic.multiprocessing import (
Expand Down Expand Up @@ -251,7 +250,7 @@ class WorkerGroup:
group contains cross instance workers or not depends on the implementation of the agent.
"""

__slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"]
__slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state", "master_addr", "master_port"]

def __init__(self, spec: WorkerSpec):
self.spec = spec
Expand All @@ -261,6 +260,8 @@ def __init__(self, spec: WorkerSpec):
self.store = None
self.group_rank = None
self.group_world_size = None
self.master_addr = None
self.master_port = None

self.state = WorkerState.INIT

Expand Down Expand Up @@ -356,37 +357,6 @@ def is_failed(self) -> bool:
return self.state == WorkerState.FAILED


def _get_socket_with_port() -> socket.socket:
"""Return a free port on localhost.
The free port is "reserved" by binding a temporary socket on it.
Close the socket before passing the port to the entity that
requires it. Usage example::
sock = _get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
sock.close()
# there is still a race-condition that some other process
# may grab this port before func() runs
func(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
logger.info("Socket creation attempt failed.", exc_info=e)
raise RuntimeError("Failed to create a socket")


def _get_fq_hostname() -> str:
return socket.getfqdn(socket.gethostname())

Expand Down Expand Up @@ -506,36 +476,6 @@ def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool
"""
raise NotImplementedError

@staticmethod
def _set_master_addr_port(
store: Store,
master_addr: Optional[str],
master_port: Optional[int],
local_addr: Optional[str],
):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]

if master_addr is None:
# If user specified the address for the local node, use it as the master addr if not exist
if local_addr:
master_addr = local_addr
else:
master_addr = _get_fq_hostname()

store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))

@staticmethod
def _get_master_addr_port(store: Store) -> Tuple[str, int]:
master_addr = store.get("MASTER_ADDR").decode(encoding="UTF-8")
master_port = int(store.get("MASTER_PORT").decode(encoding="UTF-8"))
return (master_addr, master_port)

# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `torch.distributed.elastic.metrics.prof`.
@prof
def _rendezvous(self, worker_group: WorkerGroup) -> None:
r"""Run rendezvous for the workers specified by the worker spec.
Expand All @@ -546,7 +486,16 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None:
spec = worker_group.spec

with self.record_duration("RENDEZVOUS"):
store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
rdzv_info = spec.rdzv_handler.next_rendezvous()
store = rdzv_info.store
group_rank = rdzv_info.rank
group_world_size = rdzv_info.world_size

# master_addr/master_port could be explicitly overriden
# TODO: BC - specific to static rdzv and can be simplifed further
master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr
master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port

self._store = store

with self.record_duration("ASSIGN_WORKER_RANKS"):
Expand All @@ -555,17 +504,9 @@ def _rendezvous(self, worker_group: WorkerGroup) -> None:
worker_group.store = store
worker_group.group_rank = group_rank
worker_group.group_world_size = group_world_size
worker_group.master_addr = master_addr
worker_group.master_port = master_port

if group_rank == 0:
self._set_master_addr_port(
store,
spec.master_addr,
spec.master_port,
spec.local_addr,
)

with self.record_duration("GET_MASTER_ADDR_PORT"):
master_addr, master_port = self._get_master_addr_port(store)
restart_count = spec.max_restarts - self._remaining_restarts

logger.info(
Expand Down
Loading

0 comments on commit d62b025

Please sign in to comment.