Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable torch.autograd typechecks #44451

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -233,9 +233,6 @@ ignore_errors = True
[mypy-torch.utils.hipify.hipify_python]
ignore_errors = True

[mypy-torch.autograd]
ignore_errors = True

[mypy-torch.autograd._functions.tensor]
ignore_errors = True

Expand Down
12 changes: 6 additions & 6 deletions test/test_type_hints.py
Expand Up @@ -86,7 +86,7 @@ def get_all_examples():
if docstr and fname not in blocklist:
e = get_examples_from_docstring(docstr)
if e:
example_file_lines.append("\n\ndef example_torch_{}():".format(fname))
example_file_lines.append(f"\n\ndef example_torch_{fname}():")
example_file_lines += e

for fname in dir(torch.Tensor):
Expand All @@ -95,7 +95,7 @@ def get_all_examples():
if docstr and fname not in blocklist:
e = get_examples_from_docstring(docstr)
if e:
example_file_lines.append("\n\ndef example_torch_tensor_{}():".format(fname))
example_file_lines.append(f"\n\ndef example_torch_tensor_{fname}():")
example_file_lines += e

return "\n".join(example_file_lines)
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_doc_examples(self):
os.path.abspath(fn),
])
if result != 0:
self.fail("mypy failed:\n{}".format(stdout))
self.fail(f"mypy failed:\n{stdout}")

@unittest.skipIf(not HAVE_MYPY, "need mypy")
def test_type_hint_examples(self):
Expand All @@ -175,7 +175,7 @@ def test_type_hint_examples(self):
example_path,
])
if result != 0:
self.fail("mypy failed for exampl {}\n{}".format(example, stdout))
self.fail(f"mypy failed for example {example}\n{stdout}")

@unittest.skipIf(not HAVE_MYPY, "need mypy")
def test_run_mypy(self):
Expand Down Expand Up @@ -215,7 +215,7 @@ def is_torch_mypyini(path_to_file):
finally:
os.chdir(cwd)
if result != 0:
self.fail("mypy failed: {} {}".format(stdout, stderr))
self.fail(f"mypy failed: {stdout} {stderr}")

@unittest.skipIf(not HAVE_MYPY, "need mypy")
def test_run_mypy_strict(self):
Expand All @@ -237,7 +237,7 @@ def test_run_mypy_strict(self):
finally:
os.chdir(cwd)
if result != 0:
self.fail("mypy failed: {} {}".format(stdout, stderr))
self.fail(f"mypy failed: {stdout} {stderr}")

if __name__ == '__main__':
run_tests()
40 changes: 40 additions & 0 deletions torch/_C/_autograd.pyi
@@ -0,0 +1,40 @@
from typing import List
from enum import Enum

# Defined in tools/autograd/init.cpp

class ProfilerState(Enum):
Disable = 0
CPU = 1
CUDA = 2
NVTX = 3


class ProfilerConfig:
def __init__(self, state: ProfilerState, report_input_shapes: bool, profile_memory: bool) -> None: ...
...


class ProfilerEvent:
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cpu_memory_usage(self) -> int: ...
def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cuda_memory_usage(self) -> int: ...
def device(self) -> int: ...
def handle(self) -> int: ...
def has_cuda(self) -> bool: ...
def is_remote(self) -> bool: ...
def kind(self) -> int: ...
def name(self) -> str: ...
def node_id(self) -> int: ...
def sequence_nr(self) -> int: ...
def shapes(self) -> List[List[int]]: ...
def thread_id(self) -> int: ...
...


def _enable_profiler(config: ProfilerConfig) -> None: ...
def _disable_profiler() -> List[List[ProfilerEvent]]: ...
def _profiler_enabled() -> bool: ...
def _enable_record_function(enable: bool) -> None: ...
def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
47 changes: 25 additions & 22 deletions torch/autograd/__init__.py
Expand Up @@ -8,8 +8,9 @@
"""
import torch
import warnings
from typing import Any, Callable, Union, Tuple, Sequence, Optional

from torch.types import _TensorOrTensors
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

from .variable import Variable
from .function import Function, NestedIOFunction
Expand All @@ -22,9 +23,10 @@

__all__ = ['Variable', 'Function', 'backward', 'grad_mode']

_OptionalTensor = Optional[torch.Tensor]

def _make_grads(outputs, grads):
new_grads = []
def _make_grads(outputs: Sequence[torch.Tensor], grads: Sequence[_OptionalTensor]) -> Tuple[_OptionalTensor, ...]:
new_grads: List[_OptionalTensor] = []
for out, grad in zip(outputs, grads):
if isinstance(grad, torch.Tensor):
if not out.shape == grad.shape:
Expand All @@ -33,7 +35,7 @@ def _make_grads(outputs, grads):
+ str(grad.shape) + " and output["
+ str(outputs.index(out)) + "] has a shape of "
+ str(out.shape) + ".")
if (out.dtype.is_complex != grad.dtype.is_complex):
if out.dtype.is_complex != grad.dtype.is_complex:
raise RuntimeError("For complex Tensors, both grad_output and output"
" are required to have the same dtype."
" Mismatch in dtype: grad_output["
Expand All @@ -55,6 +57,14 @@ def _make_grads(outputs, grads):
return tuple(new_grads)


def _tensor_or_tensors_to_tuple(tensors: Optional[_TensorOrTensors], length: int) -> Tuple[_OptionalTensor, ...]:
if tensors is None:
return (None, ) * length
if isinstance(tensors, torch.Tensor):
return (tensors, )
return tuple(tensors)


def backward(
tensors: _TensorOrTensors,
grad_tensors: Optional[_TensorOrTensors] = None,
Expand Down Expand Up @@ -112,19 +122,13 @@ def backward(

tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)

if grad_tensors is None:
grad_tensors = [None] * len(tensors)
elif isinstance(grad_tensors, torch.Tensor):
grad_tensors = [grad_tensors]
else:
grad_tensors = list(grad_tensors)

grad_tensors = _make_grads(tensors, grad_tensors)
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
grad_tensors_ = _make_grads(tensors, grad_tensors_)
if retain_graph is None:
retain_graph = create_graph

Variable._execution_engine.run_backward(
tensors, grad_tensors, retain_graph, create_graph,
tensors, grad_tensors_, retain_graph, create_graph,
allow_unreachable=True) # allow_unreachable flag


Expand Down Expand Up @@ -189,20 +193,14 @@ def grad(
"(defaults to True). To accumulate gradient for other "
"parts of the graph, please use torch.autograd.backward.")

if grad_outputs is None:
grad_outputs = [None] * len(outputs)
elif isinstance(grad_outputs, torch.Tensor):
grad_outputs = [grad_outputs]
else:
grad_outputs = list(grad_outputs)

grad_outputs = _make_grads(outputs, grad_outputs)
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
grad_outputs_ = _make_grads(outputs, grad_outputs_)

if retain_graph is None:
retain_graph = create_graph

return Variable._execution_engine.run_backward(
outputs, grad_outputs, retain_graph, create_graph,
outputs, grad_outputs_, retain_graph, create_graph,
inputs, allow_unused)


Expand Down Expand Up @@ -230,3 +228,8 @@ def variable(*args, **kwargs):

if not torch._C._autograd_init():
raise RuntimeError("autograd initialization failed")

# Import all native method/classes
from torch._C._autograd import (ProfilerState, ProfilerConfig, ProfilerEvent,
_enable_profiler, _disable_profiler, _profiler_enabled,
_enable_record_function, _set_empty_test_observer)
15 changes: 10 additions & 5 deletions torch/csrc/autograd/init.cpp
Expand Up @@ -12,23 +12,28 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
using namespace torch::autograd::profiler;
auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch.tensor"));
if (!tensor_module)
throw python_error();
return nullptr;

// NOTE: "leaks" THPVariableClass
THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
if (!THPVariableClass)
throw python_error();
return nullptr;
malfet marked this conversation as resolved.
Show resolved Hide resolved

auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.autograd"));
if (!autograd_module)
throw python_error();
return nullptr;

// NOTE: "leaks" Function
THPFunctionClass = PyObject_GetAttrString(autograd_module, "Function");
if (!THPFunctionClass)
throw python_error();
return nullptr;

auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module)
return nullptr;
auto _C_m = py::handle(torch_C_module).cast<py::module>();
auto m = _C_m.def_submodule("_autograd", "autograd bindings");

auto m = py::handle(autograd_module).cast<py::module>();

py::enum_<ProfilerState>(m, "ProfilerState")
.value("Disabled", ProfilerState::Disabled)
Expand Down