From 2e07a9a2d1a63ec1d549da00c589a09cd905d4f7 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Wed, 10 Sep 2025 08:42:20 -0700 Subject: [PATCH] Replace export_for_training with export (#14073) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14073 export_for_training is deprecated, so replace it with export. Differential Revision: D81936329 --- .../coreml/test/test_coreml_quantizer.py | 6 ++-- backends/apple/mps/test/test_mps_utils.py | 2 +- .../test/test_quantize_op_fusion_pass.py | 10 ++----- backends/example/test_example_delegate.py | 8 ++---- backends/mediatek/quantizer/annotator.py | 6 ++-- backends/qualcomm/tests/utils.py | 2 +- backends/test/harness/stages/quantize.py | 4 +-- backends/test/suite/runner.py | 5 ---- .../test_duplicate_dynamic_quant_chain.py | 2 +- backends/vulkan/test/test_vulkan_passes.py | 5 ++-- backends/vulkan/test/utils.py | 8 ++---- .../test/ops/test_check_quant_params.py | 4 +-- .../test/quantizer/test_pt2e_quantization.py | 27 +++++++++--------- .../test/quantizer/test_representation.py | 4 +-- .../test/quantizer/test_xnnpack_quantizer.py | 16 ++++------- backends/xnnpack/test/test_xnnpack_utils.py | 4 +-- devtools/inspector/tests/inspector_test.py | 7 ++--- .../backend-delegates-xnnpack-reference.md | 4 +-- docs/source/backends-coreml.md | 4 +-- docs/source/backends-xnnpack.md | 4 +-- docs/source/bundled-io.md | 12 ++++---- docs/source/llm/export-custom-llm.md | 10 +++---- .../tutorial-xnnpack-delegate-lowering.md | 6 ++-- .../export-to-executorch-tutorial.py | 14 ++++------ examples/apple/mps/scripts/mps_example.py | 4 +-- examples/llm_manual/export_nanogpt.py | 4 +-- .../mediatek/aot_utils/oss_utils/utils.py | 4 +-- .../mediatek/model_export_scripts/llama.py | 2 +- examples/models/llama/eval_llama_lib.py | 2 +- .../models/phi-3-mini/export_phi-3-mini.py | 4 +-- examples/models/test/test_export.py | 4 +-- .../portable/scripts/export_and_delegate.py | 8 +++--- examples/xnnpack/aot_compiler.py | 4 +-- examples/xnnpack/quantization/example.py | 8 ++---- exir/backend/test/test_partitioner.py | 28 ++++++------------- exir/backend/test/test_passes.py | 6 ++-- exir/emit/test/test_emit.py | 7 ++--- exir/tests/test_extract_io_quant_params.py | 4 +-- exir/tests/test_passes.py | 8 ++---- exir/tests/test_quantization.py | 2 +- exir/tests/test_quantize_io_pass.py | 6 ++-- extension/export_util/utils.py | 4 +-- extension/llm/README.md | 2 +- extension/llm/export/builder.py | 10 +++---- extension/llm/export/test_export_passes.py | 6 ++-- .../test/resources/gen_bundled_program.py | 4 +-- extension/training/README.md | 1 - .../training/examples/XOR/export_model.py | 1 - 48 files changed, 123 insertions(+), 184 deletions(-) diff --git a/backends/apple/coreml/test/test_coreml_quantizer.py b/backends/apple/coreml/test/test_coreml_quantizer.py index d5754328796..eb8b9471345 100644 --- a/backends/apple/coreml/test/test_coreml_quantizer.py +++ b/backends/apple/coreml/test/test_coreml_quantizer.py @@ -15,7 +15,7 @@ ) from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer -from torch.export import export_for_training +from torch.export import export from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -32,9 +32,7 @@ def quantize_and_compare( ) -> None: assert quantization_type in {"PTQ", "QAT"} - pre_autograd_aten_dialect = export_for_training( - model, example_inputs, strict=True - ).module() + pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module() quantization_config = LinearQuantizerConfig.from_dict( { diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 674a4b0ba62..5afa604795e 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -206,7 +206,7 @@ def lower_module_and_test_output( expected_output = model(*sample_inputs) - model = torch.export.export_for_training( + model = torch.export.export( model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() diff --git a/backends/cortex_m/test/test_quantize_op_fusion_pass.py b/backends/cortex_m/test/test_quantize_op_fusion_pass.py index 3cd65208d1f..1595b0cfbc3 100644 --- a/backends/cortex_m/test/test_quantize_op_fusion_pass.py +++ b/backends/cortex_m/test/test_quantize_op_fusion_pass.py @@ -23,7 +23,7 @@ get_node_args, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.export import export, export_for_training +from torch.export import export from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -42,9 +42,7 @@ def _prepare_quantized_model(self, model_class): model = model_class() # Export and quantize - exported_model = export_for_training( - model.eval(), self.example_inputs, strict=True - ).module() + exported_model = export(model.eval(), self.example_inputs, strict=True).module() prepared_model = prepare_pt2e(exported_model, AddQuantizer()) quantized_model = convert_pt2e(prepared_model) @@ -242,9 +240,7 @@ def forward(self, x, y): inputs = (torch.randn(shape), torch.randn(shape)) model = SingleAddModel() - exported_model = export_for_training( - model.eval(), inputs, strict=True - ).module() + exported_model = export(model.eval(), inputs, strict=True).module() prepared_model = prepare_pt2e(exported_model, AddQuantizer()) quantized_model = convert_pt2e(prepared_model) diff --git a/backends/example/test_example_delegate.py b/backends/example/test_example_delegate.py index bc6ad4d7e4c..fd4c6652787 100644 --- a/backends/example/test_example_delegate.py +++ b/backends/example/test_example_delegate.py @@ -46,9 +46,7 @@ def get_example_inputs(): ) m = model.eval() - m = torch.export.export_for_training( - m, copy.deepcopy(example_inputs), strict=True - ).module() + m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module() # print("original model:", m) quantizer = ExampleQuantizer() # quantizer = XNNPACKQuantizer() @@ -84,9 +82,7 @@ def test_delegate_mobilenet_v2(self): ) m = model.eval() - m = torch.export.export_for_training( - m, copy.deepcopy(example_inputs), strict=True - ).module() + m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module() quantizer = ExampleQuantizer() m = prepare_pt2e(m, quantizer) diff --git a/backends/mediatek/quantizer/annotator.py b/backends/mediatek/quantizer/annotator.py index 8c0e42627e0..6fd4a1d23c3 100644 --- a/backends/mediatek/quantizer/annotator.py +++ b/backends/mediatek/quantizer/annotator.py @@ -10,7 +10,7 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.export import export_for_training +from torch.export import export from torch.fx import Graph, Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( SubgraphMatcherWithNameNodeMap, @@ -158,9 +158,7 @@ def forward(self, x): return norm, {} for pattern_cls in (ExecuTorchPattern, MTKPattern): - pattern_gm = export_for_training( - pattern_cls(), (torch.randn(3, 3),), strict=True - ).module() + pattern_gm = export(pattern_cls(), (torch.randn(3, 3),), strict=True).module() matcher = SubgraphMatcherWithNameNodeMap( pattern_gm, ignore_literals=True, remove_overlapping_matches=False ) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 21153f2a3ff..93eee4dfc31 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -576,7 +576,7 @@ def get_prepared_qat_module( quant_dtype: QuantDtype = QuantDtype.use_8a8w, submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ) -> torch.fx.GraphModule: - m = torch.export.export_for_training(module, inputs, strict=True).module() + m = torch.export.export(module, inputs, strict=True).module() quantizer = make_quantizer( quant_dtype=quant_dtype, diff --git a/backends/test/harness/stages/quantize.py b/backends/test/harness/stages/quantize.py index b98c4faa3dd..9edb600e19f 100644 --- a/backends/test/harness/stages/quantize.py +++ b/backends/test/harness/stages/quantize.py @@ -7,7 +7,7 @@ DuplicateDynamicQuantChainPass, ) -from torch.export import export_for_training +from torch.export import export from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, @@ -47,7 +47,7 @@ def run( assert inputs is not None if self.is_qat: artifact.train() - captured_graph = export_for_training(artifact, inputs, strict=True).module() + captured_graph = export(artifact, inputs, strict=True).module() assert isinstance(captured_graph, torch.fx.GraphModule) diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 3729d94cdf3..1f84db9c730 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -5,7 +5,6 @@ import re import time import unittest -import warnings from datetime import timedelta from typing import Any @@ -283,10 +282,6 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter: def runner_main(): args = parse_args() - # Suppress deprecation warnings for export_for_training, as it generates a - # lot of log spam. We don't really need the warning here. - warnings.simplefilter("ignore", category=FutureWarning) - seed = args.seed or random.randint(0, 100_000_000) print(f"Running with seed {seed}.") diff --git a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py index 79bc56f8780..4d4b5d8cd5a 100644 --- a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py +++ b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py @@ -58,7 +58,7 @@ def _test_duplicate_chain( # program capture m = copy.deepcopy(m_eager) - m = torch.export.export_for_training(m, example_inputs, strict=True).module() + m = torch.export.export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index b277dff2a76..76e25f2e291 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -58,7 +58,7 @@ def quantize_and_lower_module( _check_ir_validity=False, ) - program = torch.export.export_for_training( + program = torch.export.export( model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() @@ -95,7 +95,6 @@ def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) -> class TestVulkanPasses(unittest.TestCase): - def test_fuse_int8pack_mm(self): K = 256 N = 256 @@ -184,7 +183,7 @@ def test_fuse_linear_qta8a_qga4w(self): _check_ir_validity=False, ) - program = torch.export.export_for_training( + program = torch.export.export( quantized_model, sample_inputs, strict=True ).module() diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index 363ee37058d..41c1d92bd00 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -35,7 +35,7 @@ _load_for_executorch_from_buffer, ) from executorch.extension.pytree import tree_flatten -from torch.export import export, export_for_training +from torch.export import export from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -53,7 +53,7 @@ def get_exported_graph( dynamic_shapes=None, qmode=QuantizationMode.NONE, ) -> torch.fx.GraphModule: - export_training_graph = export_for_training( + export_training_graph = export( model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() @@ -590,9 +590,7 @@ def op_ablation_test( # noqa: C901 logger.info("Starting fast binary search operator ablation test...") # Step 1: Export model to get edge_program and extract operators - export_training_graph = export_for_training( - model, sample_inputs, strict=True - ).module() + export_training_graph = export(model, sample_inputs, strict=True).module() program = export( export_training_graph, sample_inputs, diff --git a/backends/xnnpack/test/ops/test_check_quant_params.py b/backends/xnnpack/test/ops/test_check_quant_params.py index 8be59aab50e..e077462ee2c 100644 --- a/backends/xnnpack/test/ops/test_check_quant_params.py +++ b/backends/xnnpack/test/ops/test_check_quant_params.py @@ -9,7 +9,7 @@ ) from executorch.backends.xnnpack.utils.utils import get_param_tensor from executorch.exir import to_edge_transform_and_lower -from torch.export import export_for_training +from torch.export import export from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -52,7 +52,7 @@ def _test_check_quant_message(self, ep_modifier, expected_message): torch._dynamo.reset() mod = torch.nn.Linear(10, 10) quantizer = XNNPACKQuantizer() - captured = export_for_training(mod, (torch.randn(1, 10),), strict=True).module() + captured = export(mod, (torch.randn(1, 10),), strict=True).module() quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True)) prepared = prepare_pt2e(captured, quantizer) diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 5d76ecd2d54..f0456a3604e 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -22,7 +22,7 @@ weight_observer_range_neg_127_to_127, ) from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.export import export_for_training +from torch.export import export from torch.testing._internal.common_quantization import ( NodeSpec as ns, TestHelperModules, @@ -58,7 +58,7 @@ def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): # resetting dynamo cache torch._dynamo.reset() - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() if is_qat: m = prepare_qat_pt2e(m, quantizer) else: @@ -351,7 +351,7 @@ def test_disallow_eval_train(self) -> None: m.train() # After export: this is not OK - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -405,7 +405,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() # pyre-ignore[23] - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: bn_op = bn_train_op if train else bn_eval_op @@ -474,7 +474,7 @@ def forward(self, x): quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() weight_meta = None for n in m.graph.nodes: # pyre-ignore[16] if ( @@ -503,7 +503,7 @@ def test_reentrant(self) -> None: quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True, is_qat=True) ) - m.conv_bn_relu = export_for_training( # pyre-ignore[8] + m.conv_bn_relu = export( # pyre-ignore[8] m.conv_bn_relu, example_inputs, strict=True ).module() m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) # pyre-ignore[6,8] @@ -513,7 +513,7 @@ def test_reentrant(self) -> None: quantizer = XNNPACKQuantizer().set_module_type( torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) ) - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # pyre-ignore[6] m = convert_pt2e(m) @@ -575,7 +575,7 @@ def check_nn_module(node: torch.fx.Node) -> None: "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] ) - m.conv_bn_relu = export_for_training( # pyre-ignore[8] + m.conv_bn_relu = export( # pyre-ignore[8] m.conv_bn_relu, example_inputs, strict=True ).module() for node in m.conv_bn_relu.graph.nodes: # pyre-ignore[16] @@ -591,7 +591,7 @@ def test_speed(self) -> None: def dynamic_quantize_pt2e(model, example_inputs) -> torch.fx.GraphModule: torch._dynamo.reset() - model = export_for_training(model, example_inputs, strict=True).module() + model = export(model, example_inputs, strict=True).module() # Per channel quantization for weight # Dynamic quantization for activation # Please read a detail: https://fburl.com/code/30zds51q @@ -648,7 +648,7 @@ def forward(self, x): example_inputs = (torch.randn(1, 3, 5, 5),) m = M() - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(), ) @@ -724,11 +724,10 @@ def test_save_load(self) -> None: class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase): - def test_quantize_pt2e_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = export(m, example_inputs, strict=True) m = ep.module() quantizer = XNNPACKQuantizer().set_global( @@ -768,7 +767,7 @@ def test_quantize_pt2e_preserve_handle(self): def test_extract_results_from_loggers(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = export(m, example_inputs, strict=True) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) @@ -792,7 +791,7 @@ def test_extract_results_from_loggers(self): def test_extract_results_from_loggers_list_output(self): m = TestHelperModules.Conv2dWithSplit() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs, strict=True) + ep = export(m, example_inputs, strict=True) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) diff --git a/backends/xnnpack/test/quantizer/test_representation.py b/backends/xnnpack/test/quantizer/test_representation.py index 817f7f9e368..614bc8b83d6 100644 --- a/backends/xnnpack/test/quantizer/test_representation.py +++ b/backends/xnnpack/test/quantizer/test_representation.py @@ -8,7 +8,7 @@ XNNPACKQuantizer, ) from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 -from torch.export import export_for_training +from torch.export import export from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -33,7 +33,7 @@ def _test_representation( ) -> None: # resetting dynamo cache torch._dynamo.reset() - model = export_for_training(model, example_inputs, strict=True).module() + model = export(model, example_inputs, strict=True).module() model_copy = copy.deepcopy(model) model = prepare_pt2e(model, quantizer) # pyre-ignore[6] diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 84b1a932a5b..1e1a473dd59 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -27,7 +27,7 @@ convert_to_reference_fx, prepare_fx, ) -from torch.export import export_for_training +from torch.export import export from torch.testing._internal.common_quantization import ( NodeSpec as ns, skip_if_no_torchvision, @@ -500,7 +500,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # pyre-ignore[6] # Use a linear count instead of names because the names might change, but # the order should be the same. @@ -636,7 +636,7 @@ def test_propagate_annotation(self): example_inputs = (torch.randn(1, 3, 5, 5),) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -901,9 +901,7 @@ def forward(self, input_tensor, hidden_tensor): model_fx = _convert_to_reference_decomposed_fx(model_fx) with torchdynamo.config.patch(allow_rnn=True): - model_graph = export_for_training( - model_graph, example_inputs, strict=True - ).module() + model_graph = export(model_graph, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=False @@ -963,9 +961,7 @@ def forward(self, input_tensor, hidden_tensor): model_fx = _convert_to_reference_decomposed_fx(model_fx) with torchdynamo.config.patch(allow_rnn=True): - model_graph = export_for_training( - model_graph, example_inputs, strict=True - ).module() + model_graph = export(model_graph, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=False, is_dynamic=False @@ -1173,7 +1169,7 @@ def test_resnet18(self): m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 1f9d8c47723..5a6c529b497 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -70,7 +70,7 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) -from torch.export import export_for_training +from torch.export import export from torch.testing import FileCheck @@ -317,7 +317,7 @@ def quantize_and_test_model_with_quantizer( module.eval() # program capture - m = export_for_training(module, example_inputs, strict=True).module() + m = export(module, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config() diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index cf1fdb7ec00..a3afed07ed8 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -662,7 +662,6 @@ def test_calculate_numeric_gap(self): ), patch.object( _inspector, "gen_graphs_from_etrecord" ): - # Call the constructor of Inspector inspector_instance = Inspector( etdump_path=ETDUMP_PATH, @@ -724,7 +723,7 @@ def test_calculate_numeric_gap(self): @unittest.skip("ci config values are not propagated") def test_intermediate_tensor_comparison_with_torch_export(self): - """Test intermediate tensor comparison using torch.export.export_for_training and to_edge_transform_and_lower.""" + """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower.""" class SimpleTestModel(torch.nn.Module): """A simple test model for demonstration purposes.""" @@ -759,8 +758,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model_path = os.path.join(tmp_dir, "model.pte") etrecord_path = os.path.join(tmp_dir, "etrecord.bin") - # Step 1: Export using torch.export.export_for_training - exported_program = torch.export.export_for_training(model, example_inputs) + # Step 1: Export using torch.export.export + exported_program = torch.export.export(model, example_inputs) self.assertIsNotNone(exported_program) # Step 2: Lower to XNNPACK with generate_etrecord=True diff --git a/docs/source/backend-delegates-xnnpack-reference.md b/docs/source/backend-delegates-xnnpack-reference.md index d38c5af60fa..cfb915aca59 100644 --- a/docs/source/backend-delegates-xnnpack-reference.md +++ b/docs/source/backend-delegates-xnnpack-reference.md @@ -106,9 +106,9 @@ quantizer.set_global(quantization_config) ### Quantizing your model with the XNNPACKQuantizer After configuring our quantizer, we are now ready to quantize our model ```python -from torch.export import export_for_training +from torch.export import export -exported_model = export_for_training(model_to_quantize, example_inputs).module() +exported_model = export(model_to_quantize, example_inputs).module() prepared_model = prepare_pt2e(exported_model, quantizer) print(prepared_model.graph) ``` diff --git a/docs/source/backends-coreml.md b/docs/source/backends-coreml.md index dbf87e7d697..26c0c570893 100644 --- a/docs/source/backends-coreml.md +++ b/docs/source/backends-coreml.md @@ -90,7 +90,7 @@ Quantization with the CoreML backend requires exporting the model for iOS 17 or To perform 8-bit quantization with the PT2E flow, follow these steps: 1) Create a [`coremltools.optimize.torch.quantization.LinearQuantizerConfig`](https://apple.github.io/coremltools/source/coremltools.optimize.torch.quantization.html#coremltools.optimize.torch.quantization.LinearQuantizerConfig) and use to to create an instance of a `CoreMLQuantizer`. -2) Use `torch.export.export_for_training` to export a graph module that will be prepared for quantization. +2) Use `torch.export.export` to export a graph module that will be prepared for quantization. 3) Call `prepare_pt2e` to prepare the model for quantization. 4) Run the prepared model with representative samples to calibrate the quantizated tensor activation ranges. 5) Call `convert_pt2e` to quantize the model. @@ -126,7 +126,7 @@ static_8bit_config = ct.optimize.torch.quantization.LinearQuantizerConfig( quantizer = CoreMLQuantizer(static_8bit_config) # Step 2: Export the model for training -training_gm = torch.export.export_for_training(mobilenet_v2, sample_inputs).module() +training_gm = torch.export.export(mobilenet_v2, sample_inputs).module() # Step 3: Prepare the model for quantization prepared_model = prepare_pt2e(training_gm, quantizer) diff --git a/docs/source/backends-xnnpack.md b/docs/source/backends-xnnpack.md index b7fca261850..d1a120e69fa 100644 --- a/docs/source/backends-xnnpack.md +++ b/docs/source/backends-xnnpack.md @@ -79,7 +79,7 @@ Weight-only quantization is not currently supported on XNNPACK. To perform 8-bit quantization with the PT2E flow, perform the following steps prior to exporting the model: 1) Create an instance of the `XnnpackQuantizer` class. Set quantization parameters. -2) Use `torch.export.export_for_training` to prepare for quantization. +2) Use `torch.export.export` to prepare for quantization. 3) Call `prepare_pt2e` to prepare the model for quantization. 4) For static quantization, run the prepared model with representative samples to calibrate the quantizated tensor activation ranges. 5) Call `convert_pt2e` to quantize the model. @@ -103,7 +103,7 @@ qparams = get_symmetric_quantization_config(is_per_channel=True) # (1) quantizer = XNNPACKQuantizer() quantizer.set_global(qparams) -training_ep = torch.export.export_for_training(model, sample_inputs).module() # (2) +training_ep = torch.export.export(model, sample_inputs).module() # (2) prepared_model = prepare_pt2e(training_ep, quantizer) # (3) for cal_sample in [torch.randn(1, 3, 224, 224)]: # Replace with representative model inputs diff --git a/docs/source/bundled-io.md b/docs/source/bundled-io.md index 3e8accce80e..79897737268 100644 --- a/docs/source/bundled-io.md +++ b/docs/source/bundled-io.md @@ -96,7 +96,7 @@ from executorch.devtools.bundled_program.config import MethodTestCase, MethodTes from executorch.devtools.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) -from torch.export import export, export_for_training +from torch.export import export # Step 1: ExecuTorch Program Export @@ -130,7 +130,7 @@ capture_input = ( # Export method's FX Graph. method_graph = export( - export_for_training(model, capture_input).module(), + export(model, capture_input).module(), capture_input, ) @@ -238,7 +238,7 @@ from executorch.exir import to_edge from executorch.devtools import BundledProgram from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite -from torch.export import export, export_for_training +from torch.export import export class Module(torch.nn.Module): @@ -262,7 +262,7 @@ inputs = (torch.ones(2, 2, dtype=torch.float), ) # Find each method of model needs to be traced my its name, export its FX Graph. method_graph = export( - export_for_training(model, inputs).module(), + export(model, inputs).module(), inputs, ) @@ -374,7 +374,7 @@ from executorch.exir import to_edge from executorch.devtools import BundledProgram from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite -from torch.export import export, export_for_training +from torch.export import export class Module(torch.nn.Module): @@ -398,7 +398,7 @@ inputs = (torch.ones(2, 2, dtype=torch.float),) # Find each method of model needs to be traced my its name, export its FX Graph. method_graph = export( - export_for_training(model, inputs).module(), + export(model, inputs).module(), inputs, ) diff --git a/docs/source/llm/export-custom-llm.md b/docs/source/llm/export-custom-llm.md index bbdf596d21b..57537ba31d8 100644 --- a/docs/source/llm/export-custom-llm.md +++ b/docs/source/llm/export-custom-llm.md @@ -41,7 +41,7 @@ import torch from executorch.exir import EdgeCompileConfig, to_edge from torch.nn.attention import sdpa_kernel, SDPBackend -from torch.export import export, export_for_training +from torch.export import export from model import GPT @@ -66,7 +66,7 @@ dynamic_shape = ( # Trace the model, converting it to a portable intermediate representation. # The torch.no_grad() call tells PyTorch to exclude training-specific logic. with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module() + m = export(model, example_inputs, dynamic_shapes=dynamic_shape).module() traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape) # Convert the model into a runnable ExecuTorch program. @@ -125,7 +125,7 @@ from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower import torch from torch.export import export from torch.nn.attention import sdpa_kernel, SDPBackend -from torch.export import export_for_training +from torch.export import export from model import GPT @@ -152,7 +152,7 @@ dynamic_shape = ( # Trace the model, converting it to a portable intermediate representation. # The torch.no_grad() call tells PyTorch to exclude training-specific logic. with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - m = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shape).module() + m = export(model, example_inputs, dynamic_shapes=dynamic_shape).module() traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape) # Convert the model into a runnable ExecuTorch program. @@ -209,7 +209,7 @@ xnnpack_quant_config = get_symmetric_quantization_config( xnnpack_quantizer = XNNPACKQuantizer() xnnpack_quantizer.set_global(xnnpack_quant_config) -m = export_for_training(model, example_inputs).module() +m = export(model, example_inputs).module() # Annotate the model for quantization. This prepares the model for calibration. m = prepare_pt2e(m, xnnpack_quantizer) diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index 04cab007f65..bccd4e4add3 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -77,13 +77,13 @@ After lowering to the XNNPACK Program, we can then prepare it for executorch and The XNNPACK delegate can also execute symmetrically quantized models. To understand the quantization flow and learn how to quantize models, refer to [Custom Quantization](quantization-custom-quantization.md) note. For the sake of this tutorial, we will leverage the `quantize()` python helper function conveniently added to the `executorch/executorch/examples` folder. ```python -from torch.export import export_for_training +from torch.export import export from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() sample_inputs = (torch.randn(1, 3, 224, 224), ) -mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path +mobilenet_v2 = export(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( @@ -110,7 +110,7 @@ def quantize(model, example_inputs): quantized_mobilenetv2 = quantize(mobilenet_v2, sample_inputs) ``` -Quantization requires a two stage export. First we use the `export_for_training` API to capture the model before giving it to `quantize` utility function. After performing the quantization step, we can now leverage the XNNPACK delegate to lower the quantized exported model graph. From here, the procedure is the same as for the non-quantized model lowering to XNNPACK. +Quantization requires a two stage export. First we use the `export` API to capture the model before giving it to `quantize` utility function. After performing the quantization step, we can now leverage the XNNPACK delegate to lower the quantized exported model graph. From here, the procedure is the same as for the non-quantized model lowering to XNNPACK. ```python # Continued from earlier... diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index 2ca6a207d17..4dd2818fabe 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -173,8 +173,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # ----------------------- # # To quantize a model, we first need to capture the graph with -# ``torch.export.export_for_training``, perform quantization, and then -# call ``torch.export``. ``torch.export.export_for_training`` returns a +# ``torch.export.export``, perform quantization, and then +# call ``torch.export``. ``torch.export.export`` returns a # graph which contains ATen operators which are Autograd safe, meaning they are # safe for eager-mode training, which is needed for quantization. We will call # the graph at this level, the ``Pre-Autograd ATen Dialect`` graph. @@ -187,12 +187,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # will annotate the nodes in the graph with information needed to quantize the # model properly for a specific backend. -from torch.export import export_for_training +from torch.export import export example_args = (torch.randn(1, 3, 256, 256),) -pre_autograd_aten_dialect = export_for_training( - SimpleConv(), example_args, strict=True -).module() +pre_autograd_aten_dialect = export(SimpleConv(), example_args, strict=True).module() print("Pre-Autograd ATen Dialect Graph") print(pre_autograd_aten_dialect) @@ -543,7 +541,7 @@ def forward(self, a, x, b): # Here is an example for an entire end-to-end workflow: import torch -from torch.export import export, export_for_training, ExportedProgram +from torch.export import export, export, ExportedProgram class M(torch.nn.Module): @@ -557,7 +555,7 @@ def forward(self, x): example_args = (torch.randn(3, 4),) -pre_autograd_aten_dialect = export_for_training(M(), example_args, strict=True).module() +pre_autograd_aten_dialect = export(M(), example_args, strict=True).module() # Optionally do quantization: # pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer)) aten_dialect = export(pre_autograd_aten_dialect, example_args, strict=True) diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 42ea79435ed..46e0f6af242 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -170,9 +170,7 @@ def parse_args(): # pre-autograd export. eventually this will become torch.export with torch.no_grad(): - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() + model = torch.export.export(model, example_inputs, strict=True).module() edge: EdgeProgramManager = export_to_edge( model, example_inputs, diff --git a/examples/llm_manual/export_nanogpt.py b/examples/llm_manual/export_nanogpt.py index 8c948479f2a..9beb041ce27 100644 --- a/examples/llm_manual/export_nanogpt.py +++ b/examples/llm_manual/export_nanogpt.py @@ -15,7 +15,7 @@ from executorch.exir import to_edge from model import GPT -from torch.export import export, export_for_training +from torch.export import export from torch.nn.attention import sdpa_kernel, SDPBackend model = GPT.from_pretrained("gpt2") # use gpt2 weight as pretrained weight @@ -27,7 +27,7 @@ # Trace the model, converting it to a portable intermediate representation. # The torch.no_grad() call tells PyTorch to exclude training-specific logic. with sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - m = export_for_training( + m = export( model, example_inputs, dynamic_shapes=dynamic_shape, strict=True ).module() traced_model = export(m, example_inputs, dynamic_shapes=dynamic_shape, strict=True) diff --git a/examples/mediatek/aot_utils/oss_utils/utils.py b/examples/mediatek/aot_utils/oss_utils/utils.py index e365309f10c..69c3beb7475 100755 --- a/examples/mediatek/aot_utils/oss_utils/utils.py +++ b/examples/mediatek/aot_utils/oss_utils/utils.py @@ -33,9 +33,7 @@ def build_executorch_binary( if quant_dtype not in Precision: raise AssertionError(f"No support for Precision {quant_dtype}.") - captured_model = torch.export.export_for_training( - model, inputs, strict=True - ).module() + captured_model = torch.export.export(model, inputs, strict=True).module() annotated_model = prepare_pt2e(captured_model, quantizer) print("Quantizing the model...") # calibration diff --git a/examples/mediatek/model_export_scripts/llama.py b/examples/mediatek/model_export_scripts/llama.py index 60c57850d00..8953d3e9050 100644 --- a/examples/mediatek/model_export_scripts/llama.py +++ b/examples/mediatek/model_export_scripts/llama.py @@ -322,7 +322,7 @@ def export_to_et_ir( max_num_token, max_cache_size, True ) print("Getting pre autograd ATen Dialect Graph") - pre_autograd_aten_dialect = torch.export.export_for_training( + pre_autograd_aten_dialect = torch.export.export( model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() # NOTE: Will be replaced with export quantizer = NeuropilotQuantizer() diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 991ff72ae43..03f8f5cd759 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -223,7 +223,7 @@ def gen_eval_wrapper( ) else: # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch - # for quantizers. Currently export_for_training only works with --kv_cache, but + # for quantizers. Currently export only works with --kv_cache, but # fails without the kv_cache mode model = ( manager.model.eval().to(device="cuda") diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index d1239d9769d..017c15f783e 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -23,7 +23,7 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from torch.export import export_for_training +from torch.export import export as torch_export from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -108,7 +108,7 @@ def export(args) -> None: gm(*example_inputs) gm = convert_pt2e(gm) DuplicateDynamicQuantChainPass()(gm) - exported_program = export_for_training( + exported_program = torch_export( gm, example_inputs, dynamic_shapes=dynamic_shapes, strict=False ) diff --git a/examples/models/test/test_export.py b/examples/models/test/test_export.py index 306f54c0e89..53b6278bd6c 100644 --- a/examples/models/test/test_export.py +++ b/examples/models/test/test_export.py @@ -29,9 +29,7 @@ def collect_executorch_and_eager_outputs( Returns a tuple containing the outputs of the eager mode model and the executorch mode model. """ eager_model = eager_model.eval() - model = torch.export.export_for_training( - eager_model, example_inputs, strict=True - ).module() + model = torch.export.export(eager_model, example_inputs, strict=True).module() edge_model = export_to_edge(model, example_inputs) executorch_prog = edge_model.to_executorch() diff --git a/examples/portable/scripts/export_and_delegate.py b/examples/portable/scripts/export_and_delegate.py index 1c2adf67688..509916959b0 100644 --- a/examples/portable/scripts/export_and_delegate.py +++ b/examples/portable/scripts/export_and_delegate.py @@ -61,7 +61,7 @@ def export_composite_module_with_lower_graph(): m_compile_spec = m.get_compile_spec() # pre-autograd export. eventually this will become torch.export - m = torch.export.export_for_training(m, m_inputs, strict=True).module() + m = torch.export.export(m, m_inputs, strict=True).module() edge = export_to_edge(m, m_inputs) logging.info(f"Exported graph:\n{edge.exported_program().graph}") @@ -84,7 +84,7 @@ def forward(self, *args): m = CompositeModule() m = m.eval() # pre-autograd export. eventually this will become torch.export - m = torch.export.export_for_training(m, m_inputs, strict=True).module() + m = torch.export.export(m, m_inputs, strict=True).module() composited_edge = export_to_edge(m, m_inputs) # The graph module is still runnerable @@ -134,7 +134,7 @@ def get_example_inputs(self): m = Model() m_inputs = m.get_example_inputs() # pre-autograd export. eventually this will become torch.export - m = torch.export.export_for_training(m, m_inputs, strict=True).module() + m = torch.export.export(m, m_inputs, strict=True).module() edge = export_to_edge(m, m_inputs) logging.info(f"Exported graph:\n{edge.exported_program().graph}") @@ -171,7 +171,7 @@ def export_and_lower_the_whole_graph(): m_inputs = m.get_example_inputs() # pre-autograd export. eventually this will become torch.export - m = torch.export.export_for_training(m, m_inputs, strict=True).module() + m = torch.export.export(m, m_inputs, strict=True).module() edge = export_to_edge(m, m_inputs) logging.info(f"Exported graph:\n{edge.exported_program().graph}") diff --git a/examples/xnnpack/aot_compiler.py b/examples/xnnpack/aot_compiler.py index 886f3123f85..81eeb75c72c 100644 --- a/examples/xnnpack/aot_compiler.py +++ b/examples/xnnpack/aot_compiler.py @@ -86,14 +86,14 @@ model = model.eval() # pre-autograd export. eventually this will become torch.export - ep = torch.export.export_for_training(model, example_inputs, strict=False) + ep = torch.export.export(model, example_inputs, strict=False) model = ep.module() if args.quantize: logging.info("Quantizing Model...") # TODO(T165162973): This pass shall eventually be folded into quantizer model = quantize(model, example_inputs, quant_type) - ep = torch.export.export_for_training(model, example_inputs, strict=False) + ep = torch.export.export(model, example_inputs, strict=False) edge = to_edge_transform_and_lower( ep, diff --git a/examples/xnnpack/quantization/example.py b/examples/xnnpack/quantization/example.py index 93831ab8252..f5425155008 100644 --- a/examples/xnnpack/quantization/example.py +++ b/examples/xnnpack/quantization/example.py @@ -60,9 +60,7 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_ m = model # 1. pytorch 2.0 export quantization flow (recommended/default flow) - m = torch.export.export_for_training( - m, copy.deepcopy(example_inputs), strict=True - ).module() + m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) quantizer.set_global(quantization_config) @@ -179,9 +177,7 @@ def main() -> None: model = model.eval() # pre-autograd export. eventually this will become torch.export - model = torch.export.export_for_training( - model, example_inputs, strict=True - ).module() + model = torch.export.export(model, example_inputs, strict=True).module() start = time.perf_counter() quantized_model = quantize(model, example_inputs) end = time.perf_counter() diff --git a/exir/backend/test/test_partitioner.py b/exir/backend/test/test_partitioner.py index d369a914fac..dedcfe52966 100644 --- a/exir/backend/test/test_partitioner.py +++ b/exir/backend/test/test_partitioner.py @@ -40,7 +40,7 @@ ) from executorch.extension.pytree import tree_flatten from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param -from torch.export import export, export_for_training +from torch.export import export from torch.fx.passes.operator_support import any_chain @@ -76,7 +76,7 @@ def partition( mlp = MLP() example_inputs = mlp.get_random_inputs() - model = export_for_training(mlp, example_inputs, strict=True).module() + model = export(mlp, example_inputs, strict=True).module() aten = export(model, example_inputs, strict=True) spec_key = "path" spec_value = "/a/b/c/d" @@ -137,7 +137,7 @@ def partition( mlp = MLP() example_inputs = mlp.get_random_inputs() - model = export_for_training(mlp, example_inputs, strict=True).module() + model = export(mlp, example_inputs, strict=True).module() aten = export(model, example_inputs, strict=True) edge = exir.to_edge(aten) @@ -177,7 +177,7 @@ def partition( mlp = MLP() example_inputs = mlp.get_random_inputs() - model = export_for_training(mlp, example_inputs, strict=True).module() + model = export(mlp, example_inputs, strict=True).module() edge = exir.to_edge(export(model, example_inputs, strict=True)) with self.assertRaisesRegex( @@ -229,9 +229,7 @@ def partition( partition_tags=partition_tags, ) - model = export_for_training( - self.AddConst(), (torch.ones(2, 2),), strict=True - ).module() + model = export(self.AddConst(), (torch.ones(2, 2),), strict=True).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),), strict=True)) delegated = edge.to_backend(PartitionerNoTagData()) @@ -310,9 +308,7 @@ def partition( partition_tags=partition_tags, ) - model = export_for_training( - self.AddConst(), (torch.ones(2, 2),), strict=True - ).module() + model = export(self.AddConst(), (torch.ones(2, 2),), strict=True).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),), strict=True)) delegated = edge.to_backend(PartitionerTagData()) @@ -387,9 +383,7 @@ def partition( partition_tags=partition_tags, ) - model = export_for_training( - self.AddConst(), (torch.ones(2, 2),), strict=True - ).module() + model = export(self.AddConst(), (torch.ones(2, 2),), strict=True).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),), strict=True)) delegated = edge.to_backend(PartitionerTagData()) @@ -477,9 +471,7 @@ def partition( ) inputs = (torch.ones(2, 2),) - model = export_for_training( - ReuseConstData(), (torch.ones(2, 2),), strict=True - ).module() + model = export(ReuseConstData(), (torch.ones(2, 2),), strict=True).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),), strict=True)) exec_prog = edge.to_backend(PartitionerTagData()).to_executorch() executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer) @@ -539,9 +531,7 @@ def partition( partition_tags=partition_tags, ) - model = export_for_training( - ReuseConstData(), (torch.ones(2, 2),), strict=True - ).module() + model = export(ReuseConstData(), (torch.ones(2, 2),), strict=True).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),), strict=True)) with self.assertRaises(RuntimeError) as error: _ = edge.to_backend(PartitionerTagData()) diff --git a/exir/backend/test/test_passes.py b/exir/backend/test/test_passes.py index 1cdf494fa01..32fb75e90f9 100644 --- a/exir/backend/test/test_passes.py +++ b/exir/backend/test/test_passes.py @@ -12,7 +12,7 @@ duplicate_constant_node, ) from torch._export.utils import is_buffer -from torch.export import export_for_training +from torch.export import export from torch.testing import FileCheck @@ -28,9 +28,7 @@ def forward(self, x): z = x - self.const return y, z - model = export_for_training( - ReuseConstData(), (torch.ones(2, 2),), strict=True - ).module() + model = export(ReuseConstData(), (torch.ones(2, 2),), strict=True).module() edge = exir.to_edge( torch.export.export(model, (torch.ones(2, 2),), strict=True) ) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 7d0da7170c6..43b5fcfa99b 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -66,7 +66,7 @@ from functorch.experimental import control_flow from torch import nn -from torch.export import Dim, export, export_for_training +from torch.export import Dim, export from torch.export.experimental import _export_forward_backward @@ -1751,8 +1751,8 @@ def forward(self, x): module_1(*example_inputs) module_2(*example_inputs) - ep1 = export_for_training(module_1, example_inputs, strict=True) - ep2 = export_for_training(module_2, example_inputs, strict=True) + ep1 = export(module_1, example_inputs, strict=True) + ep2 = export(module_2, example_inputs, strict=True) edge_program_manager = exir.to_edge( {"forward1": ep1, "forward2": ep2}, @@ -1794,7 +1794,6 @@ def forward(self, input, label): net = TrainingNet(Net()) # Captures the forward graph. The graph will look similar to the model definition now. - # Will move to export_for_training soon which is the api planned to be supported in the long term. ep = export( net, (torch.randn(1, 2), torch.ones(1, dtype=torch.int64)), strict=True ) diff --git a/exir/tests/test_extract_io_quant_params.py b/exir/tests/test_extract_io_quant_params.py index 84da01c673d..ec018e8ae68 100644 --- a/exir/tests/test_extract_io_quant_params.py +++ b/exir/tests/test_extract_io_quant_params.py @@ -41,7 +41,7 @@ def setUp(self): operator_config = get_symmetric_quantization_config() self.quantizer.set_global(operator_config) - exported = torch.export.export_for_training( + exported = torch.export.export( self.mod, copy.deepcopy(self.example_inputs), strict=True, @@ -54,7 +54,7 @@ def setUp(self): converted = convert_pt2e(prepared) # Export again with quant parameters - final_export = torch.export.export_for_training( + final_export = torch.export.export( converted, self.example_inputs, strict=True, diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 781c4d716e4..716b808b087 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1340,9 +1340,7 @@ def forward(self, query, key, value): value = torch.randn(32, 32, 32, 32) # Capture the model - m = torch.export.export_for_training( - M(32), (query, key, value), strict=True - ).module() + m = torch.export.export(M(32), (query, key, value), strict=True).module() # 8w16a quantization from torchao.quantization.pt2e import MinMaxObserver, PerChannelMinMaxObserver @@ -1615,9 +1613,7 @@ def quantize_model( m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor] ) -> Tuple[EdgeProgramManager, int, int]: # program capture - m = torch.export.export_for_training( - m_eager, example_inputs, strict=True - ).module() + m = torch.export.export(m_eager, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config() diff --git a/exir/tests/test_quantization.py b/exir/tests/test_quantization.py index c7bcdeeeb5c..7cfa926dae3 100644 --- a/exir/tests/test_quantization.py +++ b/exir/tests/test_quantization.py @@ -51,7 +51,7 @@ def test_resnet(self) -> None: m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture - m = torch.export.export_for_training( + m = torch.export.export( m, copy.deepcopy(example_inputs), strict=True ).module() diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py index f670594616a..1dde08ce15b 100644 --- a/exir/tests/test_quantize_io_pass.py +++ b/exir/tests/test_quantize_io_pass.py @@ -38,15 +38,13 @@ def _quantize(self, mod, example_inputs): quantizer = XNNPACKQuantizer() operator_config = get_symmetric_quantization_config() quantizer.set_global(operator_config) - m = torch.export.export_for_training( + m = torch.export.export( mod, copy.deepcopy(example_inputs), strict=True ).module() m = prepare_pt2e(m, quantizer) _ = m(*example_inputs) m = convert_pt2e(m) - exported_program = torch.export.export_for_training( - m, example_inputs, strict=True - ) + exported_program = torch.export.export(m, example_inputs, strict=True) return exported_program def _check_count(self, op, count, epm): diff --git a/extension/export_util/utils.py b/extension/export_util/utils.py index aa3a736af3c..782b2b0ae63 100644 --- a/extension/export_util/utils.py +++ b/extension/export_util/utils.py @@ -14,7 +14,7 @@ import torch from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge from executorch.exir.tracer import Value -from torch.export import export, export_for_training, ExportedProgram +from torch.export import export, ExportedProgram _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( @@ -108,7 +108,7 @@ def export_to_exec_prog( ) -> ExecutorchProgramManager: m = model.eval() # pre-autograd export. eventually this will become torch.export - m = export_for_training(m, example_inputs, strict=True).module() + m = export(m, example_inputs, strict=True).module() core_aten_ep = _to_core_aten( m, diff --git a/extension/llm/README.md b/extension/llm/README.md index 0f71088eea1..de8e4e6d619 100644 --- a/extension/llm/README.md +++ b/extension/llm/README.md @@ -10,7 +10,7 @@ Commonly used methods in this class include: - _source_transform_: execute a series of source transform passes. Some transform passes include - weight only quantization, which can be done at source (eager mode) level. - replace some torch operators to a custom operator. For example, _replace_sdpa_with_custom_op_. -- _torch.export_for_training_: get a graph that is ready for pt2 graph-based quantization. +- _torch.export_: get a graph that is ready for pt2 graph-based quantization. - _pt2e_quantize_ with passed in quantizers. - util functions in _quantizer_lib.py_ can help to get different quantizers based on the needs. - _export_to_edge_: export to edge dialect diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 3c8b6b4aa2a..4fa220b0565 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -34,7 +34,7 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.export import export_for_training, ExportedProgram +from torch.export import export, ExportedProgram from torch.nn.attention import SDPBackend from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer @@ -234,7 +234,7 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: logging.info(f"inputs: {self.example_inputs}") logging.info(f"kwargs: {self.example_kwarg_inputs}") logging.info(f"dynamic shapes: {dynamic_shape}") - exported_module = export_for_training( + exported_module = export( self.model if not module else module, self.example_inputs, kwargs=self.example_kwarg_inputs, @@ -246,7 +246,7 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: def export(self) -> "LLMEdgeManager": """ Exports the model pre-autograd. This is not a full export, since it uses - torch.export_for_training() to keep autograd-safe ops from getting decomposed. + torch.export.export() to keep autograd-safe ops from getting decomposed. The full torch.export() if called later on during to_edge() or to_edge_transform_and_lower(). """ @@ -257,9 +257,7 @@ def export(self) -> "LLMEdgeManager": self.pre_autograd_graph_module = exported_module.module() if self.save_exported_program: export_output = f"{self.modelname}.pt2" - logging.info( - f"Saving torch.export()/export_for_training() result to {export_output}" - ) + logging.info(f"Saving torch.export() result to {export_output}") torch.export.save(exported_module, export_output) return self diff --git a/extension/llm/export/test_export_passes.py b/extension/llm/export/test_export_passes.py index 3b58e8218c4..1b59892ff88 100644 --- a/extension/llm/export/test_export_passes.py +++ b/extension/llm/export/test_export_passes.py @@ -7,13 +7,13 @@ ReplaceSDPAWithCustomSDPAPass, ) -from torch.export import export_for_training +from torch.export import export from torch.testing import FileCheck class RemoveRedundantTransposesPassTest(unittest.TestCase): def _export(self, model, example_inputs): - exported_module = export_for_training(model, example_inputs, strict=True) + exported_module = export(model, example_inputs, strict=True) return exported_module.module() def _check(self, model, example_inputs, key, before_count, after_count): @@ -177,7 +177,7 @@ def setUp(self): def _test(self, args, assume_causal_mask=False): m = self.TestModule() - gm = export_for_training(m, args, strict=True).module() + gm = export(m, args, strict=True).module() sdpa_key = "torch.ops.aten.scaled_dot_product_attention.default" custom_sdpa_key = "torch.ops.llama.custom_sdpa.default" diff --git a/extension/module/test/resources/gen_bundled_program.py b/extension/module/test/resources/gen_bundled_program.py index f1fa0a4a7e3..a85088dd817 100644 --- a/extension/module/test/resources/gen_bundled_program.py +++ b/extension/module/test/resources/gen_bundled_program.py @@ -8,7 +8,7 @@ ) from executorch.exir import to_edge_transform_and_lower -from torch.export import export, export_for_training +from torch.export import export # Step 1: ExecuTorch Program Export @@ -44,7 +44,7 @@ def main() -> None: # Export method's FX Graph. method_graph = export( - export_for_training(model, capture_input).module(), + export(model, capture_input).module(), capture_input, ) diff --git a/extension/training/README.md b/extension/training/README.md index f6f8d5139a6..ed2d65ef343 100644 --- a/extension/training/README.md +++ b/extension/training/README.md @@ -93,7 +93,6 @@ input = torch.randn(1, 2) label = torch.ones(1, dtype=torch.int64) # Captures the forward graph. The graph will look similar to the model definition now. -# Will move to export_for_training soon which is the api planned to be supported in the long term. ep = export(net, (input, label)) ``` diff --git a/extension/training/examples/XOR/export_model.py b/extension/training/examples/XOR/export_model.py index 98e04f09a2f..35f6c390b49 100644 --- a/extension/training/examples/XOR/export_model.py +++ b/extension/training/examples/XOR/export_model.py @@ -23,7 +23,6 @@ def _export_model(external_mutable_weights: bool = False): x = torch.randn(1, 2) # Captures the forward graph. The graph will look similar to the model definition now. - # Will move to export_for_training soon which is the api planned to be supported in the long term. ep = export(net, (x, torch.ones(1, dtype=torch.int64)), strict=True) # Captures the backward graph. The exported_program now contains the joint forward and backward graph. ep = _export_forward_backward(ep)