diff --git a/backends/apple/mps/mps_preprocess.py b/backends/apple/mps/mps_preprocess.py index 749f32a04e5..2982ebc2e01 100644 --- a/backends/apple/mps/mps_preprocess.py +++ b/backends/apple/mps/mps_preprocess.py @@ -6,6 +6,7 @@ from typing import ClassVar, Dict, final, List, Tuple import torch +from executorch import exir from executorch.backends.apple.mps.operators.node_visitor import ( get_node_visitors, @@ -35,6 +36,7 @@ from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass from executorch.exir.program._program import _transform +from executorch.exir.verification.verifier import EXIREdgeDialectVerifier from torch.export.exported_program import ExportedProgram FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -87,7 +89,19 @@ def preprocess( # the `output_ids` array in the schema. # TODO: Remove this once we have a better support for the dim-order ops. - edge_program = _transform(edge_program, DimOrderOpsRevertPass()) + # Need to override the verifier to skip the non dim-order ops from tripping the default verifier. + edge_program = _transform( + edge_program, + DimOrderOpsRevertPass(), + override_verifiers=[ + EXIREdgeDialectVerifier( + edge_compile_config=exir.EdgeCompileConfig( + _check_ir_validity=False, # Disable the edge dialect verifier, since we are in the mps backend. + ), + class_only=True, + ) + ], + ) mps_graph = MPSGraph( version="0",