Skip to content

Commit

Permalink
[export] Serialize constrain_as_size ops
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Aug 17, 2023
1 parent aa9f6a4 commit c7027cc
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
9 changes: 9 additions & 0 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch._dynamo as torchdynamo
from torch._export import dynamic_dim, export
from torch._export.constraints import constrain_as_size
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
from torch._export.db.examples import all_examples
from torch._export.serde.serialize import (
Expand Down Expand Up @@ -366,6 +367,14 @@ def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
inputs = normalize_inputs(case.example_inputs)
self.check_graph(model, inputs.args)

def test_constraints(self):
def f(x, y):
n = x.item()
constrain_as_size(n, min=2)
return y.sum() + torch.ones(n, 5).sum()

self.check_graph(f, (torch.tensor(3), torch.randn(4, 5)))


instantiate_parametrized_tests(TestDeserialize)

Expand Down
28 changes: 27 additions & 1 deletion torch/_export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,39 @@ def transform(self, *passes: PassType) -> "ExportedProgram":
res = pm(self.graph_module)
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None

def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
) -> Dict[sympy.Symbol, RangeConstraint]:
def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env

shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: RangeConstraint(v.lower, v.upper) for k, v in shape_env.var_to_range.items()
}
return range_constraints

transformed_ep = ExportedProgram(
transformed_gm,
transformed_gm.graph,
copy.deepcopy(self.graph_signature),
copy.deepcopy(self.call_spec),
self.state_dict,
copy.deepcopy(self.range_constraints),
_get_updated_range_constraints(transformed_gm),
copy.deepcopy(self.equality_constraints),
copy.deepcopy(self._module_call_graph),
)
Expand Down
4 changes: 2 additions & 2 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,9 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
return []
if _is_single_tensor_return(node.target):
return [Argument.create(as_tensor=self.serialize_tensor_output(node.name, meta_val))]
elif len(returns) == 1 and isinstance(returns[0].real_type, torch.SymIntType): # type: ignore[attr-defined]
elif len(returns) == 1 and isinstance(meta_val, torch.SymInt):
return [Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))]
elif len(returns) == 1 and isinstance(node.meta["val"], torch.SymBool):
elif len(returns) == 1 and isinstance(meta_val, torch.SymBool):
return [Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))]

# There are a two possibilities at this point:
Expand Down

0 comments on commit c7027cc

Please sign in to comment.