From fe77ded48a8af49cdfb9bf41264750fbb9b936bf Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 6 Nov 2020 00:47:23 -0800 Subject: [PATCH] Add Python declaration of torch._C and torch._C._autograd modules. (#46622) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46622 Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D24761503 Pulled By: xuzhao9 fbshipit-source-id: c7ff9a9e46480a83bf6961e09972b5d20bdeb67b --- torch/_C/__init__.pyi.in | 21 ++++++++++++++++++++ torch/_C/_autograd.pyi | 3 ++- torch/_C/_distributed_autograd.pyi | 25 ++++++++++++++++++++++++ torch/csrc/distributed/autograd/init.cpp | 10 +++++++++- torch/distributed/autograd/__init__.py | 14 +++++++++++++ 5 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 torch/_C/_distributed_autograd.pyi diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 64c746e7eff2..bdebb355e33f 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -293,7 +293,15 @@ class Node: # Defined in torch/aten/src/ATen/core/function_schema.h +class Argument: + name: str + type: JitType + default_value: Optional[Any] + def has_default_value(self) -> _bool: ... + ... class FunctionSchema: + arguments: List[Argument] + returns: List[Argument] ... # Defined in torch/csrc/jit/python/script_init.cpp @@ -332,6 +340,7 @@ class CompilationUnit: def __init__(self) -> None: ... def find_function(self, name: str) -> ScriptFunction: ... def define(self, script: str, rcb: ResolutionCallback): ... + def get_interface(self, name: str) -> InterfaceType: ... class ScriptModule: def setattr(self, name: str, value: Any): ... @@ -780,3 +789,15 @@ class Def(TreeView): class Decl(TreeView): ... + +# Defined in torch/csrc/distributed/rpc/init.cpp +def _rpc_init() -> _bool: ... + +# Defined in torch/csrc/distributed/autograd/init.cpp +def _dist_autograd_init() -> _bool: ... + +# Defined in torch/csrc/distributed/c10d/init.cpp +def _c10d_init() -> _bool: ... + +# Defined in torch/csrc/distributed/rpc/testing/init.cpp +def _faulty_agent_init() -> _bool: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index a154fb1948c1..a989bb19ad8c 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -12,7 +12,8 @@ class ProfilerState(Enum): class ProfilerConfig: def __init__( - self, state: ProfilerState, + self, + state: ProfilerState, report_input_shapes: bool, profile_memory: bool, with_stack: bool diff --git a/torch/_C/_distributed_autograd.pyi b/torch/_C/_distributed_autograd.pyi new file mode 100644 index 000000000000..39cbb984c635 --- /dev/null +++ b/torch/_C/_distributed_autograd.pyi @@ -0,0 +1,25 @@ +import torch +from typing import Dict, List, Set, Any + +# This module is defined in torch/csrc/distributed/autograd/init.cpp + +class DistAutogradContext: + def _context_id(self) -> int: ... + def _recv_functions(self) -> Dict[int, Any]: ... + def _send_functions(self) -> Dict[int, Any]: ... + def _known_worker_ids(self) -> Set[int]: ... + +def _new_context() -> DistAutogradContext: ... +def _release_context(context_id: int) -> None: ... +def _get_max_id() -> int: ... +def _is_valid_context(worker_id: int) -> bool: ... +def _retrieve_context(context_id: int) -> DistAutogradContext: ... +def _current_context() -> DistAutogradContext: ... +def _init(worker_id: int) -> None: ... +def _get_debug_info() -> Dict[str, str]: ... +def backward( + context_id: int, + roots: List[torch.Tensor], + retain_graph = False +) -> None: ... +def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ... diff --git a/torch/csrc/distributed/autograd/init.cpp b/torch/csrc/distributed/autograd/init.cpp index 09de7abb87a5..ad6dfa7d8f46 100644 --- a/torch/csrc/distributed/autograd/init.cpp +++ b/torch/csrc/distributed/autograd/init.cpp @@ -22,7 +22,15 @@ PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) { throw python_error(); } - auto module = py::handle(autograd_module).cast(); + auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); + if (!torch_C_module) { + throw python_error(); + } + + auto torch_C_m = py::handle(torch_C_module).cast(); + auto m = torch_C_m.def_submodule("_distributed_autograd", "distributed autograd bindings"); + + auto module = py::handle(m).cast(); auto distAutogradContext = shared_ptr_class_(module, "DistAutogradContext") diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py index a56b41bce8c6..c8d4366e4429 100644 --- a/torch/distributed/autograd/__init__.py +++ b/torch/distributed/autograd/__init__.py @@ -10,6 +10,20 @@ def is_available(): if is_available() and not torch._C._dist_autograd_init(): raise RuntimeError("Failed to initialize torch.distributed.autograd") +if is_available(): + from torch._C._distributed_autograd import ( + get_gradients, + backward, + _init, + _new_context, + _release_context, + _get_max_id, + _is_valid_context, + _retrieve_context, + _current_context, + _get_debug_info, + DistAutogradContext, + ) class context(object): '''