Skip to content

Commit

Permalink
[torch/elastic] Revise the rendezvous handler registry logic.
Browse files Browse the repository at this point in the history
Summary: Improve the implementation and the unit test coverage of `RendezvousHandlerRegistry`.

Test Plan: Run the existing and newly-introduced unit tests.

Reviewed By: tierex

Differential Revision: D27442325

fbshipit-source-id: 8519a2caacbe2e3ce5d9a02e87a910503dea27d7
  • Loading branch information
cbalioglu authored and facebook-github-bot committed Apr 6, 2021
1 parent 359d0a0 commit df299db
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 98 deletions.
147 changes: 91 additions & 56 deletions test/distributed/elastic/rendezvous/api_test.py
Expand Up @@ -7,67 +7,14 @@
from typing import Any, Dict, SupportsInt, Tuple, cast
from unittest import TestCase

from torch.distributed import Store
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
RendezvousHandlerRegistry,
RendezvousParameters
)


def create_mock_rdzv_handler(ignored: RendezvousParameters) -> RendezvousHandler:
return MockRendezvousHandler()


class MockRendezvousHandler(RendezvousHandler):
def next_rendezvous(
self,
# pyre-ignore[11]: Annotation `Store` is not defined as a type.
) -> Tuple["torch.distributed.Store", int, int]: # noqa F821
raise NotImplementedError()

def get_backend(self) -> str:
return "mock"

def is_closed(self) -> bool:
return False

def set_closed(self):
pass

def num_nodes_waiting(self) -> int:
return -1

def get_run_id(self) -> str:
return ""


class RendezvousHandlerFactoryTest(TestCase):
def test_double_registration(self):
factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
with self.assertRaises(ValueError):
factory.register("mock", create_mock_rdzv_handler)

def test_no_factory_method_found(self):
factory = RendezvousHandlerFactory()
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)

with self.assertRaises(ValueError):
factory.create_handler(rdzv_params)

def test_create_handler(self):
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)

factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
mock_rdzv_handler = factory.create_handler(rdzv_params)
self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))


class RendezvousParametersTest(TestCase):
def setUp(self) -> None:
self._backend = "dummy_backend"
Expand Down Expand Up @@ -236,3 +183,91 @@ def test_get_as_int_raises_error_if_value_is_invalid(self) -> None:
r"valid integer value.$",
):
params.get_as_int("dummy_param")


class _DummyRendezvousHandler(RendezvousHandler):
def __init__(self, params: RendezvousParameters) -> None:
self.params = params

def get_backend(self) -> str:
return "dummy_backend"

def next_rendezvous(self) -> Tuple[Store, int, int]:
raise NotImplementedError()

def is_closed(self) -> bool:
return False

def set_closed(self) -> None:
pass

def num_nodes_waiting(self) -> int:
return -1

def get_run_id(self) -> str:
return ""

def shutdown(self) -> bool:
return False


class RendezvousHandlerRegistryTest(TestCase):
def setUp(self) -> None:
self._params = RendezvousParameters(
backend="dummy_backend",
endpoint="dummy_endpoint",
run_id="dummy_run_id",
min_nodes=1,
max_nodes=1,
)

self._registry = RendezvousHandlerRegistry()

@staticmethod
def _create_handler(params: RendezvousParameters) -> RendezvousHandler:
return _DummyRendezvousHandler(params)

def test_register_registers_once_if_called_twice_with_same_creator(self) -> None:
self._registry.register("dummy_backend", self._create_handler)
self._registry.register("dummy_backend", self._create_handler)

def test_register_raises_error_if_called_twice_with_different_creators(self) -> None:
self._registry.register("dummy_backend", self._create_handler)

other_create_handler = lambda p: _DummyRendezvousHandler(p) # noqa: E731

with self.assertRaisesRegex(
ValueError,
r"^The rendezvous backend 'dummy_backend' cannot be registered with "
rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$",
):
self._registry.register("dummy_backend", other_create_handler)

def test_create_handler_returns_handler(self) -> None:
self._registry.register("dummy_backend", self._create_handler)

handler = self._registry.create_handler(self._params)

self.assertIsInstance(handler, _DummyRendezvousHandler)

self.assertIs(handler.params, self._params)

def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None:
with self.assertRaisesRegex(
ValueError,
r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call "
r"`register`\?$",
):
self._registry.create_handler(self._params)

def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None:
self._registry.register("dummy_backend_2", self._create_handler)

with self.assertRaisesRegex(
RuntimeError,
r"^The rendezvous backend 'dummy_backend' does not match the requested backend "
r"'dummy_backend_2'.$",
):
self._params.backend = "dummy_backend_2"

self._registry.create_handler(self._params)
31 changes: 19 additions & 12 deletions torch/distributed/elastic/rendezvous/__init__.py
@@ -1,5 +1,3 @@
#!/usr/bin/env/python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -101,13 +99,22 @@
to participate in next rendezvous.
"""

from .api import ( # noqa: F401
RendezvousClosedError,
RendezvousConnectionError,
RendezvousError,
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
RendezvousStateError,
RendezvousTimeoutError,
)
from .api import *
from .registry import _register_default_handlers


_register_default_handlers()


__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]
54 changes: 33 additions & 21 deletions torch/distributed/elastic/rendezvous/api.py
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import abc
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, final

from torch.distributed import Store

Expand Down Expand Up @@ -219,51 +219,63 @@ def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]


class RendezvousHandlerFactory:
"""
Creates ``RendezvousHandler`` instances for supported rendezvous backends.
"""
@final
class RendezvousHandlerRegistry:
"""Represents a registry of `RendezvousHandler` backends."""

def __init__(self):
self._registry: Dict[str, RendezvousHandlerCreator] = {}
_registry: Dict[str, RendezvousHandlerCreator]

def register(self, backend: str, creator: RendezvousHandlerCreator):
"""
Registers a new rendezvous backend.
def __init__(self) -> None:
self._registry = {}

def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
"""Registers a new rendezvous backend.
Args:
backend:
The name of the backend.
creater:
The callback to invoke to construct the `RendezvousHandler`.
"""
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")

current_creator: Optional[RendezvousHandlerCreator]
try:
current_creator = self._registry[backend]
except KeyError:
current_creator = None # type: ignore[assignment]
current_creator = None

if current_creator is not None:
if current_creator is not None and current_creator != creator:
raise ValueError(
f"The rendezvous backend '{backend}' cannot be registered with"
f" '{creator.__module__}.{creator.__name__}' as it is already"
f" registered with '{current_creator.__module__}.{current_creator.__name__}'."
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
f"is already registered with '{current_creator}'."
)

self._registry[backend] = creator

def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""
Creates a new ``RendezvousHandler`` instance for the specified backend.
"""
"""Creates a new `RendezvousHandler`."""
try:
creator = self._registry[params.backend]
except KeyError:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call {self.register.__name__}?"
f"to call `{self.register.__name__}`?"
)

handler = creator(params)

# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
f"The rendezvous handler backend '{handler.get_backend()}' does not match the "
f"requested backend '{params.backend}'."
f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
f"backend '{params.backend}'."
)

return handler


# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()
22 changes: 13 additions & 9 deletions torch/distributed/elastic/rendezvous/registry.py
Expand Up @@ -4,16 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from . import etcd_rendezvous
from .api import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
)
from .api import RendezvousHandler, RendezvousParameters
from .api import rendezvous_handler_registry as handler_registry

_factory = RendezvousHandlerFactory()
_factory.register("etcd", etcd_rendezvous.create_rdzv_handler)

def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import etcd_rendezvous

return etcd_rendezvous.create_rdzv_handler(params)


def _register_default_handlers() -> None:
handler_registry.register("etcd", _create_etcd_handler)


# The legacy function kept for backwards compatibility.
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
return _factory.create_handler(params)
return handler_registry.create_handler(params)

0 comments on commit df299db

Please sign in to comment.