diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 71588f44ac9..ee54fe3660d 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -29,10 +29,10 @@ from executorch.extension.export_util.utils import export_to_edge, save_pte_program from executorch.extension.llm.tokenizer.utils import get_tokenizer +from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.export import export_for_training from torch.nn.attention import SDPBackend FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -190,9 +190,9 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": strict=True, ).module() else: - self.pre_autograd_graph_module = export_for_training( + self.pre_autograd_graph_module = capture_pre_autograd_graph( self.model, self.example_inputs, dynamic_shapes=dynamic_shape - ).module() + ) return self