Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/_passes/remove_redundant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RemoveRedundantOpsTransform(ExportPass):
exir_ops.edge.aten.alias.default,
exir_ops.edge.aten.lift_fresh_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

def __init__(self) -> None:
Expand Down
26 changes: 26 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,32 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
)


@update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default)
def register_clone_dim_order_op():
# Similar to to_dim_order_copy, _clone_dim_order can be removed as long as the
# operator is not changing the dtype, i.e. the operator call is modifying the dim
# order only. Therefore, check that the input and output dtypes are the same, if so
# the operator is safe to remove.
def check_clone_dim_order_node(node: torch.fx.Node) -> bool:
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False

in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)

if in_tensor.dtype != out_tensor.dtype:
return False

return True

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
are_node_inputs_supported_fn=check_clone_dim_order_node,
)


@update_features(
[
exir_ops.edge.aten.bmm.default,
Expand Down
19 changes: 14 additions & 5 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,11 @@ def create_constant_nodes_and_return_specs(
return name_to_spec_dict


def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
# add _skip_dim_order to ensure the introduced correct clone node for different dim order schema
# TODO(gasoonjia): only relying on _clone_dim_order once we remove _skip_dim_order option in the EdgeCompileConfig
def _update_output_node_and_specs(
exported_program: ExportedProgram, _skip_dim_order: bool
) -> None:
"""
Update the output node and output specs in the exported program.
In case a constant node is used as output, we replace it with a clone of the constant node.
Expand All @@ -307,15 +311,19 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
output_specs = exported_program.graph_signature.output_specs
assert len(output_nodes) == len(output_specs)

clone_op = (
exir_ops.edge.aten.clone.default
if _skip_dim_order
else exir_ops.edge.dim_order_ops._clone_dim_order.default
)

for i in range(len(output_specs)):
out_node = output_nodes[i]
if out_node not in updated_constant_placeholders:
continue

with exported_program.graph.inserting_after(out_node):
new_node = exported_program.graph.call_function(
exir_ops.edge.aten.clone.default, (out_node,)
)
new_node = exported_program.graph.call_function(clone_op, (out_node,))
assert "val" in out_node.meta
new_node.meta["val"] = out_node.meta["val"]
output_nodes[i] = new_node
Expand All @@ -329,6 +337,7 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
def constant_prop_pass(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
_skip_dim_order: bool = True,
) -> ExportedProgram:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
Expand Down Expand Up @@ -376,7 +385,7 @@ def constant_prop_pass(
new_input_specs.append(name_to_spec_dict[node.name])
exported_program.graph_signature.input_specs = new_input_specs

_update_output_node_and_specs(exported_program)
_update_output_node_and_specs(exported_program, _skip_dim_order=_skip_dim_order)

# Cleanup the graph.
exported_program.graph.eliminate_dead_code()
Expand Down
5 changes: 1 addition & 4 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)

# TODO - these passes are too specialized on a single to_copy op.
# We should be able to replace (or revert) any of the dim_order ops in the future.


class MemoryFormatOpsPass(ExportPass):
"""
Expand All @@ -43,7 +40,7 @@ def call_operator(self, op, args, kwargs, meta):
# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

# get the "to" memory format for the EdgeOp
# get the target memory format for the EdgeOp
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)

# can always get the shape, assuming rank is specialized
Expand Down
2 changes: 1 addition & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ def forward(self) -> torch.Tensor:
)

edge._edge_programs["forward"] = constant_prop_pass(
edge.exported_program("forward")
edge.exported_program("forward"), _skip_dim_order=False
)

# Check (c_lifted_tensor_*) nodes are all replaced by _prop_tensor_constant.
Expand Down
Loading