66
77# pyre-strict
88
9- from typing import Any , cast , Dict , Sequence , Tuple
9+ from typing import Any , cast , Dict , List , Optional , Sequence , Tuple , Type
1010
1111import torch
12+ import torch .fx
13+ import torch .utils ._pytree as pytree
14+ from executorch .backends .cadence .aot .pass_utils import (
15+ CadencePassAttribute ,
16+ create_cadence_pass_filter ,
17+ register_cadence_pass ,
18+ )
1219from executorch .backends .cadence .aot .utils import get_edge_overload_packet
20+ from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
1321from executorch .exir .dialects ._ops import ops as exir_ops
1422from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
23+ from executorch .exir .pass_manager import PassManager , PassType
1524from executorch .exir .passes import dead_code_elimination_pass
25+ from executorch .exir .passes .scalar_to_tensor_pass import ScalarToTensorPass
1626from executorch .exir .passes .spec_prop_pass import SpecPropPass
1727from torch ._subclasses import FakeTensor
1828from torch .utils ._pytree import tree_map_only
1929
30+
31+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
32+ class InitializePipeline (ExportPass ):
33+ """
34+ Initialize the Jarvis pipeline. This should invariably be the first pass to
35+ run.
36+ """
37+
38+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
39+ dead_code_elimination_pass (graph_module )
40+ result = SpecPropPass ()(graph_module )
41+ assert result is not None
42+ return result
43+
44+
45+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
46+ class FinalizePipeline (ExportPass ):
47+ """
48+ The final cleanup pass after running the Jarvis pipeline.
49+ """
50+
51+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
52+ finalize_passes : List [PassType ] = [
53+ ScalarToTensorPass (),
54+ SpecPropPass (),
55+ ]
56+ result = PassManager (passes = finalize_passes )(graph_module )
57+ dead_code_elimination_pass (result .graph_module )
58+ return result
59+
60+
2061# Similar to what's done in executorch/exir/pass_base.py
2162Argument = Any # pyre-ignore
2263
@@ -131,7 +172,7 @@ def call_operator(
131172 )
132173
133174
134- class RemoveZeroSizedCatArgsPass (ExportPass ):
175+ class RemoveZeroSizedCatArgsPass (ExportPass ): # is this the latest?
135176 def call_operator (
136177 self ,
137178 op , # pyre-ignore
@@ -255,20 +296,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
255296 return result
256297
257298
258- class InitializePipeline (ExportPass ):
259- """
260- Initialize the Jarvis pipeline. This should invariably be the first pass to
261- run.
262- """
263-
264- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
265- dead_code_elimination_pass (graph_module )
266- result = SpecPropPass ()(graph_module )
267- assert result is not None
268- return result
269-
270-
271- class ReplaceSafeSoftmaxWithSoftmax (ExportPass ):
299+ class ReplaceSafeSoftmaxWithSoftmax (ExportPass ): # keep
272300 """
273301 Replace _safe_softmax with _softmax
274302 """
@@ -292,3 +320,33 @@ def call_operator(
292320 kwargs ,
293321 meta ,
294322 )
323+
324+
325+ def get_passes_in_default_order () -> List [Type [PassType ]]:
326+ passes = [
327+ InitializePipeline ,
328+ RemoveZeroSizedCatArgsPass ,
329+ ReplaceLogicalNotBooleanWhereWithWherePass ,
330+ ReplaceScalarTensorWithFullPass ,
331+ RemoveCloneOpsTransform ,
332+ RemoveNopExpandOpPass ,
333+ ReplaceSqueezeAndUnsqueezeWithViewPass ,
334+ ReplacePT2QuantWithCadenceQuantPass ,
335+ ReplacePT2DequantWithCadenceDequantPass ,
336+ # TODO: add the rest of the passes here.
337+ ]
338+ return pytree .tree_flatten (passes )[0 ]
339+
340+
341+ def get_cadence_passes (
342+ opt_level : int ,
343+ ) -> List [Optional [PassResult ]]:
344+ passes = get_passes_in_default_order ()
345+ pass_filter = create_cadence_pass_filter (opt_level )
346+ filtered_passes = [
347+ # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
348+ filtered_pass ()
349+ # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
350+ for filtered_pass in list (filter (pass_filter , passes ))
351+ ]
352+ return filtered_passes
0 commit comments