Skip to content

Commit

Permalink
[ONNX] Add a post-pass for If folding (#49410)
Browse files Browse the repository at this point in the history
* add pass

* add tests

* enable tests

* update dtype symbolic

* update pass

* clang format

* update utils file

* fix clang_tidy errors

* update pass

* update pass

* clang format

* update pass

* update pass

* update pass

* update pass

* update pass

* update pass

* update pass

* update pass

* fix mypy tests

* update comment

* update pass

* add new line

* update pass

* empty commit

* empty commit

* fix merge conflict
  • Loading branch information
KsenijaS committed Jan 6, 2021
1 parent e9cbf65 commit 33f26ed
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 32 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ namespace c10 {
_(onnx, ReduceL2) \
_(onnx, Conv) \
_(onnx, BatchNormalization) \
_(onnx, ReduceProd) \
FORALL_ATTR_BASE_SYMBOLS(_) \
_(attr, Subgraph) \
_(attr, ReverseSubgraph) \
Expand Down
132 changes: 120 additions & 12 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4355,8 +4355,7 @@ def test_embedding_bag(self):
input = torch.randint(10, (7, 5))
self.run_test(model, (input))

@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_1d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
Expand All @@ -4371,8 +4370,7 @@ def forward(self, embedding_matrix, input, offset, weights):
embedding_matrix = torch.rand(10, 15)
self.run_test(model, (embedding_matrix, x, offset, w))

@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
@skipIfUnsupportedMinOpsetVersion(10)
@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([12, 13]) # Due to ONNX Loop shape inference issue
def test_embedding_bag_2d_per_sample_weights(self):
class EmbeddingModel(torch.nn.Module):
Expand Down Expand Up @@ -4618,6 +4616,124 @@ def run():
self.assertEqual('Unsupported: ONNX export of Pad in opset 9. The sizes of the padding must be constant. ' +
'Please try opset version 11.', the_exception.args[0])

@skipIfUnsupportedMinOpsetVersion(9)
def test_if_fold(self):
class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.dim() == 2:
y = y + 4
y = y + 2
else:
y = y - 1
return y
x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.numel() > 1:
y = y + 4
else:
y = y + 2
return y

x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.dim() != 3:
y = y + 4
y = y + 2
else:
return y
return y

x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.dim() >= 1:
y = y + 4
else:
y = y - 1
return y

x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.dim() <= 1:
y = y + 4
else:
y = y + 2
return y

x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.dim() < 3 and y.dtype == torch.int:
y = y + 4
y = y + 2
else:
return y
return y

x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.dim() == 3 and y.dtype == torch.int:
y = y + 4
y = y + 2
else:
y = y + 1
return y

x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, y):
if y.numel() != 0 and y.dim() == 2:
y = y + 4
y = y + 2
else:
return y
return y

x = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), x)

class IfFoldModel(torch.nn.Module):
def forward(self, x, y):
if x.numel() == y.numel():
y = x + y
else:
y = y - x
return y

x = torch.ones((3, 4), dtype=torch.int)
y = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), (x, y))

class IfFoldModel(torch.nn.Module):
def forward(self, x, y):
if x.numel() != y.numel():
y = x + y
else:
y = y - x
return y

x = torch.ones((3, 4), dtype=torch.int)
y = torch.ones((3, 4), dtype=torch.int)
self.run_test(IfFoldModel(), (x, y))

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfONNXShapeInference(False)
def test_uninitialized(self):
Expand Down Expand Up @@ -4769,7 +4885,6 @@ def forward(self, *tensor_list):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_crossentropyloss(self):
for ignore_index in [-100, 1]:
x = torch.randn(3, 5)
Expand Down Expand Up @@ -4936,7 +5051,6 @@ def forward(self, input, target):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss(self):
class NLLModel(torch.nn.Module):
def __init__(self):
Expand All @@ -4958,7 +5072,6 @@ def forward(self, input, target):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_none(self):
class NLLModel(torch.nn.Module):
def __init__(self):
Expand All @@ -4981,7 +5094,6 @@ def forward(self, input, target):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean(self):
class NLLModel(torch.nn.Module):
def __init__(self):
Expand All @@ -5004,7 +5116,6 @@ def forward(self, input, target):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_sum(self):
class NLLModel(torch.nn.Module):
def __init__(self):
Expand All @@ -5027,7 +5138,6 @@ def forward(self, input, target):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean_weights(self):
class NLLModel(torch.nn.Module):
def __init__(self):
Expand All @@ -5050,7 +5160,6 @@ def forward(self, input, target):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean_ignore_index(self):
class NLLModel(torch.nn.Module):
def __init__(self):
Expand All @@ -5070,7 +5179,6 @@ def forward(self, input, target):

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(12)
@disableScriptTest() # shape/type inference
def test_nllloss_2d_mean_ignore_index_weights(self):
class NLLModel(torch.nn.Module):
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp",
"torch/csrc/jit/passes/onnx/list_model_parameters.cpp",
"torch/csrc/jit/passes/onnx/function_substitution.cpp",
"torch/csrc/jit/passes/onnx/fold_if_node.cpp",
"torch/csrc/jit/passes/onnx/helper.cpp",
"torch/csrc/jit/passes/onnx/peephole.cpp",
"torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp",
Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def _jit_pass_onnx_scalar_type_analysis(graph: Graph) -> None: ...
def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ...
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
def _jit_pass_onnx_fold_if(graph: Graph) -> None: ...
def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
Expand Down

0 comments on commit 33f26ed

Please sign in to comment.