From c06c3868d096790d358eb7fac72036f76b55cd06 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 15 Sep 2025 14:28:55 +0200 Subject: [PATCH] Add strict export option to ExportRecipe Default is True, mirroring earlier behavior. Also update ExportSession to handle this. Signed-off-by: Erik Lundell Change-Id: Iaa47fabb713daf9d5c352408bb20625d05b7e476 --- export/export.py | 5 ++++- export/recipe.py | 3 +++ export/stages.py | 5 ++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/export/export.py b/export/export.py index 86a932d153c..1e9cdbde7c0 100644 --- a/export/export.py +++ b/export/export.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -200,7 +201,9 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: aten_transform_passes = list( self._export_recipe.aten_transform_passes ) - stage = TorchExportStage(aten_transform_passes) + stage = TorchExportStage( + aten_transform_passes, strict=self._export_recipe.strict + ) elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER: stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EDGE: diff --git a/export/recipe.py b/export/recipe.py index 18f4b8aebb9..4465da51956 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -151,6 +152,7 @@ class ExportRecipe: executorch_backend_config: Optional backend configuration for ExecuTorch pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline. mode: Export mode (debug or release) + strict: Set the strict flag in the torch export call. """ name: Optional[str] = None @@ -163,6 +165,7 @@ class ExportRecipe: executorch_backend_config: Optional[ExecutorchBackendConfig] = None pipeline_stages: Optional[List[StageType]] = None mode: Mode = Mode.RELEASE + strict: bool = True @classmethod def get_recipe(cls, recipe: "RecipeType", **kwargs) -> "ExportRecipe": diff --git a/export/stages.py b/export/stages.py index 323b327bfa4..3be801c6a14 100644 --- a/export/stages.py +++ b/export/stages.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -110,9 +111,11 @@ def __init__( aten_transform_passes: Optional[ List[Callable[[str, ExportedProgram], ExportedProgram]] ] = None, + strict=True, ) -> None: super().__init__() self._aten_transform_passes = aten_transform_passes + self.strict = strict @property def stage_type(self) -> str: @@ -147,7 +150,7 @@ def run(self, artifact: PipelineArtifact) -> None: model, example_inputs[method_name][0], dynamic_shapes=method_dynamic_shapes, - strict=True, + strict=self.strict, ) # Apply pre-edge transform passes if available