From 629c93e4e9dfd44cec933d2320e64e786a992cfb Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 4 Dec 2024 19:14:52 -0800 Subject: [PATCH] [et][dim order] Makes DimOrderOpsMap as an operator to operator mapping Pull Request resolved: https://github.com/pytorch/executorch/pull/7187 This diff updates DimOrderOpsMap from name-to-operator mapping to operator-to-operator mapping, which has multiple benefits: 1. Reduce dialect ambiguity. Different dialects op may map to same name (e.g. `aten.to_copy` and `exir_ops.edge.aten._to_copy.default`). Directly using op can diminish the ambiguity. 2. Auto-maintain MemoryFormatOpsMap by reverting DimOrderOpsMap ghstack-source-id: 256633613 @exported-using-ghexport Differential Revision: [D66773612](https://our.internmc.facebook.com/intern/diff/D66773612/) --- exir/passes/dim_order_ops_registry.py | 18 +++++------------- exir/passes/memory_format_ops_pass.py | 12 ++++++------ exir/verification/verifier.py | 1 + 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index c4436aaa910..f3fc009f109 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -58,22 +58,14 @@ def _empty_dim_order_out_impl(*args, **kwargs): """ -Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup +Defines a map of edge ops to the corresponding dim_order ops for quick lookup """ DimOrderOpsMap = { - "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - "aten.empty.memory_format": exir_ops.edge.dim_order_ops._empty_dim_order.default, + exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default, } """ -Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup +Defines a map of edge ops to the corresponding memory format ops for quick lookup, which is the revert of DimOrderOpsMap """ -MemoryFormatOpsMap = { - "dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default, - "dim_order_ops._empty_dim_order.default": exir_ops.edge.aten.empty.memory_format, -} - -# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts. -assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap) - -# TODO stricter check for 1:1 mapping +MemoryFormatOpsMap = {v: k for k, v in DimOrderOpsMap.items()} diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index ba89a510a71..db597b2b233 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -32,7 +32,7 @@ class MemoryFormatOpsPass(ExportPass): """ def call_operator(self, op, args, kwargs, meta): - if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap): + if not (isinstance(op, EdgeOpOverload) and op in DimOrderOpsMap): return super().call_operator( op, args, @@ -61,10 +61,10 @@ def call_operator(self, op, args, kwargs, meta): nkwargs["dim_order"] = get_dim_order(mem_format, ndim) logger.debug( f"{op.__name__} = rank: {ndim}, memory_format: {mem_format}." - f" {DimOrderOpsMap[op.__name__].__name__} = dim_order: {nkwargs['dim_order']}" + f" {DimOrderOpsMap[op].__name__} = dim_order: {nkwargs['dim_order']}" ) - t = DimOrderOpsMap[op.__name__] + t = DimOrderOpsMap[op] return super().call_operator( t, @@ -80,7 +80,7 @@ class DimOrderOpsRevertPass(ExportPass): """ def call_operator(self, op, args, kwargs, meta): - if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap): + if not (isinstance(op, EdgeOpOverload) and op in MemoryFormatOpsMap): return super().call_operator( op, args, @@ -109,10 +109,10 @@ def call_operator(self, op, args, kwargs, meta): logger.debug( f" {op.__name__} = dim_order: {dim_order}." - f" {MemoryFormatOpsMap[op.__name__].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}." + f" {MemoryFormatOpsMap[op].__name__} = rank: {ndim}, memory_format: {nkwargs['memory_format']}." ) - t = MemoryFormatOpsMap[op.__name__] + t = MemoryFormatOpsMap[op] return super().call_operator( t, diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index 2c45929bf23..f906623ca25 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -179,6 +179,7 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: if validator.violating_ops: raise SpecViolationError( f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}" + "Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding " )