Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] Serialize symbolic values #103273

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
146 changes: 138 additions & 8 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,43 @@

import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export import dynamic_dim, export
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
from torch._export.db.examples import all_examples
from torch._export.serde.serialize import (
ExportedProgramSerializer,
deserialize,
serialize,
)
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import is_concrete_int
import torch.utils._pytree as pytree
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)


def get_filtered_export_db_tests():
unsupported_tags = {"torch.cond", "torch.map"}
unsupported_test_names = {
"dynamic_shape_constructor", # 'NoneType' object has no attribute 'from_tensor'
"dictionary", # Graph output must be a tuple()
"fn_with_kwargs", # export doesn't support kwargs yet
"scalar_output", # Tracing through 'f' must produce a single graph
}

return [
(name, case)
for name, case in all_examples().items()
if (
case.support_level == SupportLevel.SUPPORTED and
not (unsupported_tags & case.tags) and
name not in unsupported_test_names
)
]


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
Expand Down Expand Up @@ -143,12 +172,11 @@ def f(x: torch.Tensor) -> torch.Tensor:

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestDeserialize(TestCase):
def check_graph(self, fn, inputs) -> None:
def check_graph(self, fn, inputs, constraints=None) -> None:
"""Export a graph, serialize it, deserialize it, and compare the results."""
# TODO(angelayi): test better with some sort of wrapper around all
# export tests

ep = export(fn, inputs, [])
# TODO(angelayi): test better with some sort of wrapper
constraints = [] if constraints is None else constraints
ep = export(fn, inputs, constraints)
serialized_struct, state_dict = serialize(ep)
deserialized_ep = deserialize(serialized_struct, state_dict)

Expand All @@ -159,7 +187,38 @@ def check_graph(self, fn, inputs) -> None:
flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)

for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
self.assertTrue(torch.allclose(orig, loaded))
self.assertEqual(type(orig), type(loaded))
if isinstance(orig, torch.Tensor):
self.assertTrue(torch.allclose(orig, loaded))
else:
self.assertEqual(orig, loaded)

for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes):
# Check "val" metadata
val1 = node1.meta.get("val", None)
val2 = node2.meta.get("val", None)

if val1 is None or val2 is None:
# Either both are None
self.assertEqual(val1, val2)
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
# Or both are fake tensors with the same shape/dtype
self.assertEqual(len(val1.shape), len(val2.shape))
for s1, s2 in zip(val1.shape, val2.shape):
if is_concrete_int(s1) and is_concrete_int(s2):
self.assertEqual(s1, s2)
else:
self.assertEqual(str(s1), str(s2))
self.assertEqual(val1.dtype, val2.dtype)
elif isinstance(val1, list) and isinstance(val2, list):
# Or both are fake tensors lists with one element and with the
# same shape/dtype
self.assertTrue(len(val1) == 1 and len(val2) == 1)
self.assertEqual(val1[0].shape, val2[0].shape)
self.assertEqual(val1[0].dtype, val2[0].dtype)
else:
# For expressions like 's0 < 10' can only compare through string
self.assertEqual(str(val1), str(val2))

def test_multi_return(self) -> None:
"""
Expand Down Expand Up @@ -199,6 +258,77 @@ def forward(self, x):
inputs = (torch.ones([512], requires_grad=True),)
self.check_graph(MyModule(), inputs)

def test_dynamic(self) -> None:
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
d_s0 = d.shape[0]
d_s1 = d.shape[1]
d_s3 = d_s0 * d_s1
e = d.view(d_s3)
return torch.cat([e, e])


inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
constraints = [
dynamic_dim(inputs[0], 0),
dynamic_dim(inputs[2], 0),
dynamic_dim(inputs[2], 0) == dynamic_dim(inputs[0], 0),
]
self.check_graph(DynamicShapeSimpleModel(), inputs, constraints)

def test_sym_bool(self):
def f(x, y):
return x.size(0) in y

self.check_graph(f, (torch.ones(2), torch.ones(3)))

def test_shape(self):
def f(x):
z, y = x.size()
return z + y + x[0], z

inputs = (torch.ones(2, 3),)
constraints = [
dynamic_dim(inputs[0], 0),
dynamic_dim(inputs[0], 1),
]
self.check_graph(f, inputs, constraints)

def test_module(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)

def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return x

inputs = (torch.randn(3, 3),)
self.check_graph(M(), inputs)

@parametrize(
"name,case",
get_filtered_export_db_tests(),
name_fn=lambda name, case: "case_{}".format(name),
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
inputs = normalize_inputs(case.example_inputs)
self.check_graph(model, inputs.args)


instantiate_parametrized_tests(TestDeserialize)


if __name__ == '__main__':
run_tests()
3 changes: 2 additions & 1 deletion torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ExportGraphSignature,
_process_constraints,
)
from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
from torch._decomp import core_aten_decompositions
from torch._dynamo.eval_frame import Constraint
from torch._functorch.aot_autograd import aot_export_module
Expand Down Expand Up @@ -236,7 +237,7 @@ def export(
if _add_runtime_assertions:
exported_program = exported_program._add_runtime_assertions()

return exported_program
return exported_program.transform(_ReplaceSymSizeOpPass())

except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, str(e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class InputDim(NamedTuple):

@dataclass
class RangeConstraint:
min_val: sympy.Integer
max_val: sympy.Integer
min_val: sympy.Expr
max_val: sympy.Expr


def _convert_to_int(val):
Expand Down Expand Up @@ -65,6 +65,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module = copy.deepcopy(graph_module)
graph = graph_module.graph

insert_loc = None
for node in graph.nodes:
if node.op != "placeholder":
continue
insert_loc = node
assert insert_loc is not None

# Add runtime asserts for input shape constraints. We do this after all
# placeholder nodes so that we can handle both (unary) predicates and
# (binary) relations.
Expand All @@ -80,15 +87,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
continue

fake_tensor_shape = node.meta["val"].shape
prev_node = node
for dim, shape in enumerate(fake_tensor_shape):
with graph.inserting_after(prev_node):
with graph.inserting_after(insert_loc):
dim_node = graph.call_function(
torch.ops.aten.sym_size.int, (node, dim)
)
input_dim = InputDim(node.name, dim)
inputdim_to_node[input_dim] = dim_node
prev_node = dim_node
insert_loc = dim_node

if isinstance(shape, SymInt):
# If the shape is dynamic, add range assertions
Expand Down
25 changes: 25 additions & 0 deletions torch/_export/passes/replace_sym_size_ops_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Dict

import torch
from torch.fx.passes.infra.pass_base import PassBase

replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = {
torch.ops.aten.sym_size: torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default,
}


class _ReplaceSymSizeOpPass(PassBase):
"""
Replace torch.ops.aten.sym_size with torch.ops.aten.sym_size.int
and torch.ops.aten.sym_stride with torch.ops.aten.sym_stride.int
"""

def call(self, graph_module):
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.target in replacements:
node.target = replacements[node.target]
18 changes: 16 additions & 2 deletions torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,21 @@ class Device:
index: Optional[int]


@dataclass
class SymExpr:
expr_str: str
hint: Optional[int]


@dataclass
class SymInt(_Union):
as_symbol: str
as_expr: SymExpr
angelayi marked this conversation as resolved.
Show resolved Hide resolved
as_int: int


@dataclass
class SymBool(_Union):
as_symbol: str
as_expr: str
as_bool: bool


Expand Down Expand Up @@ -185,6 +191,12 @@ class CallSpec:
out_spec: str


@dataclass
class RangeConstraint:
min_val: int
max_val: int


@dataclass
class GraphModule:
graph: Graph
Expand All @@ -197,3 +209,5 @@ class GraphModule:
class ExportedProgram:
graph_module: GraphModule
opset_version: Dict[str, int]
range_constraints: Dict[str, RangeConstraint]
equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]