Skip to content

Commit

Permalink
[ONNX] Drop 'aten_graph' arg for 'DynamoExporter'
Browse files Browse the repository at this point in the history
ghstack-source-id: b15d9612edec3ebd2917ad381e5c52de79416d1b
Pull Request resolved: #99667
  • Loading branch information
BowenBao committed Apr 20, 2023
1 parent 5315317 commit e032366
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
18 changes: 11 additions & 7 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,17 @@ def forward(self, x):

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(self, SigmoidModel(), (x,))

@pytorch_test_common.xfail(
"RuntimeError: false INTERNAL ASSERT FAILED at "
"'/home/titaiwang/pytorch/build/aten/src/ATen/RegisterFunctionalization_0.cpp':3725,"
" please report a bug to PyTorch. mutating a non-functional tensor with a "
"functional tensor is not allowed. Please ensure that all of your inputs are "
"wrapped inside of a functionalize() call."
)
@skip_if_no_torchvision
def test_resnet18(self):
model = torchvision.models.resnet18(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224)

_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
self,
model,
(dummy_input,),
)

@skip_if_no_torchvision
def test_shufflenet_v2(self):
model = torchvision.models.shufflenet_v2_x0_5(pretrained=False)
Expand Down
5 changes: 1 addition & 4 deletions torch/onnx/_internal/fx/dynamo_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def export(self) -> torch.onnx.ExportOutput:
# TODO(wechi): There are several symbolic tracing mechanisms to convert
# nn.Module to FX graph. We should choose the right one after they are
# matured.
# TODO(titaiwang): Set `tracing_mode` according to `self.options.dynamic_shapes`
graph_module, graph_guard = torch._dynamo.export(
wrapped_model, *args, aten_graph=True, **kwargs
)
graph_module, graph_guard = torch._dynamo.export(wrapped_model, *args, **kwargs)
del graph_guard # Unused
torch._dynamo.reset()

Expand Down

0 comments on commit e032366

Please sign in to comment.