Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations to torch.onnx.* modules #45258

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 0 additions & 24 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -146,30 +146,6 @@ ignore_errors = True
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
ignore_errors = True

[mypy-torch.onnx.operators]
ignore_errors = True

[mypy-torch.onnx.symbolic_opset8]
ignore_errors = True

[mypy-torch.onnx.symbolic_opset9]
ignore_errors = True

[mypy-torch.onnx.symbolic_opset11]
ignore_errors = True

[mypy-torch.onnx.symbolic_caffe2]
ignore_errors = True

[mypy-torch.onnx.symbolic_helper]
ignore_errors = True

[mypy-torch.onnx.symbolic_registry]
ignore_errors = True

[mypy-torch.onnx.utils]
ignore_errors = True

[mypy-torch.multiprocessing.pool]
ignore_errors = True

Expand Down
68 changes: 66 additions & 2 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def wait(fut: Future) -> Any: ...
def _collect_all(futures: List[Future]) -> Future: ...

def unify_type_list(types: List[JitType]) -> JitType: ...
def _freeze_module(module: ScriptModule, preserved_attrs: List[str], freeze_interfaces: _bool = True) -> ScriptModule: ...
def _freeze_module(module: ScriptModule, preserved_attrs: List[str] = [], freeze_interfaces: _bool = True) -> ScriptModule: ...
def _is_tracing() -> _bool: ...
def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
Expand Down Expand Up @@ -216,6 +216,8 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ...
# Defined in torch/csrc/jit/python/script_init.cpp
ResolutionCallback = Callable[[str], Callable[..., Any]]

# Defined in torch/csrc/jit/python/script_init.cpp
# and torch/csrc/jit/python/init.cpp
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
Expand Down Expand Up @@ -244,6 +246,54 @@ def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallb
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
def _run_emit_module_hook(m: ScriptModule): ...
def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ...

def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ...
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ...
def _jit_pass_fixup_onnx_loop_node_inputs(n: Node) -> None: ...
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ...
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
def _jit_pass_onnx_prepare_inplace_ops_for_onnx(graph: Graph) -> None: ...
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
def _jit_pass_onnx_unpack_quantized_weights(
graph: Graph,
paramsDict: Dict[str, IValue]
) -> Dict[str, IValue]: ...
def _jit_pass_onnx_quantization_insert_permutes(
graph: Graph,
paramsDict: Dict[str, IValue]
) -> Dict[str, IValue]: ...
def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ...
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ...
def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ...
def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ...
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ...
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ...
def _jit_pass_onnx_block(
old_block: Block,
new_block: Block,
operator_export_type: _onnx.OperatorExportTypes,
env: Dict[Value, Value]
) -> None: ...
def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ...

def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
def _jit_script_compile_overload(
qualname: str,
Expand Down Expand Up @@ -279,8 +329,18 @@ def import_ir_module_from_buffer(
extra_files: Dict[str, Any]
) -> ScriptModule: ...

def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
def _check_onnx_proto(proto: str) -> None: ...
def _propagate_and_assign_input_shapes(
graph: Graph,
inputs: Tuple[Tensor, ...],
with_grad: _bool,
propagate: _bool
) -> Graph: ...

# Defined in torch/torch/csrc/jit/ir/ir.h
class Graph:
def eraseInput(self, i: _int) -> None: ...
...

# Defined in torch/csrc/jit/ir/ir.h
Expand Down Expand Up @@ -364,8 +424,8 @@ class ScriptFunction:
def qualified_name(self) -> str: ...

class ScriptMethod:
graph: Graph
...

class ModuleDict:
def __init__(self, mod: ScriptModule) -> None: ...
def items(self) -> List[Tuple[str, Any]]: ...
Expand All @@ -376,6 +436,10 @@ class ParameterDict:
class BufferDict:
def __init__(self, mod: ScriptModule) -> None: ...

# Defined in torch/csrc/jit/api/module.h
class Module:
...

# Defined in torch/csrc/Module.cpp
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
Expand Down
1 change: 1 addition & 0 deletions torch/_C/_onnx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OperatorExportTypes(Enum):
ONNX_ATEN = ...
ONNX_ATEN_FALLBACK = ...
RAW = ...
ONNX_FALLTHROUGH = ...

class TrainingMode(Enum):
EVAL = ...
Expand Down
23 changes: 13 additions & 10 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import warnings
from sys import maxsize as maxsize
from typing import Set

import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
Expand Down Expand Up @@ -125,7 +126,7 @@ def decorator(fn):
def wrapper(g, *args, **kwargs):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] # type: ignore
# only support _outputs in kwargs
assert len(kwargs) <= 1
if len(kwargs) == 1:
Expand Down Expand Up @@ -215,11 +216,11 @@ def _try_get_scalar_type(*args):

def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import _slice
return _slice(g, input, axes, starts, ends)
from torch.onnx.symbolic_opset9 import _slice as _slice9
return _slice9(g, input, axes, starts, ends)
else:
from torch.onnx.symbolic_opset10 import _slice
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
from torch.onnx.symbolic_opset10 import _slice as _slice10
return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)


def _is_fp(value):
Expand Down Expand Up @@ -356,23 +357,24 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_
size = g.op("Concat", *size, axis_i=0)
scale_factor = _interpolate_size_to_scales(g, input, size, dim)
else:
return _unimplemented("Both size and scales are None in __interpolate")
return _unimplemented("interpolate", "Both size and scales are None in __interpolate")
return scale_factor, mode


def _unbind_helper(g, self, dim, _outputs):
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import unbind
else:
from torch.onnx.symbolic_opset11 import unbind
from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
return unbind(g, self, dim, _outputs)


def _scatter_helper(g, self, dim, index, src):
if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import scatter
else:
from torch.onnx.symbolic_opset11 import scatter
# for mypy, scatter was imported two lines above
from torch.onnx.symbolic_opset11 import scatter # type: ignore
return scatter(g, self, dim, index, src)


Expand Down Expand Up @@ -420,7 +422,8 @@ def _index_fill_reshape_helper(g, self, dim, index):
if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import scatter
else:
from torch.onnx.symbolic_opset11 import scatter
# for mypy, scatter was imported two lines above
from torch.onnx.symbolic_opset11 import scatter # type: ignore

if self.type().dim() is None:
return _unimplemented("index_fill", "input rank not accesible")
Expand Down Expand Up @@ -608,4 +611,4 @@ def _cast_func_template(to_i, g, input, non_blocking):

# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX.
_quantized_ops = set()
_quantized_ops: Set[int] = set()
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.onnx.symbolic_opset9 as sym_opset9

from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
from torch.onnx.symbolic_opset9 import _cast_Float
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore

import warnings

Expand Down
9 changes: 6 additions & 3 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented

from typing import Optional

import numpy
import math
import warnings
Expand Down Expand Up @@ -311,7 +313,7 @@ def _maybe_cast_reduce_op_input(g, self):
if dtype is not None:
# pytorch reduce-ops cast all other integral types to int64
if not sym_help._is_fp(self) and not (dtype == 'Long'):
self = _cast_Long(g, self, False)
self = _cast_Long(g, self, False) # type: ignore
return self


Expand Down Expand Up @@ -2085,7 +2087,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first):
# It's really only necessary because those operators expand to something that
# only works with int32 types in Caffe2...
if lengths.type().scalarType() != 'Int':
lengths = _cast_Int(g, lengths, False)
lengths = _cast_Int(g, lengths, False) # type: ignore
return g.op("prim::PackPadded", input, lengths, outputs=2)


Expand Down Expand Up @@ -2429,7 +2431,7 @@ def _get_arange_dtype(dtype):


def masked_fill(g, self, mask, value):
mask = _cast_Bool(g, mask, False)
mask = _cast_Bool(g, mask, False) # type: ignore
value = sym_help._maybe_get_scalar(value)
return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self)

Expand Down Expand Up @@ -2727,6 +2729,7 @@ def as_strided(g, self, sizes, strides, offset=None):
sizes = sym_help._maybe_get_const(sizes, 'is')
rank = len(strides)
self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
ind: Optional[torch.Tensor]
if not sym_help._is_value(sizes):
ind = torch.tensor([0], dtype=torch.long)
for i, (size, stride) in enumerate(zip(sizes, strides)):
Expand Down
5 changes: 3 additions & 2 deletions torch/onnx/symbolic_registry.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import warnings
import importlib
from inspect import getmembers, isfunction
from typing import Dict, Tuple, Any, Union

# The symbolic registry "_registry" is a dictionary that maps operators
# (for a specific domain and opset version) to their symbolic functions.
# An operator is defined by its domain, opset version, and opname.
# The keys are tuples (domain, version), (where domain is a string, and version is an int),
# and the operator's name (string).
# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
_registry = {}
_registry: Dict[Tuple[str, int], Dict] = {}

_symbolic_versions = {}
_symbolic_versions: Dict[Union[int, str], Any] = {}
from torch.onnx.symbolic_helper import _onnx_stable_opsets
for opset_version in _onnx_stable_opsets:
module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
Expand Down
24 changes: 15 additions & 9 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
from typing import Union, Tuple, List


# the flag to tell the user whether it's in the middle of ONNX export or not
Expand Down Expand Up @@ -76,7 +77,7 @@ def export(model, args, f, export_params=True, verbose=False, training=None,
if aten or export_raw_ir:
assert operator_export_type is None
assert aten ^ export_raw_ir
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no mention to OperatorExportTypes.ATEN in the source code. So, I think this value was supposed to be OperatorExportTypes.ONNX_ATEN.

elif operator_export_type is None:
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
Expand Down Expand Up @@ -351,6 +352,7 @@ def _trace_and_get_graph_from_model(model, args):

def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes):
torch_out = None
params: Union[List, Tuple]
if isinstance(model, torch.jit.ScriptModule):
try:
graph = model.forward.graph
Expand Down Expand Up @@ -442,7 +444,7 @@ def _model_to_graph(model, args, verbose=False,
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
params_dict = dict(zip(param_names, params))

if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_originally_training is not defined anywhere. So, I guess this part of the conditional isn't executed.

if training is None or training == TrainingMode.EVAL:
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)

if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
Expand Down Expand Up @@ -476,7 +478,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
if aten or export_raw_ir:
assert operator_export_type is None
assert aten ^ export_raw_ir
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

elif operator_export_type is None:
operator_export_type = OperatorExportTypes.ONNX
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
Expand Down Expand Up @@ -1051,6 +1053,10 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
dims = [1]
isscalar = True
type = type.lower()
tensor: Union[torch.CharTensor, torch.ShortTensor,
torch.IntTensor, torch.LongTensor,
torch.HalfTensor, torch.FloatTensor,
torch.DoubleTensor]
if type == "char":
tensor = torch.CharTensor(*dims)
elif type == "short":
Expand All @@ -1068,7 +1074,7 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
else:
raise ValueError("Unknown type, type should be one of the following strings: "
"char, short, int, long, half, float, double")
tensor.fill_(value)
tensor.fill_(value) # type: ignore
if isscalar:
return g.op("Constant", *args, value_z=tensor, **kwargs)
return g.op("Constant", *args, value_t=tensor, **kwargs)
Expand Down Expand Up @@ -1141,8 +1147,8 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
dynamic_axes[key] = value_dict


torch._C.Graph.op = _graph_op
torch._C.Graph.at = _graph_at
torch._C.Block.op = _block_op
torch._C.Graph.constant = _graph_constant
torch._C.Node.__getitem__ = _node_getitem
torch._C.Graph.op = _graph_op # type: ignore
torch._C.Graph.at = _graph_at # type: ignore
torch._C.Block.op = _block_op # type: ignore
torch._C.Graph.constant = _graph_constant # type: ignore
torch._C.Node.__getitem__ = _node_getitem # type: ignore