Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch/elastic] Revise the rendezvous handler registry logic. #55466

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
147 changes: 91 additions & 56 deletions test/distributed/elastic/rendezvous/api_test.py
Expand Up @@ -7,67 +7,14 @@
from typing import Any, Dict, SupportsInt, Tuple, cast
from unittest import TestCase

from torch.distributed import Store
from torch.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
RendezvousHandlerRegistry,
RendezvousParameters
)


def create_mock_rdzv_handler(ignored: RendezvousParameters) -> RendezvousHandler:
return MockRendezvousHandler()


class MockRendezvousHandler(RendezvousHandler):
def next_rendezvous(
self,
# pyre-ignore[11]: Annotation `Store` is not defined as a type.
) -> Tuple["torch.distributed.Store", int, int]: # noqa F821
raise NotImplementedError()

def get_backend(self) -> str:
return "mock"

def is_closed(self) -> bool:
return False

def set_closed(self):
pass

def num_nodes_waiting(self) -> int:
return -1

def get_run_id(self) -> str:
return ""


class RendezvousHandlerFactoryTest(TestCase):
def test_double_registration(self):
factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
with self.assertRaises(ValueError):
factory.register("mock", create_mock_rdzv_handler)

def test_no_factory_method_found(self):
factory = RendezvousHandlerFactory()
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)

with self.assertRaises(ValueError):
factory.create_handler(rdzv_params)

def test_create_handler(self):
rdzv_params = RendezvousParameters(
backend="mock", endpoint="", run_id="foobar", min_nodes=1, max_nodes=2
)

factory = RendezvousHandlerFactory()
factory.register("mock", create_mock_rdzv_handler)
mock_rdzv_handler = factory.create_handler(rdzv_params)
self.assertTrue(isinstance(mock_rdzv_handler, MockRendezvousHandler))


class RendezvousParametersTest(TestCase):
def setUp(self) -> None:
self._backend = "dummy_backend"
Expand Down Expand Up @@ -236,3 +183,91 @@ def test_get_as_int_raises_error_if_value_is_invalid(self) -> None:
r"valid integer value.$",
):
params.get_as_int("dummy_param")


class _DummyRendezvousHandler(RendezvousHandler):
def __init__(self, params: RendezvousParameters) -> None:
self.params = params

def get_backend(self) -> str:
return "dummy_backend"

def next_rendezvous(self) -> Tuple[Store, int, int]:
raise NotImplementedError()

def is_closed(self) -> bool:
return False

def set_closed(self) -> None:
pass

def num_nodes_waiting(self) -> int:
return -1

def get_run_id(self) -> str:
return ""

def shutdown(self) -> bool:
return False


class RendezvousHandlerRegistryTest(TestCase):
def setUp(self) -> None:
self._params = RendezvousParameters(
backend="dummy_backend",
endpoint="dummy_endpoint",
run_id="dummy_run_id",
min_nodes=1,
max_nodes=1,
)

self._registry = RendezvousHandlerRegistry()

@staticmethod
def _create_handler(params: RendezvousParameters) -> RendezvousHandler:
return _DummyRendezvousHandler(params)

def test_register_registers_once_if_called_twice_with_same_creator(self) -> None:
self._registry.register("dummy_backend", self._create_handler)
self._registry.register("dummy_backend", self._create_handler)

def test_register_raises_error_if_called_twice_with_different_creators(self) -> None:
self._registry.register("dummy_backend", self._create_handler)

other_create_handler = lambda p: _DummyRendezvousHandler(p) # noqa: E731

with self.assertRaisesRegex(
ValueError,
r"^The rendezvous backend 'dummy_backend' cannot be registered with "
rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$",
):
self._registry.register("dummy_backend", other_create_handler)

def test_create_handler_returns_handler(self) -> None:
self._registry.register("dummy_backend", self._create_handler)

handler = self._registry.create_handler(self._params)

self.assertIsInstance(handler, _DummyRendezvousHandler)

self.assertIs(handler.params, self._params)

def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None:
with self.assertRaisesRegex(
ValueError,
r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call "
r"`register`\?$",
):
self._registry.create_handler(self._params)

def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None:
self._registry.register("dummy_backend_2", self._create_handler)

with self.assertRaisesRegex(
RuntimeError,
r"^The rendezvous backend 'dummy_backend' does not match the requested backend "
r"'dummy_backend_2'.$",
):
self._params.backend = "dummy_backend_2"

self._registry.create_handler(self._params)
30 changes: 19 additions & 11 deletions torch/distributed/elastic/rendezvous/__init__.py
@@ -1,4 +1,3 @@
#!/usr/bin/env/python3
# -*- coding: utf-8 -*-

# Copyright (c) Facebook, Inc. and its affiliates.
Expand Down Expand Up @@ -102,13 +101,22 @@
to participate in next rendezvous.
"""

from .api import ( # noqa: F401
RendezvousClosedError,
RendezvousConnectionError,
RendezvousError,
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
RendezvousStateError,
RendezvousTimeoutError,
)
from .api import *
from .registry import _register_default_handlers


_register_default_handlers()


__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]
51 changes: 31 additions & 20 deletions torch/distributed/elastic/rendezvous/api.py
Expand Up @@ -219,51 +219,62 @@ def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]


class RendezvousHandlerFactory:
"""
Creates ``RendezvousHandler`` instances for supported rendezvous backends.
"""
class RendezvousHandlerRegistry:
"""Represents a registry of `RendezvousHandler` backends."""

def __init__(self):
self._registry: Dict[str, RendezvousHandlerCreator] = {}
_registry: Dict[str, RendezvousHandlerCreator]

def register(self, backend: str, creator: RendezvousHandlerCreator):
"""
Registers a new rendezvous backend.
def __init__(self) -> None:
self._registry = {}

def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
"""Registers a new rendezvous backend.

Args:
backend:
The name of the backend.
creater:
The callback to invoke to construct the `RendezvousHandler`.
"""
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")

current_creator: Optional[RendezvousHandlerCreator]
try:
current_creator = self._registry[backend]
except KeyError:
current_creator = None # type: ignore[assignment]
current_creator = None

if current_creator is not None:
if current_creator is not None and current_creator != creator:
raise ValueError(
f"The rendezvous backend '{backend}' cannot be registered with"
f" '{creator.__module__}.{creator.__name__}' as it is already"
f" registered with '{current_creator.__module__}.{current_creator.__name__}'."
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
f"is already registered with '{current_creator}'."
)

self._registry[backend] = creator

def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""
Creates a new ``RendezvousHandler`` instance for the specified backend.
"""
"""Creates a new `RendezvousHandler`."""
try:
creator = self._registry[params.backend]
except KeyError:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call {self.register.__name__}?"
f"to call `{self.register.__name__}`?"
)

handler = creator(params)

# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
f"The rendezvous handler backend '{handler.get_backend()}' does not match the "
f"requested backend '{params.backend}'."
f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
f"backend '{params.backend}'."
)

return handler


# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()
22 changes: 13 additions & 9 deletions torch/distributed/elastic/rendezvous/registry.py
Expand Up @@ -4,16 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from . import etcd_rendezvous
from .api import (
RendezvousHandler,
RendezvousHandlerFactory,
RendezvousParameters,
)
from .api import RendezvousHandler, RendezvousParameters
from .api import rendezvous_handler_registry as handler_registry

_factory = RendezvousHandlerFactory()
_factory.register("etcd", etcd_rendezvous.create_rdzv_handler)

def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import etcd_rendezvous

return etcd_rendezvous.create_rdzv_handler(params)


def _register_default_handlers() -> None:
handler_registry.register("etcd", _create_etcd_handler)


# The legacy function kept for backwards compatibility.
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
return _factory.create_handler(params)
return handler_registry.create_handler(params)