Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] ExportedProgram #102259

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
388 changes: 15 additions & 373 deletions test/export/test_export.py

Large diffs are not rendered by default.

30 changes: 13 additions & 17 deletions test/export/test_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ def f(x: torch.Tensor) -> List[torch.Tensor]:
class NullPass(ExportPassBase):
pass

gm = export(f, (torch.ones(3, 2),)).find_method("forward")
new_gm = NullPass()(gm)
self.assertIsNotNone(new_gm)
new_nodes = new_gm.graph_module.graph.nodes
ep = export(f, (torch.ones(3, 2),))
old_nodes = ep.graph.nodes

ep = ep.transform(NullPass())
new_nodes = ep.graph.nodes

for node in new_nodes:
if node.op != "call_function":
continue
self.assertTrue(hasattr(node, "stack_trace"))
self.assertIsNotNone(node.stack_trace)

old_nodes = gm.graph.nodes
self.assertEqual(len(new_nodes), len(old_nodes))
for new_node, old_node in zip(new_nodes, old_nodes):
self.assertEqual(new_node.op, old_node.op)
Expand Down Expand Up @@ -70,12 +70,10 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
z = x + y
return torch.ops.aten.relu.default(z)

gm = export(
f, (torch.randn(2, 2), torch.randn(2, 2)),
).find_method("forward")
ep = export(f, (torch.randn(2, 2), torch.randn(2, 2)))
FileCheck().check("torch.ops.aten.add.Tensor").check(
"torch.ops.aten.relu.default"
).run(gm.code)
).run(ep.graph_module.code)

class AddReluFusionPass(ExportPassBase):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
Expand All @@ -88,6 +86,7 @@ def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return backend_op(x, y)

subgraph_rewriter.replace_pattern(graph_module, pattern, replacement)
return PassResult(graph_module, True)

class BackendNullPass(ExportPassBase):
def get_valid_dialects(self) -> List[Type]:
Expand Down Expand Up @@ -117,17 +116,15 @@ def call_operator(self, op, args, kwargs, meta):
)
return super().call_operator(op, args, kwargs, meta)

AddReluFusionPass()(gm)
ep = ep.transform(AddReluFusionPass())
FileCheck().check(
"torch.ops.DO_NOT_USE_TEST_ONLY.add_relu.default"
).run(gm.code)
).run(ep.graph_module.code)

new_gm = BackendNullPass()(gm)
self.assertIsNotNone(new_gm)
new_gm = new_gm.graph_module
ep = ep.transform(BackendNullPass())

with self.assertRaisesRegex(ExportPassBaseError, "Expecting op of dialects:"):
_ = BackendViolatePass()(gm)
ep.transform(BackendViolatePass())

def test_cond(self) -> None:
class M(torch.nn.Module):
Expand All @@ -151,9 +148,8 @@ def false_fn(x, y):
x = torch.tensor([2])
y = torch.tensor([5])
mod = M()
gm = export(mod, (torch.tensor(True), x, y)).find_method("forward")
_ = export(mod, (torch.tensor(True), x, y)).transform(ExportPassBase())

ExportPassBase()(gm)

if __name__ == '__main__':
run_tests()
49 changes: 21 additions & 28 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export, dynamic_dim, _export
from torch._export import export, dynamic_dim
from torch._export.constraints import constrain_as_value
from torch._export.passes import (
ConstPropPass,
Expand Down Expand Up @@ -34,18 +34,14 @@ def test_replace_broken_ops(self) -> None:
def f(inp: torch.Tensor) -> torch.Tensor:
return model(inp)

gm = export(f, (x,)).find_method("forward")

new_gm = ReplaceViewOpsWithViewCopyOpsPass()(gm)
self.assertIsNotNone(new_gm)
new_gm = new_gm.graph_module
ep = export(f, (x,)).transform(ReplaceViewOpsWithViewCopyOpsPass())

count_after = 0
for node in new_gm.graph.nodes:
for node in ep.graph.nodes:
if node.target == torch.ops.aten.view.default:
count_after += 1
self.assertEqual(count_after, 0)
self.assertTrue(torch.allclose(gm(x), f(x)))
self.assertTrue(torch.allclose(ep(x), f(x)))

def test_const_prop_pass(self) -> None:
class M(torch.nn.Module):
Expand All @@ -58,18 +54,16 @@ def forward(self, x):
c = torch.cat([self.a, b])
return (c + c) + x

def count_additions(gm) -> int:
def count_additions(ep) -> int:
return sum(
(node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
(node.target == torch.ops.aten.add.Tensor) for node in ep.graph.nodes
)

gm = export(M(), (torch.zeros(2, 2, 3),)).find_method("forward")
self.assertEqual(count_additions(gm), 3)
ep = export(M(), (torch.zeros(2, 2, 3),))
self.assertEqual(count_additions(ep), 3)

new_gm = ConstPropPass()(gm)
self.assertIsNotNone(new_gm)
new_gm = new_gm.graph_module
self.assertEqual(count_additions(new_gm), 1)
ep = ep.transform(ConstPropPass())
self.assertEqual(count_additions(ep), 1)

def test_runtime_assert_one_dim(self) -> None:
class M(torch.nn.Module):
Expand Down Expand Up @@ -112,19 +106,19 @@ def forward(self, x, y):
dynamic_dim(x, 0) >= 3
]

gm = export(M(), (x, y), constraints=constraints).add_runtime_assertions().find_method("forward")
ep = export(M(), (x, y), constraints=constraints).add_runtime_assertions()

num_assert = count_call_function(gm.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(gm.graph, torch.ops.aten.scalar_tensor.default)
num_assert = count_call_function(ep.graph, torch.ops.aten._assert_async.msg)
num_scalar_tensor = count_call_function(ep.graph, torch.ops.aten.scalar_tensor.default)

self.assertEqual(num_assert, 6)
self.assertEqual(num_scalar_tensor, 6)

with self.assertRaisesRegex(RuntimeError, "Input arg0"):
gm(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))

with self.assertRaisesRegex(RuntimeError, "Input arg1"):
gm(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))

def test_runtime_assert_some_dims_not_specified(self) -> None:
class M(torch.nn.Module):
Expand Down Expand Up @@ -214,12 +208,11 @@ def forward(self, x):

x = torch.zeros(4, 2, 3)

gm = _export(M(), (x,))
self.assertEqual(count_call_function(gm.graph, torch.ops.aten.view.default), 1)
ep = export(M(), (x,))
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1)

pass_result = ReplaceViewOpsWithViewCopyOpsPass()(gm)
self.assertTrue(pass_result.modified)
self.assertEqual(count_call_function(pass_result.graph_module.graph, torch.ops.aten.view.default), 0)
ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass())
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 0)

def test_functionalization_with_view_copy(self) -> None:
def foo(x):
Expand All @@ -231,8 +224,8 @@ def foo(x):

ep = export(foo, (x,)).transform(ReplaceViewOpsWithViewCopyOpsPass())
# After this pass, there shouldn't be any view nodes in the graph
self.assertTrue(count_call_function(ep.module.graph, torch.ops.aten.view.default) == 0)
self.assertTrue(count_call_function(ep.module.graph, torch.ops.aten.view_copy.default) > 0)
self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view.default) == 0)
self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0)

def test_views_op_having_view_copy(self) -> None:
schemas = torch._C._dispatch_get_registrations_for_dispatch_key("")
Expand Down
20 changes: 8 additions & 12 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Owner(s): ["module: dynamo"]
import copy
import pickle
import unittest

import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.graph_module import get_export_meta
from torch._export.serialize import convert_fake_tensor_to_tensor_meta, convert_tensor_meta_to_fake_tensor
from torch._subclasses.fake_tensor import FakeTensor
from torch.testing._internal.common_utils import run_tests, TestCase
Expand All @@ -31,15 +31,13 @@ def inner_false_fn(y):
return control_flow.cond(x.shape[0] < 10, true_fn, false_fn, [x])

inputs = (torch.ones(3),)
mmep = export(f, inputs)
gm = mmep.find_method("forward")
gm.print_readable()
ep = export(f, inputs)

# Pickle the ExportGraphModule
pickled_gm = pickle.dumps(convert_fake_tensor_to_tensor_meta(gm)[0])
loaded_gm = convert_tensor_meta_to_fake_tensor(pickle.loads(pickled_gm))
pickled_ep = pickle.dumps(convert_fake_tensor_to_tensor_meta(copy.deepcopy(ep))[0])
loaded_ep = convert_tensor_meta_to_fake_tensor(pickle.loads(pickled_ep))

for node1, node2 in zip(loaded_gm.graph.nodes, gm.graph.nodes):
for node1, node2 in zip(loaded_ep.graph.nodes, ep.graph.nodes):
val1 = node1.meta.get("val", None)
val2 = node2.meta.get("val", None)

Expand All @@ -60,13 +58,11 @@ def inner_false_fn(y):
# For expressions like 's0 < 10' can only compare through string
self.assertEqual(str(val1), str(val2))

self.assertTrue(torch.allclose(loaded_gm(*inputs), gm(*inputs)))
self.assertTrue(torch.allclose(loaded_ep(*inputs), ep(*inputs)))

# Check metadata
orig_meta = get_export_meta(gm)
new_meta = get_export_meta(loaded_gm)
self.assertEqual(orig_meta.in_spec, new_meta.in_spec)
self.assertEqual(orig_meta.out_spec, new_meta.out_spec)
self.assertEqual(ep.call_spec.in_spec, loaded_ep.call_spec.in_spec)
self.assertEqual(ep.call_spec.out_spec, loaded_ep.call_spec.out_spec)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def graph_with_interpreter(*args):
if (shape_env := getattr(fake_mode, "shape_env", None)) is not None:
# Inline constraints added by users correspond to unbacked symbols in shape_env,
new_graph.meta["inline_constraints"] = {
k: v
k: (v.lower, v.upper)
for k, v in shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}
Expand Down