Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def main() -> None:
],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_type_promotion=(float_dtype == torch.float16),
_skip_dim_order=True,
),
)
Expand Down
1 change: 0 additions & 1 deletion exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class EdgeCompileConfig:
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
default_factory=list
)
_skip_type_promotion: bool = False
# TODO(gasoonjia): remove this
_skip_dim_order: bool = False

Expand Down
4 changes: 1 addition & 3 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,9 +652,7 @@ def _get_aten_to_edge_passes(config: EdgeCompileConfig):
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.

pre_op_replace_passes = base_pre_op_replace_passes + (
[] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
)
pre_op_replace_passes = base_pre_op_replace_passes + [RemoveMixedTypeOperators()]

post_op_replace_passes = base_post_op_replace_passes

Expand Down
1 change: 0 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def _get_dynamic_shape(self) -> Any:
def _get_edge_config(self) -> EdgeCompileConfig:
edge_config = EdgeCompileConfig(
_check_ir_validity=False,
_skip_type_promotion=bool(self.dtype == DType.fp16),
_skip_dim_order=True,
)
return edge_config
Expand Down
3 changes: 1 addition & 2 deletions test/end2end/exported_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def export(
ignore_to_out_var_failure: bool = False,
dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND,
capture_config=None,
skip_type_promotion: bool = False,
export_joint_graph: bool = False,
external_constants: bool = False,
export_state_names: bool = False,
Expand Down Expand Up @@ -194,7 +193,7 @@ def __init__(self, method):
exec_prog = to_edge(
exported_methods,
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False, _skip_type_promotion=skip_type_promotion
_check_ir_validity=False,
),
).to_executorch(
ExecutorchBackendConfig(
Expand Down
8 changes: 0 additions & 8 deletions test/models/export_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ def get_random_inputs(self):

def export_module_to_program(
module_class: Type[nn.Module],
skip_type_promotion: bool,
external_constants: bool = False,
) -> ExecutorchProgramManager:
"""Exports the module and returns the serialized program data."""
Expand All @@ -293,7 +292,6 @@ def export_module_to_program(
module = ExportedModule.export(
module_class,
methods,
skip_type_promotion=skip_type_promotion,
export_joint_graph=export_joint,
external_constants=external_constants,
export_state_names=export_state_names,
Expand Down Expand Up @@ -342,17 +340,11 @@ def main() -> None:
# Export and write to the output files.
os.makedirs(args.outdir, exist_ok=True)
for module_name, module_class in module_names_to_classes.items():
skip_type_promotion = False
if module_name == "ModuleAddHalf":
# Skip type promotion to keep the model in fp16.
# Type promotion will convert to fp32.
skip_type_promotion = True
if args.external_constants:
module_name = f"{module_name}Program"
outfile = os.path.join(args.outdir, f"{module_name}.pte")
prog = export_module_to_program(
module_class,
skip_type_promotion=skip_type_promotion,
external_constants=args.external_constants,
)
with open(outfile, "wb") as fp:
Expand Down
Loading