-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[export] Add run_decomposition() function to ExportedProgram #110236
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -422,43 +422,152 @@ def module(self) -> torch.nn.Module: | |
|
||
return unlift_exported_program_lifted_states(self) | ||
|
||
def _transform(self, *passes: PassType) -> "ExportedProgram": | ||
def run_decompositions( | ||
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None | ||
) -> "ExportedProgram": | ||
""" | ||
Run a set of decompositions on the exported program and returns a new | ||
exported program. By default we will run the Core ATen decompositions to | ||
get the Core ATen IR. | ||
|
||
For now, we do not decompose joint graphs. | ||
""" | ||
from torch._decomp import core_aten_decompositions | ||
from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( | ||
RangeConstraint, | ||
_AddRuntimeAssertionsForInlineConstraintsPass, | ||
InputDim, | ||
) | ||
from torch._export.passes.lift_constant_tensor_pass import ( | ||
lift_constant_tensor_pass, | ||
) | ||
from torch._export.passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass | ||
from torch._functorch.aot_autograd import aot_export_module | ||
|
||
def _get_placeholders(gm): | ||
placeholders = [] | ||
for node in gm.graph.nodes: | ||
if node.op != "placeholder": | ||
break | ||
placeholders.append(node) | ||
return placeholders | ||
|
||
decomp_table = decomp_table or core_aten_decompositions() | ||
|
||
old_placeholders = _get_placeholders(self.graph_module) | ||
fake_args = [node.meta["val"] for node in old_placeholders] | ||
|
||
gm, graph_signature = aot_export_module( | ||
self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False | ||
) | ||
|
||
# Update the signatures with the new placeholder names in case they | ||
# changed when calling aot_export | ||
new_placeholders = _get_placeholders(gm) | ||
assert len(new_placeholders) == len(old_placeholders) | ||
old_new_placeholder_map = { | ||
old_node.name: new_node.name | ||
for old_node, new_node in zip(old_placeholders, new_placeholders) | ||
} | ||
old_outputs = list(self.graph.nodes)[-1].args[0] | ||
new_outputs = list(gm.graph.nodes)[-1].args[0] | ||
assert len(new_outputs) == len(old_outputs) | ||
old_new_output_map = { | ||
old_node.name: new_node.name | ||
for old_node, new_node in zip(old_outputs, new_outputs) | ||
} | ||
|
||
new_backward_signature = ( | ||
ExportBackwardSignature( | ||
copy.deepcopy( | ||
self.graph_signature.backward_signature.gradients_to_parameters | ||
), | ||
{ | ||
old_new_placeholder_map[inp]: param | ||
for inp, param in self.graph_signature.backward_signature.gradients_to_parameters | ||
}, | ||
copy.deepcopy(self.graph_signature.backward_signature.loss_output), | ||
) | ||
if self.graph_signature.backward_signature is not None | ||
else None | ||
) | ||
|
||
new_graph_signature = ExportGraphSignature( | ||
copy.deepcopy(self.graph_signature.parameters), | ||
copy.deepcopy(self.graph_signature.buffers), | ||
[old_new_placeholder_map[inp] for inp in self.graph_signature.user_inputs], | ||
[old_new_output_map[out] for out in self.graph_signature.user_outputs], | ||
{ | ||
old_new_placeholder_map[inp]: param | ||
for inp, param in self.graph_signature.inputs_to_parameters.items() | ||
}, | ||
{ | ||
old_new_placeholder_map[inp]: buffer | ||
for inp, buffer in self.graph_signature.inputs_to_buffers.items() | ||
}, | ||
copy.deepcopy(self.graph_signature.buffers_to_mutate), | ||
new_backward_signature, | ||
copy.deepcopy(self.graph_signature.assertion_dep_token), | ||
) | ||
|
||
# NOTE: aot_export adds symint metadata for placeholders with int | ||
# values; since these become specialized, we replace such metadata with | ||
# the original values. | ||
# Also, set the param/buffer metadata back to the placeholders. | ||
for old_node, new_node in zip(old_placeholders, new_placeholders): | ||
if not isinstance(old_node.meta["val"], torch.Tensor): | ||
new_node.meta["val"] = old_node.meta["val"] | ||
|
||
if ( | ||
new_node.target in new_graph_signature.inputs_to_parameters | ||
or new_node.target in new_graph_signature.inputs_to_buffers | ||
): | ||
for k, v in old_node.meta.items(): | ||
new_node.meta[k] = v | ||
|
||
# TODO unfortunately preserving graph-level metadata is not | ||
# working well with aot_export. So we manually copy it. | ||
# (The node-level meta is addressed above.) | ||
gm.meta.update(self.graph_module.meta) | ||
|
||
new_range_constraints = _get_updated_range_constraints(gm) | ||
|
||
new_equality_constraints = [ | ||
( | ||
InputDim(old_new_placeholder_map[inp_dim1.input_name], inp_dim1.dim), | ||
InputDim(old_new_placeholder_map[inp_dim2.input_name], inp_dim2.dim), | ||
) | ||
for inp_dim1, inp_dim2 in self.equality_constraints | ||
] | ||
|
||
exported_program = ExportedProgram( | ||
gm, | ||
gm.graph, | ||
new_graph_signature, | ||
copy.deepcopy(self.call_spec), | ||
self.state_dict, | ||
new_range_constraints, | ||
new_equality_constraints, | ||
copy.deepcopy(self.module_call_graph), | ||
self.example_inputs, | ||
self.dialect, | ||
) | ||
|
||
if len(new_range_constraints) > 0 or len(new_equality_constraints) > 0: | ||
exported_program = exported_program._transform( | ||
_AddRuntimeAssertionsForInlineConstraintsPass( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also not sure if this is 100% needed. If it's possible that some decomposed operator introduces an unbacked symint, then we will also need to add assertions for it 😅 |
||
new_range_constraints, new_equality_constraints | ||
) | ||
) | ||
exported_program = lift_constant_tensor_pass(exported_program) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So basically anytime we run aot_export, we need to re-run lift_constant_tensor_pass and _ReplaceSymSizeOpPass()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this is 100% needed. Basically, if it's possible that some decomposed operator introduces a constant tensor, then we will need to lift it. |
||
|
||
return exported_program._transform(_ReplaceSymSizeOpPass()) | ||
|
||
def _transform(self, *passes: PassType) -> "ExportedProgram": | ||
pm = PassManager(list(passes)) | ||
res = pm(self.graph_module) | ||
transformed_gm = res.graph_module if res is not None else self.graph_module | ||
assert transformed_gm is not None | ||
|
||
def _get_updated_range_constraints( | ||
gm: torch.fx.GraphModule, | ||
) -> Dict[sympy.Symbol, RangeConstraint]: | ||
def get_shape_env(gm): | ||
vals = [ | ||
node.meta["val"] | ||
for node in gm.graph.nodes | ||
if node.meta.get("val", None) is not None | ||
] | ||
from torch._guards import detect_fake_mode | ||
|
||
fake_mode = detect_fake_mode(vals) | ||
if fake_mode is not None: | ||
return fake_mode.shape_env | ||
for v in vals: | ||
if isinstance(v, torch.SymInt): | ||
return v.node.shape_env | ||
|
||
shape_env = get_shape_env(gm) | ||
if shape_env is None: | ||
return {} | ||
range_constraints = { | ||
k: RangeConstraint(v.lower, v.upper) | ||
for k, v in shape_env.var_to_range.items() | ||
} | ||
return range_constraints | ||
|
||
def _get_updated_graph_signature( | ||
old_signature: ExportGraphSignature, | ||
new_gm: torch.fx.GraphModule, | ||
|
@@ -562,3 +671,34 @@ def _validate(self): | |
if not isinstance(gm, torch.fx.GraphModule): | ||
continue | ||
verifier.check_valid(self.graph_module) | ||
|
||
|
||
def _get_updated_range_constraints( | ||
gm: torch.fx.GraphModule, | ||
) -> Dict[sympy.Symbol, Any]: | ||
from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( | ||
RangeConstraint, | ||
) | ||
|
||
def get_shape_env(gm): | ||
vals = [ | ||
node.meta["val"] | ||
for node in gm.graph.nodes | ||
if node.meta.get("val", None) is not None | ||
] | ||
from torch._guards import detect_fake_mode | ||
|
||
fake_mode = detect_fake_mode(vals) | ||
if fake_mode is not None: | ||
return fake_mode.shape_env | ||
for v in vals: | ||
if isinstance(v, torch.SymInt): | ||
return v.node.shape_env | ||
|
||
shape_env = get_shape_env(gm) | ||
if shape_env is None: | ||
return {} | ||
range_constraints = { | ||
k: RangeConstraint(v.lower, v.upper) for k, v in shape_env.var_to_range.items() | ||
} | ||
return range_constraints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am all for having more tests. Just a bit curious why dynamic shape affects run-decomposition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanted to check that the dynamic shapes were preserved through aot export.