Skip to content

Commit

Permalink
Torch onnx (#48980)
Browse files Browse the repository at this point in the history
Summary:
Fixes #45215

This is a follow up PR of #45258 and #48782

Pull Request resolved: #48980

Reviewed By: zhangguanheng66

Differential Revision: D25399823

Pulled By: ezyang

fbshipit-source-id: 798055f4abbbffecdfab0325884193c81addecec
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Dec 9, 2020
1 parent 5450614 commit 34cc77a
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 52 deletions.
24 changes: 0 additions & 24 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -143,30 +143,6 @@ ignore_errors = True
[mypy-torch.nn.intrinsic.qat.modules.conv_fused]
ignore_errors = True

[mypy-torch.onnx.operators]
ignore_errors = True

[mypy-torch.onnx.symbolic_opset8]
ignore_errors = True

[mypy-torch.onnx.symbolic_opset9]
ignore_errors = True

[mypy-torch.onnx.symbolic_opset11]
ignore_errors = True

[mypy-torch.onnx.symbolic_caffe2]
ignore_errors = True

[mypy-torch.onnx.symbolic_helper]
ignore_errors = True

[mypy-torch.onnx.symbolic_registry]
ignore_errors = True

[mypy-torch.onnx.utils]
ignore_errors = True

[mypy-torch.multiprocessing.pool]
ignore_errors = True

Expand Down
72 changes: 70 additions & 2 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def wait(fut: Future) -> Any: ...
def _collect_all(futures: List[Future]) -> Future: ...

def unify_type_list(types: List[JitType]) -> JitType: ...
def _freeze_module(module: ScriptModule, preserved_attrs: List[str], freeze_interfaces: _bool = True) -> ScriptModule: ...
def _freeze_module(module: ScriptModule,
preserved_attrs: List[str] = [],
freeze_interfaces: _bool = True,
preserveParameters: _bool = True) -> ScriptModule: ...
def _is_tracing() -> _bool: ...
def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
Expand Down Expand Up @@ -217,6 +220,8 @@ def _jit_get_trigger_value(trigger_name: str) -> _int: ...
# Defined in torch/csrc/jit/python/script_init.cpp
ResolutionCallback = Callable[[str], Callable[..., Any]]

# Defined in torch/csrc/jit/python/script_init.cpp
# and torch/csrc/jit/python/init.cpp
def _create_function_from_graph(qualname: str, graph: Graph) -> Graph: ...
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
Expand Down Expand Up @@ -246,6 +251,55 @@ def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallb
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
def _run_emit_module_hook(m: ScriptModule): ...
def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ...

def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ...
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ...
def _jit_pass_fixup_onnx_loop_node_inputs(n: Node) -> None: ...
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ...
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
def _jit_pass_peephole(graph: Graph, addmm_fusion_enabled: _bool) -> None: ...
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
def _jit_pass_onnx_prepare_inplace_ops_for_onnx(graph: Graph) -> None: ...
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
def _jit_pass_onnx_unpack_quantized_weights(
graph: Graph,
paramsDict: Dict[str, IValue]
) -> Dict[str, IValue]: ...
def _jit_pass_onnx_quantization_insert_permutes(
graph: Graph,
paramsDict: Dict[str, IValue]
) -> Dict[str, IValue]: ...
def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ...
def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ...
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ...
def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ...
def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ...
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ...
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
def _jit_pass_onnx_node_shape_type_inference(n: Node, opset_version: _int) -> None: ...
def _jit_pass_onnx_block(
old_block: Block,
new_block: Block,
operator_export_type: _onnx.OperatorExportTypes,
env: Dict[Value, Value]
) -> None: ...
def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> Node: ...

def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
def _jit_script_compile_overload(
qualname: str,
Expand Down Expand Up @@ -281,8 +335,18 @@ def import_ir_module_from_buffer(
extra_files: Dict[str, Any]
) -> ScriptModule: ...

def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
def _check_onnx_proto(proto: str) -> None: ...
def _propagate_and_assign_input_shapes(
graph: Graph,
inputs: Tuple[Tensor, ...],
with_grad: _bool,
propagate: _bool
) -> Graph: ...

# Defined in torch/torch/csrc/jit/ir/ir.h
class Graph:
def eraseInput(self, i: _int) -> None: ...
...

# Defined in torch/csrc/jit/ir/ir.h
Expand Down Expand Up @@ -366,8 +430,8 @@ class ScriptFunction:
def qualified_name(self) -> str: ...

class ScriptMethod:
graph: Graph
...

class ModuleDict:
def __init__(self, mod: ScriptModule) -> None: ...
def items(self) -> List[Tuple[str, Any]]: ...
Expand All @@ -378,6 +442,10 @@ class ParameterDict:
class BufferDict:
def __init__(self, mod: ScriptModule) -> None: ...

# Defined in torch/csrc/jit/api/module.h
class Module:
...

# Defined in torch/csrc/Module.cpp
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
Expand Down
1 change: 1 addition & 0 deletions torch/_C/_onnx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OperatorExportTypes(Enum):
ONNX_ATEN = ...
ONNX_ATEN_FALLBACK = ...
RAW = ...
ONNX_FALLTHROUGH = ...

class TrainingMode(Enum):
EVAL = ...
Expand Down
25 changes: 14 additions & 11 deletions torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import warnings
from sys import maxsize as maxsize
from typing import Set

import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
Expand Down Expand Up @@ -125,7 +126,7 @@ def decorator(fn):
def wrapper(g, *args, **kwargs):
# some args may be optional, so the length may be smaller
assert len(arg_descriptors) >= len(args)
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] # type: ignore
# only support _outputs in kwargs
assert len(kwargs) <= 1
if len(kwargs) == 1:
Expand Down Expand Up @@ -232,18 +233,18 @@ def _select_helper(g, self, dim, index, apply_reshape=True):

def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False):
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import _slice
return _slice(g, input, axes, starts, ends)
from torch.onnx.symbolic_opset9 import _slice as _slice9
return _slice9(g, input, axes, starts, ends)
else:
from torch.onnx.symbolic_opset10 import _slice
return _slice(g, input, axes, starts, ends, steps, dynamic_slice)
from torch.onnx.symbolic_opset10 import _slice as _slice10
return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)

def _hardtanh_helper(g, input, min_val, max_val):
if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import hardtanh
return hardtanh(g, input, min_val, max_val)
else:
from torch.onnx.symbolic_opset11 import hardtanh
from torch.onnx.symbolic_opset11 import hardtanh # type: ignore[no-redef]
return hardtanh(g, input, min_val, max_val)

def _is_fp(value):
Expand Down Expand Up @@ -380,23 +381,24 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode , align_
size = g.op("Concat", *size, axis_i=0)
scale_factor = _interpolate_size_to_scales(g, input, size, dim)
else:
return _unimplemented("Both size and scales are None in __interpolate")
return _unimplemented("interpolate", "Both size and scales are None in __interpolate")
return scale_factor, mode


def _unbind_helper(g, self, dim, _outputs):
if _export_onnx_opset_version <= 9:
from torch.onnx.symbolic_opset9 import unbind
else:
from torch.onnx.symbolic_opset11 import unbind
from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
return unbind(g, self, dim, _outputs)


def _scatter_helper(g, self, dim, index, src):
if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import scatter
else:
from torch.onnx.symbolic_opset11 import scatter
# for mypy, scatter was imported two lines above
from torch.onnx.symbolic_opset11 import scatter # type: ignore
return scatter(g, self, dim, index, src)


Expand Down Expand Up @@ -444,7 +446,8 @@ def _index_fill_reshape_helper(g, self, dim, index):
if _export_onnx_opset_version <= 10:
from torch.onnx.symbolic_opset9 import scatter
else:
from torch.onnx.symbolic_opset11 import scatter
# for mypy, scatter was imported two lines above
from torch.onnx.symbolic_opset11 import scatter # type: ignore

if self.type().dim() is None:
return _unimplemented("index_fill", "input rank not accesible")
Expand Down Expand Up @@ -632,4 +635,4 @@ def _cast_func_template(to_i, g, input, non_blocking):

# Global set to store the list of quantized operators in the network.
# This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX.
_quantized_ops = set()
_quantized_ops: Set[int] = set()
2 changes: 1 addition & 1 deletion torch/onnx/symbolic_opset8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.onnx.symbolic_opset9 as sym_opset9

from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
from torch.onnx.symbolic_opset9 import _cast_Float
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore

import warnings

Expand Down
9 changes: 6 additions & 3 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented

from typing import Optional

import numpy
import math
import warnings
Expand Down Expand Up @@ -311,7 +313,7 @@ def _maybe_cast_reduce_op_input(g, self):
if dtype is not None:
# pytorch reduce-ops cast all other integral types to int64
if not sym_help._is_fp(self) and not (dtype == 'Long'):
self = _cast_Long(g, self, False)
self = _cast_Long(g, self, False) # type: ignore
return self


Expand Down Expand Up @@ -2092,7 +2094,7 @@ def _pack_padded_sequence(g, input, lengths, batch_first):
# It's really only necessary because those operators expand to something that
# only works with int32 types in Caffe2...
if lengths.type().scalarType() != 'Int':
lengths = _cast_Int(g, lengths, False)
lengths = _cast_Int(g, lengths, False) # type: ignore
return g.op("prim::PackPadded", input, lengths, outputs=2)


Expand Down Expand Up @@ -2436,7 +2438,7 @@ def _get_arange_dtype(dtype):


def masked_fill(g, self, mask, value):
mask = _cast_Bool(g, mask, False)
mask = _cast_Bool(g, mask, False) # type: ignore
value = sym_help._maybe_get_scalar(value)
return g.op('Where', mask, sym_help._if_scalar_type_as(g, value, self), self)

Expand Down Expand Up @@ -2734,6 +2736,7 @@ def as_strided(g, self, sizes, strides, offset=None):
sizes = sym_help._maybe_get_const(sizes, 'is')
rank = len(strides)
self_1d = g.op("Reshape", self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
ind: Optional[torch.Tensor]
if not sym_help._is_value(sizes):
ind = torch.tensor([0], dtype=torch.long)
for i, (size, stride) in enumerate(zip(sizes, strides)):
Expand Down
5 changes: 3 additions & 2 deletions torch/onnx/symbolic_registry.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import warnings
import importlib
from inspect import getmembers, isfunction
from typing import Dict, Tuple, Any, Union

# The symbolic registry "_registry" is a dictionary that maps operators
# (for a specific domain and opset version) to their symbolic functions.
# An operator is defined by its domain, opset version, and opname.
# The keys are tuples (domain, version), (where domain is a string, and version is an int),
# and the operator's name (string).
# The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
_registry = {}
_registry: Dict[Tuple[str, int], Dict] = {}

_symbolic_versions = {}
_symbolic_versions: Dict[Union[int, str], Any] = {}
from torch.onnx.symbolic_helper import _onnx_stable_opsets
for opset_version in _onnx_stable_opsets:
module = importlib.import_module('torch.onnx.symbolic_opset{}'.format(opset_version))
Expand Down
24 changes: 15 additions & 9 deletions torch/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
from typing import Union, Tuple, List


# the flag to tell the user whether it's in the middle of ONNX export or not
Expand Down Expand Up @@ -76,7 +77,7 @@ def export(model, args, f, export_params=True, verbose=False, training=None,
if aten or export_raw_ir:
assert operator_export_type is None
assert aten ^ export_raw_ir
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
elif operator_export_type is None:
if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
Expand Down Expand Up @@ -351,6 +352,7 @@ def _trace_and_get_graph_from_model(model, args):

def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes):
torch_out = None
params: Union[List, Tuple]
if isinstance(model, torch.jit.ScriptModule):
try:
graph = model.forward.graph
Expand Down Expand Up @@ -442,7 +444,7 @@ def _model_to_graph(model, args, verbose=False,
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
params_dict = dict(zip(param_names, params))

if training is None or training == TrainingMode.EVAL or (training == TrainingMode.PRESERVE and not is_originally_training):
if training is None or training == TrainingMode.EVAL:
params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)

if do_constant_folding and _export_onnx_opset_version in torch.onnx.constant_folding_opset_versions:
Expand Down Expand Up @@ -476,7 +478,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
if aten or export_raw_ir:
assert operator_export_type is None
assert aten ^ export_raw_ir
operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
operator_export_type = OperatorExportTypes.ONNX_ATEN if aten else OperatorExportTypes.RAW
elif operator_export_type is None:
operator_export_type = OperatorExportTypes.ONNX
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
Expand Down Expand Up @@ -1051,6 +1053,10 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
dims = [1]
isscalar = True
type = type.lower()
tensor: Union[torch.CharTensor, torch.ShortTensor,
torch.IntTensor, torch.LongTensor,
torch.HalfTensor, torch.FloatTensor,
torch.DoubleTensor]
if type == "char":
tensor = torch.CharTensor(*dims)
elif type == "short":
Expand All @@ -1068,7 +1074,7 @@ def _graph_constant(g, value, dims, type, *args, **kwargs):
else:
raise ValueError("Unknown type, type should be one of the following strings: "
"char, short, int, long, half, float, double")
tensor.fill_(value)
tensor.fill_(value) # type: ignore
if isscalar:
return g.op("Constant", *args, value_z=tensor, **kwargs)
return g.op("Constant", *args, value_t=tensor, **kwargs)
Expand Down Expand Up @@ -1141,8 +1147,8 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
dynamic_axes[key] = value_dict


torch._C.Graph.op = _graph_op
torch._C.Graph.at = _graph_at
torch._C.Block.op = _block_op
torch._C.Graph.constant = _graph_constant
torch._C.Node.__getitem__ = _node_getitem
torch._C.Graph.op = _graph_op # type: ignore
torch._C.Graph.at = _graph_at # type: ignore
torch._C.Block.op = _block_op # type: ignore
torch._C.Graph.constant = _graph_constant # type: ignore
torch._C.Node.__getitem__ = _node_getitem # type: ignore

0 comments on commit 34cc77a

Please sign in to comment.