diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0d6dc87de2f..180cda46207 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -895,7 +895,6 @@ def _to_edge_and_lower_llama_xnnpack( if gen_tag_fn is not None: from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass_unlifted, - external_constants_pass, ) assert ( @@ -906,18 +905,14 @@ def _to_edge_and_lower_llama_xnnpack( gen_tag_fn=gen_tag_fn, ) - # Also add a pass for 'to_executorch' to tag weights that aren't delegated. - additional_passes.append( - partial(external_constants_pass, gen_tag_fn=gen_tag_fn) - ) - builder = builder.to_edge_transform_and_lower(partitioners) if verbose: print_delegation_info(builder.edge_manager.exported_program().graph_module) - # we need builder.export_program - - return builder.to_executorch(passes=additional_passes) + # Add gen_tag_fn to tag non-delegated weights as well. + return builder.to_executorch( + passes=additional_passes, external_constants_tag=gen_tag_fn + ) def _to_edge_and_lower_llama_openvino( diff --git a/exir/program/_program.py b/exir/program/_program.py index 5a96e02082b..03c9aeed886 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -839,13 +839,11 @@ def edge_to_executorch_passes( Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. """ passes: List[PassType] = [ - SpecPropPass(), # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. *config.passes, - # config.passes may contain external_constants_pass. This pass has to - # run after SpecPropPass, which populates tensor names. + SpecPropPass(), EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), ] + pre_memory_planning_passes(config, name) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index f8c556f351c..675c0179ebb 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -473,7 +473,11 @@ def to_edge_transform_and_lower( return self def to_executorch( - self, passes: Optional[List[ExportPass]] = None + self, + passes: Optional[List[ExportPass]] = None, + external_constants_tag: Optional[ + Callable[[torch.fx.Node], Optional[str]] + ] = None, ) -> "LLMEdgeManager": """ Lower the model to executorch and get an ExecutorchProgram. @@ -506,6 +510,7 @@ def to_executorch( do_quant_fusion_and_const_prop=True, memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), + external_constants=external_constants_tag, ) ) logging.info(