From 2ad178a28f2bcd9cb348d8c9931fa50f669799b1 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 11 Oct 2024 13:15:16 -0700 Subject: [PATCH] Back out "Revert export_for_training migration in llm/export/builder.py" Summary: Revert back change since we have fixed the issue in https://github.com/pytorch/pytorch/issues/137540 with D64080561. Differential Revision: D64260221 --- extension/llm/export/builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 11d92f32f56..37d90f8595b 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -29,10 +29,10 @@ from executorch.extension.export_util.utils import export_to_edge, save_pte_program from executorch.extension.llm.tokenizer.utils import get_tokenizer -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torch.export import export_for_training from torch.nn.attention import SDPBackend FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -193,12 +193,12 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": strict=True, ).module() else: - self.pre_autograd_graph_module = capture_pre_autograd_graph( + self.pre_autograd_graph_module = export_for_training( self.model, self.example_inputs, kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, - ) + ).module() return self