Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] Add run_decomposition() function to ExportedProgram #110236

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch._dynamo as torchdynamo
from functorch.experimental.control_flow import map, cond
from torch import Tensor
from torch.export import Constraint
from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, export, capture_pre_autograd_graph
from torch.export import Constraint, Dim, export
from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, capture_pre_autograd_graph, _export
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._export.utils import (
get_buffer,
Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(self, x, y):

orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
ep = export(
ep = _export(
orig_eager,
inps,
{},
Expand Down Expand Up @@ -1449,5 +1449,61 @@ def forward(self, x):
):
exported_program(torch.rand(2, 3), torch.rand(2, 3))

def test_export_decomps_simple(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin(x)

inp = (torch.randn(5, 10),)
m = M()
with unittest.mock.patch("torch._export.DECOMP_TABLE", None):
ep = export(m, inp)


FileCheck().check_count(
"torch.ops.aten.t.default", 1, exactly=True
).run(ep.graph_module.code)
self.assertTrue(torch.allclose(ep(*inp), m(*inp)))

core_aten_ep = ep.run_decompositions()
FileCheck().check_count(
"torch.ops.aten.permute.default", 1, exactly=True
).run(core_aten_ep.graph_module.code)
FileCheck().check_count(
"torch.ops.aten.t.default", 0, exactly=True
).run(core_aten_ep.graph_module.code)
self.assertTrue(torch.allclose(core_aten_ep(*inp), m(*inp)))

def test_export_decomps_dynamic(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am all for having more tests. Just a bit curious why dynamic shape affects run-decomposition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to check that the dynamic shapes were preserved through aot export.

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin(x)

inp = (torch.randn(5, 10),)
m = M()
with unittest.mock.patch("torch._export.DECOMP_TABLE", None):
ep = export(m, inp, dynamic_shapes={"x": {0: Dim("batch")}})

core_aten_ep = ep.run_decompositions()

input_node = [node for node in core_aten_ep.graph.nodes if node.op == "placeholder"][-1]
self.assertTrue(isinstance(input_node.meta["val"].shape[0], torch.SymInt))

FileCheck().check_count(
"torch.ops.aten.permute.default", 1, exactly=True
).run(core_aten_ep.graph_module.code)
FileCheck().check_count(
"torch.ops.aten.t.default", 0, exactly=True
).run(core_aten_ep.graph_module.code)
self.assertTrue(torch.allclose(core_aten_ep(*inp), m(*inp)))

if __name__ == '__main__':
run_tests()
7 changes: 5 additions & 2 deletions torch/_export/passes/lift_constant_tensor_pass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
from torch._export import ExportedProgram
from torch._guards import detect_fake_mode


def lift_constant_tensor_pass(ep: ExportedProgram) -> ExportedProgram:
def lift_constant_tensor_pass(ep):
"""
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
with the constant tensors as buffers.
"""
if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0:
return ep

Expand Down
198 changes: 169 additions & 29 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,43 +422,152 @@ def module(self) -> torch.nn.Module:

return unlift_exported_program_lifted_states(self)

def _transform(self, *passes: PassType) -> "ExportedProgram":
def run_decompositions(
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
) -> "ExportedProgram":
"""
Run a set of decompositions on the exported program and returns a new
exported program. By default we will run the Core ATen decompositions to
get the Core ATen IR.

For now, we do not decompose joint graphs.
"""
from torch._decomp import core_aten_decompositions
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
RangeConstraint,
_AddRuntimeAssertionsForInlineConstraintsPass,
InputDim,
)
from torch._export.passes.lift_constant_tensor_pass import (
lift_constant_tensor_pass,
)
from torch._export.passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
from torch._functorch.aot_autograd import aot_export_module

def _get_placeholders(gm):
placeholders = []
for node in gm.graph.nodes:
if node.op != "placeholder":
break
placeholders.append(node)
return placeholders

decomp_table = decomp_table or core_aten_decompositions()

old_placeholders = _get_placeholders(self.graph_module)
fake_args = [node.meta["val"] for node in old_placeholders]

gm, graph_signature = aot_export_module(
self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False
)

# Update the signatures with the new placeholder names in case they
# changed when calling aot_export
new_placeholders = _get_placeholders(gm)
assert len(new_placeholders) == len(old_placeholders)
old_new_placeholder_map = {
old_node.name: new_node.name
for old_node, new_node in zip(old_placeholders, new_placeholders)
}
old_outputs = list(self.graph.nodes)[-1].args[0]
new_outputs = list(gm.graph.nodes)[-1].args[0]
assert len(new_outputs) == len(old_outputs)
old_new_output_map = {
old_node.name: new_node.name
for old_node, new_node in zip(old_outputs, new_outputs)
}

new_backward_signature = (
ExportBackwardSignature(
copy.deepcopy(
self.graph_signature.backward_signature.gradients_to_parameters
),
{
old_new_placeholder_map[inp]: param
for inp, param in self.graph_signature.backward_signature.gradients_to_parameters
},
copy.deepcopy(self.graph_signature.backward_signature.loss_output),
)
if self.graph_signature.backward_signature is not None
else None
)

new_graph_signature = ExportGraphSignature(
copy.deepcopy(self.graph_signature.parameters),
copy.deepcopy(self.graph_signature.buffers),
[old_new_placeholder_map[inp] for inp in self.graph_signature.user_inputs],
[old_new_output_map[out] for out in self.graph_signature.user_outputs],
{
old_new_placeholder_map[inp]: param
for inp, param in self.graph_signature.inputs_to_parameters.items()
},
{
old_new_placeholder_map[inp]: buffer
for inp, buffer in self.graph_signature.inputs_to_buffers.items()
},
copy.deepcopy(self.graph_signature.buffers_to_mutate),
new_backward_signature,
copy.deepcopy(self.graph_signature.assertion_dep_token),
)

# NOTE: aot_export adds symint metadata for placeholders with int
# values; since these become specialized, we replace such metadata with
# the original values.
# Also, set the param/buffer metadata back to the placeholders.
for old_node, new_node in zip(old_placeholders, new_placeholders):
if not isinstance(old_node.meta["val"], torch.Tensor):
new_node.meta["val"] = old_node.meta["val"]

if (
new_node.target in new_graph_signature.inputs_to_parameters
or new_node.target in new_graph_signature.inputs_to_buffers
):
for k, v in old_node.meta.items():
new_node.meta[k] = v

# TODO unfortunately preserving graph-level metadata is not
# working well with aot_export. So we manually copy it.
# (The node-level meta is addressed above.)
gm.meta.update(self.graph_module.meta)

new_range_constraints = _get_updated_range_constraints(gm)

new_equality_constraints = [
(
InputDim(old_new_placeholder_map[inp_dim1.input_name], inp_dim1.dim),
InputDim(old_new_placeholder_map[inp_dim2.input_name], inp_dim2.dim),
)
for inp_dim1, inp_dim2 in self.equality_constraints
]

exported_program = ExportedProgram(
gm,
gm.graph,
new_graph_signature,
copy.deepcopy(self.call_spec),
self.state_dict,
new_range_constraints,
new_equality_constraints,
copy.deepcopy(self.module_call_graph),
self.example_inputs,
self.dialect,
)

if len(new_range_constraints) > 0 or len(new_equality_constraints) > 0:
exported_program = exported_program._transform(
_AddRuntimeAssertionsForInlineConstraintsPass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also not sure if this is 100% needed. If it's possible that some decomposed operator introduces an unbacked symint, then we will also need to add assertions for it 😅

new_range_constraints, new_equality_constraints
)
)
exported_program = lift_constant_tensor_pass(exported_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So basically anytime we run aot_export, we need to re-run lift_constant_tensor_pass and _ReplaceSymSizeOpPass()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is 100% needed. Basically, if it's possible that some decomposed operator introduces a constant tensor, then we will need to lift it.


return exported_program._transform(_ReplaceSymSizeOpPass())

def _transform(self, *passes: PassType) -> "ExportedProgram":
pm = PassManager(list(passes))
res = pm(self.graph_module)
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None

def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
) -> Dict[sympy.Symbol, RangeConstraint]:
def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env

shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: RangeConstraint(v.lower, v.upper)
for k, v in shape_env.var_to_range.items()
}
return range_constraints

def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -562,3 +671,34 @@ def _validate(self):
if not isinstance(gm, torch.fx.GraphModule):
continue
verifier.check_valid(self.graph_module)


def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
) -> Dict[sympy.Symbol, Any]:
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
RangeConstraint,
)

def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env

shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: RangeConstraint(v.lower, v.upper) for k, v in shape_env.var_to_range.items()
}
return range_constraints
Loading