Skip to content

Commit

Permalink
[torch/elastic] Revise the rendezvous handler registry logic. (#55466)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #55466

Improve the implementation and the unit test coverage of `RendezvousHandlerRegistry`.

### Note
See the original diff (D27442325 (df299db)) that had to be reverted due to an unexpected Python version incompatibility between the internal and external PyTorch CI tests.

Test Plan: Run the existing and newly-introduced unit tests.

Reviewed By: tierex

Differential Revision: D27623215

fbshipit-source-id: 51538d0f154f64e04f685a95d40d805b478c93f9
  • Loading branch information
cbalioglu authored and facebook-github-bot committed Apr 8, 2021
1 parent 8ac0619 commit 493a233
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 96 deletions.
147 changes: 91 additions & 56 deletions test/distributed/elastic/rendezvous/api_test.py
Expand Up @@ -7,67 +7,14 @@
from typing import Any, Dict, SupportsInt, Tuple, cast
from unittest import TestCase

from torch.distributed import Store
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
RendezvousHandlerRegistry,
RendezvousParameters
)


def create_mock_rdzv_handler(ignored: RendezvousParameters) -> RendezvousHandler:
return MockRendezvousHandler()


class MockRendezvousHandler(RendezvousHandler):
def next_rendezvous(
self,
# pyre-ignore[11]: Annotation `Store` is not defined as a type.
) -> Tuple["torch.distributed.Store", int, int]: # noqa F821
raise NotImplementedError()

def get_backend(self) -> str:
return "mock"

def is_closed(self) -> bool:
return False

def set_closed(self):
pass

def num_nodes_waiting(self) -> int:
return -1

def get_run_id(self) -> str:
return ""


class RendezvousHandlerFactoryTest(TestCase):
def test_double_registration(self):
factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
with self.assertRaises(ValueError):
factory.register("mock", create_mock_rdzv_handler)

def test_no_factory_method_found(self):
factory = RendezvousHandlerFactory()
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)

with self.assertRaises(ValueError):
factory.create_handler(rdzv_params)

def test_create_handler(self):
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)

factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
mock_rdzv_handler = factory.create_handler(rdzv_params)
self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))


class RendezvousParametersTest(TestCase):
def setUp(self) -> None:
self._backend = "dummy_backend"
Expand Down Expand Up @@ -236,3 +183,91 @@ def test_get_as_int_raises_error_if_value_is_invalid(self) -> None:
r"valid integer value.$",
):
params.get_as_int("dummy_param")


class _DummyRendezvousHandler(RendezvousHandler):
def __init__(self, params: RendezvousParameters) -> None:
self.params = params

def get_backend(self) -> str:
return "dummy_backend"

def next_rendezvous(self) -> Tuple[Store, int, int]:
raise NotImplementedError()

def is_closed(self) -> bool:
return False

def set_closed(self) -> None:
pass

def num_nodes_waiting(self) -> int:
return -1

def get_run_id(self) -> str:
return ""

def shutdown(self) -> bool:
return False


class RendezvousHandlerRegistryTest(TestCase):
def setUp(self) -> None:
self._params = RendezvousParameters(
backend="dummy_backend",
endpoint="dummy_endpoint",
run_id="dummy_run_id",
min_nodes=1,
max_nodes=1,
)

self._registry = RendezvousHandlerRegistry()

@staticmethod
def _create_handler(params: RendezvousParameters) -> RendezvousHandler:
return _DummyRendezvousHandler(params)

def test_register_registers_once_if_called_twice_with_same_creator(self) -> None:
self._registry.register("dummy_backend", self._create_handler)
self._registry.register("dummy_backend", self._create_handler)

def test_register_raises_error_if_called_twice_with_different_creators(self) -> None:
self._registry.register("dummy_backend", self._create_handler)

other_create_handler = lambda p: _DummyRendezvousHandler(p) # noqa: E731

with self.assertRaisesRegex(
ValueError,
r"^The rendezvous backend 'dummy_backend' cannot be registered with "
rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$",
):
self._registry.register("dummy_backend", other_create_handler)

def test_create_handler_returns_handler(self) -> None:
self._registry.register("dummy_backend", self._create_handler)

handler = self._registry.create_handler(self._params)

self.assertIsInstance(handler, _DummyRendezvousHandler)

self.assertIs(handler.params, self._params)

def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None:
with self.assertRaisesRegex(
ValueError,
r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call "
r"`register`\?$",
):
self._registry.create_handler(self._params)

def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None:
self._registry.register("dummy_backend_2", self._create_handler)

with self.assertRaisesRegex(
RuntimeError,
r"^The rendezvous backend 'dummy_backend' does not match the requested backend "
r"'dummy_backend_2'.$",
):
self._params.backend = "dummy_backend_2"

self._registry.create_handler(self._params)
30 changes: 19 additions & 11 deletions torch/distributed/elastic/rendezvous/__init__.py
@@ -1,4 +1,3 @@
#!/usr/bin/env/python3
# -*- coding: utf-8 -*-

# Copyright (c) Facebook, Inc. and its affiliates.
Expand Down Expand Up @@ -102,13 +101,22 @@
to participate in next rendezvous.
"""

from .api import ( # noqa: F401
RendezvousClosedError,
RendezvousConnectionError,
RendezvousError,
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
RendezvousStateError,
RendezvousTimeoutError,
)
from .api import *
from .registry import _register_default_handlers


_register_default_handlers()


__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]
51 changes: 31 additions & 20 deletions torch/distributed/elastic/rendezvous/api.py
Expand Up @@ -219,51 +219,62 @@ def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]


class RendezvousHandlerFactory:
"""
Creates ``RendezvousHandler`` instances for supported rendezvous backends.
"""
class RendezvousHandlerRegistry:
"""Represents a registry of `RendezvousHandler` backends."""

def __init__(self):
self._registry: Dict[str, RendezvousHandlerCreator] = {}
_registry: Dict[str, RendezvousHandlerCreator]

def register(self, backend: str, creator: RendezvousHandlerCreator):
"""
Registers a new rendezvous backend.
def __init__(self) -> None:
self._registry = {}

def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
"""Registers a new rendezvous backend.
Args:
backend:
The name of the backend.
creater:
The callback to invoke to construct the `RendezvousHandler`.
"""
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")

current_creator: Optional[RendezvousHandlerCreator]
try:
current_creator = self._registry[backend]
except KeyError:
current_creator = None # type: ignore[assignment]
current_creator = None

if current_creator is not None:
if current_creator is not None and current_creator != creator:
raise ValueError(
f"The rendezvous backend '{backend}' cannot be registered with"
f" '{creator.__module__}.{creator.__name__}' as it is already"
f" registered with '{current_creator.__module__}.{current_creator.__name__}'."
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
f"is already registered with '{current_creator}'."
)

self._registry[backend] = creator

def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""
Creates a new ``RendezvousHandler`` instance for the specified backend.
"""
"""Creates a new `RendezvousHandler`."""
try:
creator = self._registry[params.backend]
except KeyError:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call {self.register.__name__}?"
f"to call `{self.register.__name__}`?"
)

handler = creator(params)

# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
f"The rendezvous handler backend '{handler.get_backend()}' does not match the "
f"requested backend '{params.backend}'."
f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
f"backend '{params.backend}'."
)

return handler


# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()
22 changes: 13 additions & 9 deletions torch/distributed/elastic/rendezvous/registry.py
Expand Up @@ -4,16 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from . import etcd_rendezvous
from .api import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
)
from .api import RendezvousHandler, RendezvousParameters
from .api import rendezvous_handler_registry as handler_registry

_factory = RendezvousHandlerFactory()
_factory.register("etcd", etcd_rendezvous.create_rdzv_handler)

def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import etcd_rendezvous

return etcd_rendezvous.create_rdzv_handler(params)


def _register_default_handlers() -> None:
handler_registry.register("etcd", _create_etcd_handler)


# The legacy function kept for backwards compatibility.
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
return _factory.create_handler(params)
return handler_registry.create_handler(params)

0 comments on commit 493a233

Please sign in to comment.