Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion export/export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -200,7 +201,9 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
aten_transform_passes = list(
self._export_recipe.aten_transform_passes
)
stage = TorchExportStage(aten_transform_passes)
stage = TorchExportStage(
aten_transform_passes, strict=self._export_recipe.strict
)
elif stage_type == StageType.TO_EDGE_TRANSFORM_AND_LOWER:
stage = EdgeTransformAndLowerStage.from_recipe(self._lowering_recipe)
elif stage_type == StageType.TO_EDGE:
Expand Down
3 changes: 3 additions & 0 deletions export/recipe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -151,6 +152,7 @@ class ExportRecipe:
executorch_backend_config: Optional backend configuration for ExecuTorch
pipeline_stages: Optional list of stages to execute, defaults to a standard pipeline.
mode: Export mode (debug or release)
strict: Set the strict flag in the torch export call.
"""

name: Optional[str] = None
Expand All @@ -163,6 +165,7 @@ class ExportRecipe:
executorch_backend_config: Optional[ExecutorchBackendConfig] = None
pipeline_stages: Optional[List[StageType]] = None
mode: Mode = Mode.RELEASE
strict: bool = True

@classmethod
def get_recipe(cls, recipe: "RecipeType", **kwargs) -> "ExportRecipe":
Expand Down
5 changes: 4 additions & 1 deletion export/stages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -110,9 +111,11 @@ def __init__(
aten_transform_passes: Optional[
List[Callable[[str, ExportedProgram], ExportedProgram]]
] = None,
strict=True,
) -> None:
super().__init__()
self._aten_transform_passes = aten_transform_passes
self.strict = strict

@property
def stage_type(self) -> str:
Expand Down Expand Up @@ -147,7 +150,7 @@ def run(self, artifact: PipelineArtifact) -> None:
model,
example_inputs[method_name][0],
dynamic_shapes=method_dynamic_shapes,
strict=True,
strict=self.strict,
)

# Apply pre-edge transform passes if available
Expand Down
Loading