diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 75190b9c7be..001ab95d629 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2242,6 +2242,17 @@ def call_operator(self, op, args, kwargs, meta): ) +class CommonReplacePasses: + passes = [ + ReplaceSqueezeAndUnsqueezeWithViewPass, + ReplaceSplitWithSlicePass, + ReplaceSelectWithViewOpPass, + ReplaceMMWithAddMMPass, + ReplaceRepeatWithCatPass, + ReplaceFullLikeWithFullPass, + ] + + @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass(ExportPass): """ @@ -2260,19 +2271,15 @@ def call_operator(self, op, args, kwargs, meta): # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: - passes = [ + passes = CommonReplacePasses.passes + [ ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, ReplacePermuteWithTransposePass, ReplaceScalarWithTensorArgPass, ReplaceConvolutionOptionalArgsWithConcreteArgsPass, - ReplaceMMWithAddMMPass, - ReplaceSqueezeAndUnsqueezeWithViewPass, ReplaceAddMMWithLinearPass, RemoveNopSelectOpPass, - ReplaceSelectWithViewOpPass, - ReplaceRepeatWithCatPass, ReplacePadWithCatPass, ReplaceConstantPadNdWithSlicePass, ReplaceAtenConvolutionWithCadenceConvolutionPass, @@ -2287,7 +2294,6 @@ class CadenceReplaceOpsInGraph: ReplaceNopTransposeOrPermuteWithViewPass, ReplaceLinearWithFullyConnectedOpPass, ReplaceScalarTensorWithFullPass, - ReplaceFullLikeWithFullPass, ReplaceInfArgInFullWithValuePass, ReplaceLogicalNotBooleanWhereWithWherePass, ReplacePT2QuantWithCadenceQuantPass, @@ -2297,7 +2303,6 @@ class CadenceReplaceOpsInGraph: ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, ReplaceAtenApproxGeluWithApproxGeluPass, - ReplaceSplitWithSlicePass, ReplacePowWithMulPass, ReplaceMulTensorWithMulAndFullOpsPass, ]