diff --git a/mypy.ini b/mypy.ini index 122fde2d6cf6..5da4b4b52d62 100644 --- a/mypy.ini +++ b/mypy.ini @@ -53,6 +53,12 @@ ignore_errors = True [mypy-torch.distributed.*] ignore_errors = True +[mypy-torch.distributed.rpc.*] +ignore_errors = False + +[mypy-torch.distributed.rpc._testing.*] +ignore_errors = True + [mypy-torch.testing._internal.hypothesis_utils.*] ignore_errors = True diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index a989bb19ad8c..926457fe80ee 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -4,10 +4,10 @@ from enum import Enum # Defined in tools/autograd/init.cpp class ProfilerState(Enum): - Disable = 0 - CPU = 1 - CUDA = 2 - NVTX = 3 + Disable = ... + CPU = ... + CUDA = ... + NVTX = ... class ProfilerConfig: diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f50d3c829656..115284567439 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -31,14 +31,14 @@ class Reducer: ... class ReduceOp(Enum): - SUM = 0 - PRODUCT = 1 - MIN = 2 - MAX = 3 - BAND = 4 - BOR = 5 - BXOR = 6 - UNUSED = 7 + SUM = ... + PRODUCT = ... + MIN = ... + MAX = ... + BAND = ... + BOR = ... + BXOR = ... + UNUSED = ... class BroadcastOptions: rootRank: int diff --git a/torch/_C/_distributed_rpc.pyi b/torch/_C/_distributed_rpc.pyi new file mode 100644 index 000000000000..216799cb763a --- /dev/null +++ b/torch/_C/_distributed_rpc.pyi @@ -0,0 +1,194 @@ +from typing import Tuple, Dict, Optional, List, Any, overload +from datetime import timedelta +import enum +import torch +from . import Future +from ._autograd import ProfilerConfig, ProfilerState, ProfilerEvent +from ._distributed_c10d import ProcessGroup, Store + +# This module is defined in torch/csrc/distributed/rpc/init.cpp + +_DEFAULT_NUM_SEND_RECV_THREADS: int +_DEFAULT_INIT_METHOD: str +_DEFAULT_NUM_WORKER_THREADS: int +_UNSET_RPC_TIMEOUT: float +_DEFAULT_RPC_TIMEOUT_SEC: float + +class RpcBackendOptions: + rpc_timeout: float + init_method: str + def __init__( + self, + rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = _DEFAULT_INIT_METHOD, + ): ... + +class WorkerInfo: + def __init__(self, name: str, worker_id: int): ... + @property + def name(self) -> str: ... + @property + def id(self) -> int: ... + def __eq__(self, other: object) -> bool: ... + def __repr__(self) -> str: ... + +class RpcAgent: + def join(self): ... + def sync(self): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def get_debug_info(self) -> Dict[str, str]: ... + def get_metrics(self) -> Dict[str, str]: ... + +class PyRRef: + def __init__(self, value: Any, type_hint: Any = None): ... + def is_owner(self) -> bool: ... + def confirmed_by_owner(self) -> bool: ... + def owner(self) -> WorkerInfo: ... + def owner_name(self) -> str: ... + def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ... + def local_value(self) -> Any: ... + def rpc_sync(self) -> Any: ... + def rpc_async(self) -> Any: ... + def remote(self) -> Any: ... + def _serialize(self) -> Tuple: ... + @staticmethod + def _deserialize(tp: Tuple) -> 'PyRRef': ... + def _get_type(self) -> Any: ... + def _get_future(self) -> Future: ... + def _get_profiling_future(self) -> Future: ... + def _set_profiling_future(self, profilingFuture: Future): ... + def __repr__(self) -> str: ... + ... + +class ProcessGroupRpcBackendOptions(RpcBackendOptions): + num_send_recv_threads: int + def __init__( + self, + num_send_recv_threads: int, + rpc_timeout: float, + init_method: str + ): ... + +class ProcessGroupAgent(RpcAgent): + def __init__( + self, + worker_name: str, + pg: ProcessGroup, + numSendRecvThreads: int, + rpcTimeout: timedelta + ): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + @overload + def get_worker_info(self, id: int) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def join(self): ... + def shutdown(self): ... + def sync(self): ... + +class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions): + num_worker_threads: int + device_maps: Dict[str, Dict[int, int]] + def __init__( + self, + num_worker_threads: int, + _transports: Optional[List], + _channels: Optional[List], + rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = _DEFAULT_INIT_METHOD, + device_maps: Dict[str, Dict[int, int]] = dict()): ... + def set_device_map(self, to: str, device_map: Dict[str, Dict[int, int]]): ... + +class TensorPipeAgent(RpcAgent): + def __init__( + self, + store: Store, + name: str, + worker_id: int, + world_size: int, + pg: ProcessGroup, + opts: _TensorPipeRpcBackendOptionsBase, + ): ... + def join(self): ... + def shutdown(self): ... + @overload + def get_worker_info(self) -> WorkerInfo: ... + @overload + def get_worker_info(self, workerName: str) -> WorkerInfo: ... + @overload + def get_worker_info(self, id: int) -> WorkerInfo: ... + def get_worker_infos(self) -> List[WorkerInfo]: ... + def _set_reverse_device_maps(self, reverseDeviceMaps: Dict[str, Dict[int, int]]): ... + +def _is_current_rpc_agent_set() -> bool: ... +def _get_current_rpc_agent()-> RpcAgent: ... +def _set_and_start_rpc_agent(agent: RpcAgent): ... +def _reset_current_rpc_agent(): ... +def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ... +def _destroy_rref_context(ignoreRRefLeak: bool): ... +def _rref_context_get_debug_info() -> Dict[str, str]: ... +def _cleanup_python_rpc_handler(): ... +def _invoke_rpc_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any + ): ... +def _invoke_rpc_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: List[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool + ): ... +def _invoke_rpc_torchscript( + dstWorkerName: str, + qualifiedNameStr: str, + argsTuple: Tuple, + kwargsDict: Dict, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + ): ... +def _invoke_remote_builtin( + dst: WorkerInfo, + opName: str, + rpcTimeoutSeconds: float, + *args: Any, + **kwargs: Any + ): ... +def _invoke_remote_python_udf( + dst: WorkerInfo, + pickledPythonUDF: str, + tensors: List[torch.Tensor], + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + ): ... +def _invoke_remote_torchscript( + dstWorkerName: WorkerInfo, + qualifiedNameStr: str, + rpcTimeoutSeconds: float, + isAsyncExecution: bool, + *args: Any, + **kwargs: Any + ): ... +def get_rpc_timeout() -> float: ... +def enable_gil_profiling(flag: bool): ... +def _set_rpc_timeout(rpcTimeoutSeconds: float): ... + +class RemoteProfilerManager: + @staticmethod + def set_current_profiling_key(key: str): ... + +def _enable_server_process_global_profiler(new_config: ProfilerConfig): ... +def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ... +def _set_profiler_node_id(default_node_id: int): ... +def _enable_jit_rref_pickle(): ... +def _disable_jit_rref_pickle(): ... diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 2f608bd6fd33..cb332bd25e1b 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -38,7 +38,13 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { throw python_error(); } - auto module = py::handle(rpc_module).cast(); + auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); + if (!torch_C_module) + return nullptr; + auto torch_C_m = py::handle(torch_C_module).cast(); + auto m = torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings"); + + auto module = py::handle(m).cast(); auto rpcBackendOptions = shared_ptr_class_( @@ -114,6 +120,20 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { "join", &RpcAgent::join, py::call_guard()) .def( "sync", &RpcAgent::sync, py::call_guard()) + .def( + "shutdown", + &RpcAgent::shutdown, + py::call_guard()) + .def( + "get_worker_info", + (const WorkerInfo& (RpcAgent::*)(void)const) & + RpcAgent::getWorkerInfo, + py::call_guard()) + .def( + "get_worker_info", + (const WorkerInfo& (RpcAgent::*)(const std::string&)const) & + RpcAgent::getWorkerInfo, + py::call_guard()) .def( "get_worker_infos", &RpcAgent::getWorkerInfos, diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index 2c579bcc8fe9..8eb95fee9b92 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -1,6 +1,7 @@ import logging import threading +from typing import Generator, Tuple import torch import torch.distributed as dist @@ -20,12 +21,33 @@ def is_available(): if is_available(): - from . import api, backend_registry, functions, _set_profiler_node_id - from . import ( + from . import api, backend_registry, functions + from torch._C._distributed_rpc import ( _disable_jit_rref_pickle, _enable_jit_rref_pickle, + _disable_server_process_global_profiler, + _enable_server_process_global_profiler, _set_and_start_rpc_agent, + _set_profiler_node_id, + _is_current_rpc_agent_set, + _rref_context_get_debug_info, + _set_rpc_timeout, + _get_current_rpc_agent, + get_rpc_timeout, + enable_gil_profiling, + RpcBackendOptions, + _TensorPipeRpcBackendOptionsBase, + ProcessGroupRpcBackendOptions, + ProcessGroupAgent, + TensorPipeAgent, + WorkerInfo, + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_SEND_RECV_THREADS, + _DEFAULT_NUM_WORKER_THREADS, + _UNSET_RPC_TIMEOUT, + _DEFAULT_RPC_TIMEOUT_SEC, ) # noqa: F401 + from torch._C._distributed_c10d import Store from .api import * # noqa: F401 from .options import TensorPipeRpcBackendOptions # noqa: F401 from .backend_registry import BackendType @@ -36,6 +58,7 @@ def is_available(): import numbers + rendezvous_iterator: Generator[Tuple[Store, int, int], None, None] def init_rpc( name, @@ -104,18 +127,19 @@ def init_rpc( raise TypeError( f"Could not infer backend for options {rpc_backend_options}" ) - if backend != BackendType.TENSORPIPE: + # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) + if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] logger.warning( - f"RPC was initialized with no explicit backend but with options " + f"RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] f"corresponding to {backend}, hence that backend will be used " f"instead of the default {BackendType.TENSORPIPE}. To silence this " f"warning pass `backend={backend}` explicitly." ) if backend is None: - backend = BackendType.TENSORPIPE + backend = BackendType.TENSORPIPE # type: ignore[attr-defined] - if backend == BackendType.PROCESS_GROUP: + if backend == BackendType.PROCESS_GROUP: # type: ignore[attr-defined] logger.warning( "RPC was initialized with the PROCESS_GROUP backend which is " "deprecated and slated to be removed and superseded by the TENSORPIPE " @@ -176,7 +200,7 @@ def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_optio def _init_rpc_backend( - backend=backend_registry.BackendType.TENSORPIPE, + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] store=None, name=None, rank=-1, @@ -204,7 +228,6 @@ def _init_rpc_backend( @api._require_initialized def _get_debug_info(): - from . import _rref_context_get_debug_info info = _rref_context_get_debug_info() info.update(api._get_current_rpc_agent().get_debug_info()) info.update(dist_autograd._get_debug_info()) diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 7cb99066b507..e88ced794454 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -4,11 +4,11 @@ import inspect import logging import threading -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Set, Any import torch -from . import ( +from torch._C._distributed_rpc import ( PyRRef, RemoteProfilerManager, WorkerInfo, @@ -99,10 +99,10 @@ def __init__(self): # States used by `def _all_gather()`. # `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer. -_ALL_WORKER_NAMES = None +_ALL_WORKER_NAMES: Set[Any] = set() _all_gather_dict_lock = threading.RLock() _all_gather_sequence_id = 0 -_all_gather_sequence_id_to_states = collections.defaultdict(AllGatherStates) +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(AllGatherStates) def _init_rpc_states(agent): @@ -379,16 +379,18 @@ def _rref_typeof_on_user(rref): try: # Combine the implementation class and the type class. - class RRef(PyRRef, GenericWithOneTypeVar): + class RRef(PyRRef, Generic[T]): pass except TypeError as exc: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases - class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): + # Mypy doesn't understand __class__ (mypy bug #4177) + class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore pass # Combine the implementation class and the type class. - class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): + # Types for classes expecting a certain generic parameter (mypy bug #7791) + class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore pass @@ -564,7 +566,8 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) - ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] with ctx_manager as rf: args = args if args else () @@ -639,7 +642,8 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP dst_worker_info.name, ) RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) - ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] with ctx_manager as rf: args = args if args else () diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 6dac7cb0863a..fe2acbe807b1 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -28,8 +28,10 @@ def _backend_type_repr(self): """ # Create an enum type, `BackendType`, with empty members. -BackendType = enum.Enum(value="BackendType", names={}) -BackendType.__repr__ = _backend_type_repr +# Can't handle Function Enum API (mypy bug #9079) +BackendType = enum.Enum(value="BackendType", names=dict()) # type: ignore[misc] +# Unable to assign a function a method (mypy bug #2427) +BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] BackendType.__doc__ = _backend_type_doc def backend_registered(backend_name): @@ -73,8 +75,10 @@ def register_backend( }, **existing_enum_dict ) - BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) - BackendType.__repr__ = _backend_type_repr + # Can't handle Function Enum API (mypy bug #9079) + BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] + # Unable to assign a function a method (mypy bug #2427) + BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] BackendType.__doc__ = _backend_type_doc return BackendType[backend_name] diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py index c2dd804e4c81..e6d79e6e5981 100644 --- a/torch/distributed/rpc/constants.py +++ b/torch/distributed/rpc/constants.py @@ -1,6 +1,6 @@ from datetime import timedelta -from . import ( +from torch._C._distributed_rpc import ( _DEFAULT_INIT_METHOD, _DEFAULT_NUM_SEND_RECV_THREADS, _DEFAULT_NUM_WORKER_THREADS, @@ -10,16 +10,16 @@ # For any RpcAgent. -DEFAULT_RPC_TIMEOUT_SEC = _DEFAULT_RPC_TIMEOUT_SEC -DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD -DEFAULT_SHUTDOWN_TIMEOUT = 5.0 +DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC +DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD +DEFAULT_SHUTDOWN_TIMEOUT: float = 5.0 # For ProcessGroupAgent. -DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS +DEFAULT_NUM_SEND_RECV_THREADS: int = _DEFAULT_NUM_SEND_RECV_THREADS # For TensorPipeAgent. -DEFAULT_NUM_WORKER_THREADS = _DEFAULT_NUM_WORKER_THREADS +DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS # Ensure that we don't time out when there are long periods of time without # any operations against the underlying ProcessGroup. -DEFAULT_PROCESS_GROUP_TIMEOUT = timedelta(milliseconds=2 ** 31 - 1) +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2 ** 31 - 1) # Value indicating that timeout is not set for RPC call, and the default should be used. -UNSET_RPC_TIMEOUT = _UNSET_RPC_TIMEOUT +UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT diff --git a/torch/distributed/rpc/functions.py b/torch/distributed/rpc/functions.py index d761f7b4046b..f0d106c53844 100644 --- a/torch/distributed/rpc/functions.py +++ b/torch/distributed/rpc/functions.py @@ -160,5 +160,6 @@ def async_execution(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): return fn(*args, **kwargs) - wrapper._wrapped_async_rpc_function = fn + # Can't declare and use attributes of function objects (mypy#2087) + wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] return wrapper diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 08ffcbf0bcaf..073582a93e63 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -10,7 +10,7 @@ import torch import torch.distributed as dist -from . import _get_current_rpc_agent +from torch._C._distributed_rpc import _get_current_rpc_agent # Thread local tensor tables to store tensors while pickling torch.Tensor @@ -37,7 +37,8 @@ class _InternalRPCPickler: """ def __init__(self): - self._dispatch_table = copyreg.dispatch_table.copy() + # Ignore type error because dispatch_table is defined in third-party package + self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] self._dispatch_table[torch.Tensor] = self._tensor_reducer @classmethod @@ -80,9 +81,11 @@ def serialize(self, obj): # # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`. # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`. - p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index] # An RRef created locally by RRef Python constructor is type of `rpc.RRef`. - p.dispatch_table[dist.rpc.RRef] = self._rref_reducer + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index] # save _thread_local_tensor_tables.send_tables if it is in nested call global _thread_local_tensor_tables @@ -224,8 +227,8 @@ def _start_record_function(exec_type, func_name, current_worker_name, dest_worke profile_key = "rpc_{}#{}({} -> {})".format( exec_type.value, str(func_name), current_worker_name, dest_worker_name ) - rf = torch.autograd._RecordFunction() - torch.autograd._run_before_callbacks(rf, profile_key) + rf = torch.autograd._RecordFunction() # type: ignore[attr-defined] + torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined] return rf diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 149a2544d217..edec5778da6e 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -1,4 +1,4 @@ -from . import _TensorPipeRpcBackendOptionsBase +from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase from . import constants as rpc_contants import torch diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 8dc25a6a56da..47be7c4c0ff9 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -2601,7 +2601,7 @@ def test_rref_context_debug_info(self): @dist_init def test_disable_gil_profiling(self): - # test that rpc.enable_gil_profilig(false) will result in + # test that rpc.enable_gil_profiling(false) will result in # GIL wait time not being recorded. # GIL profiling should be disabled by default.