diff --git a/torchchat/export.py b/torchchat/export.py index 7a7923119..7c5243b68 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -78,11 +78,11 @@ def export_for_server( dynamic_shapes=dynamic_shapes, options=options, ) - + if package: from torch._inductor.package import package_aoti path = package_aoti(output_path, path) - + print(f"The generated packaged model can be found at: {path}") return path @@ -382,7 +382,7 @@ def main(args): if builder_args.max_seq_length is None: if ( - output_dso_path is not None + (output_dso_path is not None or output_aoti_package_path is not None) and not builder_args.dynamic_shapes ): print("Setting max_seq_length to 300 for DSO export.") diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index 1b649ffbc..005bb6ef2 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -11,7 +11,7 @@ from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -77,31 +77,39 @@ def unpack_packed_weights( def set_backend(dso, pte, aoti_package): global active_builder_args_dso global active_builder_args_pte + global active_builder_args_aoti_package active_builder_args_dso = dso active_builder_args_aoti_package = aoti_package active_builder_args_pte = pte class _Backend(Enum): - AOTI = (0,) + AOTI = 0 EXECUTORCH = 1 -def _active_backend() -> _Backend: +def _active_backend() -> Optional[_Backend]: global active_builder_args_dso global active_builder_args_aoti_package global active_builder_args_pte - # eager == aoti, which is when backend has not been explicitly set - if (not active_builder_args_pte) and (not active_builder_args_aoti_package): - return True + args = ( + active_builder_args_dso, + active_builder_args_pte, + active_builder_args_aoti_package, + ) + + # Return None, as default + if not any(args): + return None - if active_builder_args_pte and active_builder_args_aoti_package: + # Catch more than one arg + if sum(map(bool, args)) > 1: raise RuntimeError( - "code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!" + "Code generation needs to choose different implementations. Please only use one export option, and call export twice if necessary!" ) - return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH + return _Backend.EXECUTORCH if active_builder_args_pte else _Backend.AOTI def use_aoti_backend() -> bool: