Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Remove unnecessary deepcopy on args in 'DynamoExport' #104736

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import diagnostics
from torch.testing._internal import common_utils
from torch._subclasses import fake_tensor


def assert_has_diagnostics(
Expand Down Expand Up @@ -68,6 +69,16 @@ def func(x):
tensor_x = torch.randn(1, 1, 2)
_ = dynamo_export(func, tensor_x, export_options=self.export_options)

def test_args_used_for_export_is_not_converted_to_fake_tensors(self):
def func(x, y):
return x + y

tensor_x = torch.randn(1, 1, 2)
tensor_y = torch.randn(1, 1, 2)
_ = dynamo_export(func, tensor_x, tensor_y, export_options=self.export_options)
self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor)
self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor)

def test_mnist(self):
class MNISTModel(nn.Module):
def __init__(self):
Expand Down
6 changes: 1 addition & 5 deletions torch/onnx/_internal/fx/dynamo_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ def generate_fx(
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
) -> torch.fx.GraphModule:
# args will be converted to symbolic tensor. Let's copy to avoid side effects.
args = copy.deepcopy(model_args)
kwargs = copy.deepcopy(model_kwargs)

# `dynamo.export` does not recognize custom user defined classes as output type.
# Apply wrapper to adapt the outputs back to `dynamo.export` compatible types,
# i.e. :class:`torch.Tensor`.
Expand All @@ -194,7 +190,7 @@ def generate_fx(
#
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
graph_module, graph_guard = torch._dynamo.export(
wrapped_model, *args, tracing_mode=fx_mode, **kwargs
wrapped_model, *model_args, tracing_mode=fx_mode, **model_kwargs
)
del graph_guard # Unused
torch._dynamo.reset()
Expand Down