From b9c5a51592897023ac1a7bdf996782fa3026a3e2 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 2 Jul 2025 09:52:58 -0700 Subject: [PATCH] Remove _skip_type_promotion config (#12149) Summary: I don't think we need this any more, assuming tests pass. It starts with an underscore, so it's private, so we can just remove it, right? Reviewed By: larryliu0820 Differential Revision: D77619493 Pulled By: swolchok --- examples/apple/coreml/llama/export.py | 1 - exir/capture/_config.py | 1 - exir/program/_program.py | 4 +--- extension/llm/export/builder.py | 1 - test/end2end/exported_module.py | 3 +-- test/models/export_program.py | 8 -------- 6 files changed, 2 insertions(+), 16 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 9aa232fa691..a367a14c595 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -206,7 +206,6 @@ def main() -> None: ], compile_config=EdgeCompileConfig( _check_ir_validity=False, - _skip_type_promotion=(float_dtype == torch.float16), _skip_dim_order=True, ), ) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 80a838737fc..d66bc24976d 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -43,7 +43,6 @@ class EdgeCompileConfig: _core_aten_ops_exception_list: List[torch._ops.OpOverload] = field( default_factory=list ) - _skip_type_promotion: bool = False # TODO(gasoonjia): remove this _skip_dim_order: bool = False diff --git a/exir/program/_program.py b/exir/program/_program.py index 0c4469c96de..8ef02f233ac 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -652,9 +652,7 @@ def _get_aten_to_edge_passes(config: EdgeCompileConfig): # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta. # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost. - pre_op_replace_passes = base_pre_op_replace_passes + ( - [] if config._skip_type_promotion else [RemoveMixedTypeOperators()] - ) + pre_op_replace_passes = base_pre_op_replace_passes + [RemoveMixedTypeOperators()] post_op_replace_passes = base_post_op_replace_passes diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 4128bfd8198..333a18cdf84 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -214,7 +214,6 @@ def _get_dynamic_shape(self) -> Any: def _get_edge_config(self) -> EdgeCompileConfig: edge_config = EdgeCompileConfig( _check_ir_validity=False, - _skip_type_promotion=bool(self.dtype == DType.fp16), _skip_dim_order=True, ) return edge_config diff --git a/test/end2end/exported_module.py b/test/end2end/exported_module.py index a8124d62dd4..e5630b8e89f 100644 --- a/test/end2end/exported_module.py +++ b/test/end2end/exported_module.py @@ -67,7 +67,6 @@ def export( ignore_to_out_var_failure: bool = False, dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, capture_config=None, - skip_type_promotion: bool = False, export_joint_graph: bool = False, external_constants: bool = False, export_state_names: bool = False, @@ -194,7 +193,7 @@ def __init__(self, method): exec_prog = to_edge( exported_methods, compile_config=exir.EdgeCompileConfig( - _check_ir_validity=False, _skip_type_promotion=skip_type_promotion + _check_ir_validity=False, ), ).to_executorch( ExecutorchBackendConfig( diff --git a/test/models/export_program.py b/test/models/export_program.py index dac42ecee1c..fae75743eb3 100644 --- a/test/models/export_program.py +++ b/test/models/export_program.py @@ -269,7 +269,6 @@ def get_random_inputs(self): def export_module_to_program( module_class: Type[nn.Module], - skip_type_promotion: bool, external_constants: bool = False, ) -> ExecutorchProgramManager: """Exports the module and returns the serialized program data.""" @@ -293,7 +292,6 @@ def export_module_to_program( module = ExportedModule.export( module_class, methods, - skip_type_promotion=skip_type_promotion, export_joint_graph=export_joint, external_constants=external_constants, export_state_names=export_state_names, @@ -342,17 +340,11 @@ def main() -> None: # Export and write to the output files. os.makedirs(args.outdir, exist_ok=True) for module_name, module_class in module_names_to_classes.items(): - skip_type_promotion = False - if module_name == "ModuleAddHalf": - # Skip type promotion to keep the model in fp16. - # Type promotion will convert to fp32. - skip_type_promotion = True if args.external_constants: module_name = f"{module_name}Program" outfile = os.path.join(args.outdir, f"{module_name}.pte") prog = export_module_to_program( module_class, - skip_type_promotion=skip_type_promotion, external_constants=args.external_constants, ) with open(outfile, "wb") as fp: