Skip to content

Commit

Permalink
Add type annotations to torch._C._distributed_rpc module.
Browse files Browse the repository at this point in the history
ghstack-source-id: cd8c93b49d1c02174f3bc31f3ef6f2a07308a0fa
Pull Request resolved: #46624
  • Loading branch information
xuzhao9 committed Nov 5, 2020
1 parent f5964ef commit 841f462
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 51 deletions.
6 changes: 6 additions & 0 deletions mypy.ini
Expand Up @@ -53,6 +53,12 @@ ignore_errors = True
[mypy-torch.distributed.*]
ignore_errors = True

[mypy-torch.distributed.rpc.*]
ignore_errors = False

[mypy-torch.distributed.rpc._testing.*]
ignore_errors = True

[mypy-torch.testing._internal.hypothesis_utils.*]
ignore_errors = True

Expand Down
8 changes: 4 additions & 4 deletions torch/_C/_autograd.pyi
Expand Up @@ -4,10 +4,10 @@ from enum import Enum
# Defined in tools/autograd/init.cpp

class ProfilerState(Enum):
Disable = 0
CPU = 1
CUDA = 2
NVTX = 3
Disable = ...
CPU = ...
CUDA = ...
NVTX = ...


class ProfilerConfig:
Expand Down
16 changes: 8 additions & 8 deletions torch/_C/_distributed_c10d.pyi
Expand Up @@ -31,14 +31,14 @@ class Reducer:
...

class ReduceOp(Enum):
SUM = 0
PRODUCT = 1
MIN = 2
MAX = 3
BAND = 4
BOR = 5
BXOR = 6
UNUSED = 7
SUM = ...
PRODUCT = ...
MIN = ...
MAX = ...
BAND = ...
BOR = ...
BXOR = ...
UNUSED = ...

class BroadcastOptions:
rootRank: int
Expand Down
194 changes: 194 additions & 0 deletions torch/_C/_distributed_rpc.pyi
@@ -0,0 +1,194 @@
from typing import Tuple, Dict, Optional, List, Any, overload
from datetime import timedelta
import enum
import torch
from . import Future
from ._autograd import ProfilerConfig, ProfilerState, ProfilerEvent
from ._distributed_c10d import ProcessGroup, Store

# This module is defined in torch/csrc/distributed/rpc/init.cpp

_DEFAULT_NUM_SEND_RECV_THREADS: int
_DEFAULT_INIT_METHOD: str
_DEFAULT_NUM_WORKER_THREADS: int
_UNSET_RPC_TIMEOUT: float
_DEFAULT_RPC_TIMEOUT_SEC: float

class RpcBackendOptions:
rpc_timeout: float
init_method: str
def __init__(
self,
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = _DEFAULT_INIT_METHOD,
): ...

class WorkerInfo:
def __init__(self, name: str, worker_id: int): ...
@property
def name(self) -> str: ...
@property
def id(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
def __repr__(self) -> str: ...

class RpcAgent:
def join(self): ...
def sync(self): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def get_debug_info(self) -> Dict[str, str]: ...
def get_metrics(self) -> Dict[str, str]: ...

class PyRRef:
def __init__(self, value: Any, type_hint: Any = None): ...
def is_owner(self) -> bool: ...
def confirmed_by_owner(self) -> bool: ...
def owner(self) -> WorkerInfo: ...
def owner_name(self) -> str: ...
def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
def local_value(self) -> Any: ...
def rpc_sync(self) -> Any: ...
def rpc_async(self) -> Any: ...
def remote(self) -> Any: ...
def _serialize(self) -> Tuple: ...
@staticmethod
def _deserialize(tp: Tuple) -> 'PyRRef': ...
def _get_type(self) -> Any: ...
def _get_future(self) -> Future: ...
def _get_profiling_future(self) -> Future: ...
def _set_profiling_future(self, profilingFuture: Future): ...
def __repr__(self) -> str: ...
...

class ProcessGroupRpcBackendOptions(RpcBackendOptions):
num_send_recv_threads: int
def __init__(
self,
num_send_recv_threads: int,
rpc_timeout: float,
init_method: str
): ...

class ProcessGroupAgent(RpcAgent):
def __init__(
self,
worker_name: str,
pg: ProcessGroup,
numSendRecvThreads: int,
rpcTimeout: timedelta
): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def join(self): ...
def shutdown(self): ...
def sync(self): ...

class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
num_worker_threads: int
device_maps: Dict[str, Dict[int, int]]
def __init__(
self,
num_worker_threads: int,
_transports: Optional[List],
_channels: Optional[List],
rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = _DEFAULT_INIT_METHOD,
device_maps: Dict[str, Dict[int, int]] = dict()): ...
def set_device_map(self, to: str, device_map: Dict[str, Dict[int, int]]): ...

class TensorPipeAgent(RpcAgent):
def __init__(
self,
store: Store,
name: str,
worker_id: int,
world_size: int,
pg: ProcessGroup,
opts: _TensorPipeRpcBackendOptionsBase,
): ...
def join(self): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...
@overload
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
@overload
def get_worker_info(self, id: int) -> WorkerInfo: ...
def get_worker_infos(self) -> List[WorkerInfo]: ...
def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[int, int]]): ...

def _is_current_rpc_agent_set() -> bool: ...
def _get_current_rpc_agent()-> RpcAgent: ...
def _set_and_start_rpc_agent(agent: RpcAgent): ...
def _reset_current_rpc_agent(): ...
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
def _destroy_rref_context(ignoreRRefLeak: bool): ...
def _rref_context_get_debug_info() -> Dict[str, str]: ...
def _cleanup_python_rpc_handler(): ...
def _invoke_rpc_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any
): ...
def _invoke_rpc_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool
): ...
def _invoke_rpc_torchscript(
dstWorkerName: str,
qualifiedNameStr: str,
argsTuple: Tuple,
kwargsDict: Dict,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_builtin(
dst: WorkerInfo,
opName: str,
rpcTimeoutSeconds: float,
*args: Any,
**kwargs: Any
): ...
def _invoke_remote_python_udf(
dst: WorkerInfo,
pickledPythonUDF: str,
tensors: List[torch.Tensor],
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
): ...
def _invoke_remote_torchscript(
dstWorkerName: WorkerInfo,
qualifiedNameStr: str,
rpcTimeoutSeconds: float,
isAsyncExecution: bool,
*args: Any,
**kwargs: Any
): ...
def get_rpc_timeout() -> float: ...
def enable_gil_profiling(flag: bool): ...
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...

class RemoteProfilerManager:
@staticmethod
def set_current_profiling_key(key: str): ...

def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
def _set_profiler_node_id(default_node_id: int): ...
def _enable_jit_rref_pickle(): ...
def _disable_jit_rref_pickle(): ...
22 changes: 21 additions & 1 deletion torch/csrc/distributed/rpc/init.cpp
Expand Up @@ -38,7 +38,13 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
throw python_error();
}

auto module = py::handle(rpc_module).cast<py::module>();
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module)
return nullptr;
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");

auto module = py::handle(m).cast<py::module>();

auto rpcBackendOptions =
shared_ptr_class_<RpcBackendOptions>(
Expand Down Expand Up @@ -114,6 +120,20 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
"join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>())
.def(
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
.def(
"shutdown",
&RpcAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(void)const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (RpcAgent::*)(const std::string&)const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
&RpcAgent::getWorkerInfos,
Expand Down
39 changes: 31 additions & 8 deletions torch/distributed/rpc/__init__.py
@@ -1,6 +1,7 @@
import logging
import threading

from typing import Generator, Tuple
import torch
import torch.distributed as dist

Expand All @@ -20,12 +21,33 @@ def is_available():


if is_available():
from . import api, backend_registry, functions, _set_profiler_node_id
from . import (
from . import api, backend_registry, functions
from torch._C._distributed_rpc import (
_disable_jit_rref_pickle,
_enable_jit_rref_pickle,
_disable_server_process_global_profiler,
_enable_server_process_global_profiler,
_set_and_start_rpc_agent,
_set_profiler_node_id,
_is_current_rpc_agent_set,
_rref_context_get_debug_info,
_set_rpc_timeout,
_get_current_rpc_agent,
get_rpc_timeout,
enable_gil_profiling,
RpcBackendOptions,
_TensorPipeRpcBackendOptionsBase,
ProcessGroupRpcBackendOptions,
ProcessGroupAgent,
TensorPipeAgent,
WorkerInfo,
_DEFAULT_INIT_METHOD,
_DEFAULT_NUM_SEND_RECV_THREADS,
_DEFAULT_NUM_WORKER_THREADS,
_UNSET_RPC_TIMEOUT,
_DEFAULT_RPC_TIMEOUT_SEC,
) # noqa: F401
from torch._C._distributed_c10d import Store
from .api import * # noqa: F401
from .options import TensorPipeRpcBackendOptions # noqa: F401
from .backend_registry import BackendType
Expand All @@ -36,6 +58,7 @@ def is_available():

import numbers

rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]

def init_rpc(
name,
Expand Down Expand Up @@ -104,18 +127,19 @@ def init_rpc(
raise TypeError(
f"Could not infer backend for options {rpc_backend_options}"
)
if backend != BackendType.TENSORPIPE:
# Ignore type error because mypy doesn't handle dynamically generated type objects (#4865)
if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined]
logger.warning(
f"RPC was initialized with no explicit backend but with options "
f"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined]
f"corresponding to {backend}, hence that backend will be used "
f"instead of the default {BackendType.TENSORPIPE}. To silence this "
f"warning pass `backend={backend}` explicitly."
)

if backend is None:
backend = BackendType.TENSORPIPE
backend = BackendType.TENSORPIPE # type: ignore[attr-defined]

if backend == BackendType.PROCESS_GROUP:
if backend == BackendType.PROCESS_GROUP: # type: ignore[attr-defined]
logger.warning(
"RPC was initialized with the PROCESS_GROUP backend which is "
"deprecated and slated to be removed and superseded by the TENSORPIPE "
Expand Down Expand Up @@ -176,7 +200,7 @@ def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_optio


def _init_rpc_backend(
backend=backend_registry.BackendType.TENSORPIPE,
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
store=None,
name=None,
rank=-1,
Expand Down Expand Up @@ -204,7 +228,6 @@ def _init_rpc_backend(

@api._require_initialized
def _get_debug_info():
from . import _rref_context_get_debug_info
info = _rref_context_get_debug_info()
info.update(api._get_current_rpc_agent().get_debug_info())
info.update(dist_autograd._get_debug_info())
Expand Down

0 comments on commit 841f462

Please sign in to comment.