Skip to content

Commit

Permalink
Add Python declaration of torch._C and torch._C._autograd modules. (#…
Browse files Browse the repository at this point in the history
…46622)

Summary: Pull Request resolved: #46622

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D24761503

Pulled By: xuzhao9

fbshipit-source-id: c7ff9a9e46480a83bf6961e09972b5d20bdeb67b
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 6, 2020
1 parent fccfe7b commit fe77ded
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 2 deletions.
21 changes: 21 additions & 0 deletions torch/_C/__init__.pyi.in
Expand Up @@ -293,7 +293,15 @@ class Node:


# Defined in torch/aten/src/ATen/core/function_schema.h
class Argument:
name: str
type: JitType
default_value: Optional[Any]
def has_default_value(self) -> _bool: ...
...
class FunctionSchema:
arguments: List[Argument]
returns: List[Argument]
...

# Defined in torch/csrc/jit/python/script_init.cpp
Expand Down Expand Up @@ -332,6 +340,7 @@ class CompilationUnit:
def __init__(self) -> None: ...
def find_function(self, name: str) -> ScriptFunction: ...
def define(self, script: str, rcb: ResolutionCallback): ...
def get_interface(self, name: str) -> InterfaceType: ...

class ScriptModule:
def setattr(self, name: str, value: Any): ...
Expand Down Expand Up @@ -780,3 +789,15 @@ class Def(TreeView):

class Decl(TreeView):
...

# Defined in torch/csrc/distributed/rpc/init.cpp
def _rpc_init() -> _bool: ...

# Defined in torch/csrc/distributed/autograd/init.cpp
def _dist_autograd_init() -> _bool: ...

# Defined in torch/csrc/distributed/c10d/init.cpp
def _c10d_init() -> _bool: ...

# Defined in torch/csrc/distributed/rpc/testing/init.cpp
def _faulty_agent_init() -> _bool: ...
3 changes: 2 additions & 1 deletion torch/_C/_autograd.pyi
Expand Up @@ -12,7 +12,8 @@ class ProfilerState(Enum):

class ProfilerConfig:
def __init__(
self, state: ProfilerState,
self,
state: ProfilerState,
report_input_shapes: bool,
profile_memory: bool,
with_stack: bool
Expand Down
25 changes: 25 additions & 0 deletions torch/_C/_distributed_autograd.pyi
@@ -0,0 +1,25 @@
import torch
from typing import Dict, List, Set, Any

# This module is defined in torch/csrc/distributed/autograd/init.cpp

class DistAutogradContext:
def _context_id(self) -> int: ...
def _recv_functions(self) -> Dict[int, Any]: ...
def _send_functions(self) -> Dict[int, Any]: ...
def _known_worker_ids(self) -> Set[int]: ...

def _new_context() -> DistAutogradContext: ...
def _release_context(context_id: int) -> None: ...
def _get_max_id() -> int: ...
def _is_valid_context(worker_id: int) -> bool: ...
def _retrieve_context(context_id: int) -> DistAutogradContext: ...
def _current_context() -> DistAutogradContext: ...
def _init(worker_id: int) -> None: ...
def _get_debug_info() -> Dict[str, str]: ...
def backward(
context_id: int,
roots: List[torch.Tensor],
retain_graph = False
) -> None: ...
def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...
10 changes: 9 additions & 1 deletion torch/csrc/distributed/autograd/init.cpp
Expand Up @@ -22,7 +22,15 @@ PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
throw python_error();
}

auto module = py::handle(autograd_module).cast<py::module>();
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module) {
throw python_error();
}

auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule("_distributed_autograd", "distributed autograd bindings");

auto module = py::handle(m).cast<py::module>();

auto distAutogradContext =
shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
Expand Down
14 changes: 14 additions & 0 deletions torch/distributed/autograd/__init__.py
Expand Up @@ -10,6 +10,20 @@ def is_available():
if is_available() and not torch._C._dist_autograd_init():
raise RuntimeError("Failed to initialize torch.distributed.autograd")

if is_available():
from torch._C._distributed_autograd import (
get_gradients,
backward,
_init,
_new_context,
_release_context,
_get_max_id,
_is_valid_context,
_retrieve_context,
_current_context,
_get_debug_info,
DistAutogradContext,
)

class context(object):
'''
Expand Down

0 comments on commit fe77ded

Please sign in to comment.