Skip to content

Commit

Permalink
[export] Serialize symbolic values (#103273)
Browse files Browse the repository at this point in the history
* Modified the SymInt schema to also store the hint of the SymInt if it is represented as a symbol so that when we reconstruct the SymInt, the hint will also exist on the node.
* GraphModuleDeserializer.deserialize now also optionally map of symbol names to range.

ReplaceSymSizeOpPass should not be needed after #103107 lands

Pull Request resolved: #103273
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
  • Loading branch information
angelayi authored and pytorchmergebot committed Jun 13, 2023
1 parent 876695d commit 8dc6001
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 60 deletions.
146 changes: 138 additions & 8 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,43 @@

import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export import dynamic_dim, export
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
from torch._export.db.examples import all_examples
from torch._export.serde.serialize import (
ExportedProgramSerializer,
deserialize,
serialize,
)
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import is_concrete_int
import torch.utils._pytree as pytree
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)


def get_filtered_export_db_tests():
unsupported_tags = {"torch.cond", "torch.map"}
unsupported_test_names = {
"dynamic_shape_constructor", # 'NoneType' object has no attribute 'from_tensor'
"dictionary", # Graph output must be a tuple()
"fn_with_kwargs", # export doesn't support kwargs yet
"scalar_output", # Tracing through 'f' must produce a single graph
}

return [
(name, case)
for name, case in all_examples().items()
if (
case.support_level == SupportLevel.SUPPORTED and
not (unsupported_tags & case.tags) and
name not in unsupported_test_names
)
]


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
Expand Down Expand Up @@ -143,12 +172,11 @@ def f(x: torch.Tensor) -> torch.Tensor:

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestDeserialize(TestCase):
def check_graph(self, fn, inputs) -> None:
def check_graph(self, fn, inputs, constraints=None) -> None:
"""Export a graph, serialize it, deserialize it, and compare the results."""
# TODO(angelayi): test better with some sort of wrapper around all
# export tests

ep = export(fn, inputs, [])
# TODO(angelayi): test better with some sort of wrapper
constraints = [] if constraints is None else constraints
ep = export(fn, inputs, constraints)
serialized_struct, state_dict = serialize(ep)
deserialized_ep = deserialize(serialized_struct, state_dict)

Expand All @@ -159,7 +187,38 @@ def check_graph(self, fn, inputs) -> None:
flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)

for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
self.assertTrue(torch.allclose(orig, loaded))
self.assertEqual(type(orig), type(loaded))
if isinstance(orig, torch.Tensor):
self.assertTrue(torch.allclose(orig, loaded))
else:
self.assertEqual(orig, loaded)

for node1, node2 in zip(ep.graph.nodes, deserialized_ep.graph.nodes):
# Check "val" metadata
val1 = node1.meta.get("val", None)
val2 = node2.meta.get("val", None)

if val1 is None or val2 is None:
# Either both are None
self.assertEqual(val1, val2)
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
# Or both are fake tensors with the same shape/dtype
self.assertEqual(len(val1.shape), len(val2.shape))
for s1, s2 in zip(val1.shape, val2.shape):
if is_concrete_int(s1) and is_concrete_int(s2):
self.assertEqual(s1, s2)
else:
self.assertEqual(str(s1), str(s2))
self.assertEqual(val1.dtype, val2.dtype)
elif isinstance(val1, list) and isinstance(val2, list):
# Or both are fake tensors lists with one element and with the
# same shape/dtype
self.assertTrue(len(val1) == 1 and len(val2) == 1)
self.assertEqual(val1[0].shape, val2[0].shape)
self.assertEqual(val1[0].dtype, val2[0].dtype)
else:
# For expressions like 's0 < 10' can only compare through string
self.assertEqual(str(val1), str(val2))

def test_multi_return(self) -> None:
"""
Expand Down Expand Up @@ -199,6 +258,77 @@ def forward(self, x):
inputs = (torch.ones([512], requires_grad=True),)
self.check_graph(MyModule(), inputs)

def test_dynamic(self) -> None:
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
d_s0 = d.shape[0]
d_s1 = d.shape[1]
d_s3 = d_s0 * d_s1
e = d.view(d_s3)
return torch.cat([e, e])


inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
constraints = [
dynamic_dim(inputs[0], 0),
dynamic_dim(inputs[2], 0),
dynamic_dim(inputs[2], 0) == dynamic_dim(inputs[0], 0),
]
self.check_graph(DynamicShapeSimpleModel(), inputs, constraints)

def test_sym_bool(self):
def f(x, y):
return x.size(0) in y

self.check_graph(f, (torch.ones(2), torch.ones(3)))

def test_shape(self):
def f(x):
z, y = x.size()
return z + y + x[0], z

inputs = (torch.ones(2, 3),)
constraints = [
dynamic_dim(inputs[0], 0),
dynamic_dim(inputs[0], 1),
]
self.check_graph(f, inputs, constraints)

def test_module(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(3, 5)

def forward(self, x):
x = self.linear1(x)
x = self.linear1(x)
x = torch.nn.functional.relu(x)
x = self.linear2(x)
return x

inputs = (torch.randn(3, 3),)
self.check_graph(M(), inputs)

@parametrize(
"name,case",
get_filtered_export_db_tests(),
name_fn=lambda name, case: "case_{}".format(name),
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
inputs = normalize_inputs(case.example_inputs)
self.check_graph(model, inputs.args)


instantiate_parametrized_tests(TestDeserialize)


if __name__ == '__main__':
run_tests()
3 changes: 2 additions & 1 deletion torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ExportGraphSignature,
_process_constraints,
)
from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
from torch._decomp import core_aten_decompositions
from torch._dynamo.eval_frame import Constraint
from torch._functorch.aot_autograd import aot_export_module
Expand Down Expand Up @@ -236,7 +237,7 @@ def export(
if _add_runtime_assertions:
exported_program = exported_program._add_runtime_assertions()

return exported_program
return exported_program.transform(_ReplaceSymSizeOpPass())

except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, str(e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class InputDim(NamedTuple):

@dataclass
class RangeConstraint:
min_val: sympy.Integer
max_val: sympy.Integer
min_val: sympy.Expr
max_val: sympy.Expr


def _convert_to_int(val):
Expand Down Expand Up @@ -65,6 +65,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module = copy.deepcopy(graph_module)
graph = graph_module.graph

insert_loc = None
for node in graph.nodes:
if node.op != "placeholder":
continue
insert_loc = node
assert insert_loc is not None

# Add runtime asserts for input shape constraints. We do this after all
# placeholder nodes so that we can handle both (unary) predicates and
# (binary) relations.
Expand All @@ -80,15 +87,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
continue

fake_tensor_shape = node.meta["val"].shape
prev_node = node
for dim, shape in enumerate(fake_tensor_shape):
with graph.inserting_after(prev_node):
with graph.inserting_after(insert_loc):
dim_node = graph.call_function(
torch.ops.aten.sym_size.int, (node, dim)
)
input_dim = InputDim(node.name, dim)
inputdim_to_node[input_dim] = dim_node
prev_node = dim_node
insert_loc = dim_node

if isinstance(shape, SymInt):
# If the shape is dynamic, add range assertions
Expand Down
25 changes: 25 additions & 0 deletions torch/_export/passes/replace_sym_size_ops_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Dict

import torch
from torch.fx.passes.infra.pass_base import PassBase

replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = {
torch.ops.aten.sym_size: torch.ops.aten.sym_size.int,
torch.ops.aten.sym_stride: torch.ops.aten.sym_stride.int,
torch.ops.aten.sym_numel: torch.ops.aten.sym_numel.default,
}


class _ReplaceSymSizeOpPass(PassBase):
"""
Replace torch.ops.aten.sym_size with torch.ops.aten.sym_size.int
and torch.ops.aten.sym_stride with torch.ops.aten.sym_stride.int
"""

def call(self, graph_module):
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.target in replacements:
node.target = replacements[node.target]
18 changes: 16 additions & 2 deletions torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,21 @@ class Device:
index: Optional[int]


@dataclass
class SymExpr:
expr_str: str
hint: Optional[int]


@dataclass
class SymInt(_Union):
as_symbol: str
as_expr: SymExpr
as_int: int


@dataclass
class SymBool(_Union):
as_symbol: str
as_expr: str
as_bool: bool


Expand Down Expand Up @@ -185,6 +191,12 @@ class CallSpec:
out_spec: str


@dataclass
class RangeConstraint:
min_val: int
max_val: int


@dataclass
class GraphModule:
graph: Graph
Expand All @@ -197,3 +209,5 @@ class GraphModule:
class ExportedProgram:
graph_module: GraphModule
opset_version: Dict[str, int]
range_constraints: Dict[str, RangeConstraint]
equality_constraints: List[Tuple[Tuple[str, int], Tuple[str, int]]]

0 comments on commit 8dc6001

Please sign in to comment.