diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index a0b44fb9652..825cf923e7c 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -443,6 +443,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=None, help="path to the input pruning token mapping file (token_map.json)", ) + + parser.add_argument( + "--export_only", + default=False, + action="store_true", + help="If true, stops right after torch.export() and saves the exported model.", + ) return parser @@ -587,12 +594,14 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) # export_to_edge - builder_exported_to_edge = ( - _prepare_for_llama_export(modelname, args) - .export() - .pt2e_quantize(quantizers) - .export_to_edge() - ) + builder_exported = _prepare_for_llama_export(modelname, args).export() + + if args.export_only: + exit() + + builder_exported_to_edge = builder_exported.pt2e_quantize( + quantizers + ).export_to_edge() modelname = builder_exported_to_edge.modelname diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bd12c374b51..311e788797f 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -186,22 +186,25 @@ def export(self) -> "LLMEdgeManager": # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as # `Module`. - self.pre_autograd_graph_module = torch.export.export( + exported_module = torch.export.export( self.model, self.example_inputs, self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, strict=True, - ).module() + ) else: # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as # `Module`. - self.pre_autograd_graph_module = export_for_training( + exported_module = export_for_training( self.model, self.example_inputs, kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, - ).module() + ) + self.pre_autograd_graph_module = exported_module.module() + if hasattr(self.args, "export_only") and self.args.export_only: + torch.export.save(exported_module, self.args.output_name) return self