diff --git a/.ci/scripts/test_torchao_huggingface_checkpoints.sh b/.ci/scripts/test_torchao_huggingface_checkpoints.sh index 3c9ac598f8f..f06c794f88d 100644 --- a/.ci/scripts/test_torchao_huggingface_checkpoints.sh +++ b/.ci/scripts/test_torchao_huggingface_checkpoints.sh @@ -5,6 +5,7 @@ set -euxo pipefail # Args / flags # ------------------------- TEST_WITH_RUNNER=0 +USE_TORCHAO_KERNELS=0 MODEL_NAME="" # Parse args @@ -22,10 +23,14 @@ while [[ $# -gt 0 ]]; do --test_with_runner) TEST_WITH_RUNNER=1 ;; + --use_torchao_kernels) + USE_TORCHAO_KERNELS=1 + ;; -h|--help) - echo "Usage: $0 [--test_with_runner]" + echo "Usage: $0 [--test_with_runner] [--use_torchao_kernels]" echo " model_name: qwen3_4b | phi_4_mini" echo " --test_with_runner: build ET + run llama_main to sanity-check the export" + echo " --use_torchao_kernels: use torchao kernels for linear and tied embedding" exit 0 ;; *) @@ -42,6 +47,13 @@ fi MODEL_OUT=model.pte + +# Default to XNNPACK +BACKEND_ARGS="-X --xnnpack-extended-ops" +if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then + BACKEND_ARGS="--use-torchao-kernels" +fi + case "$MODEL_NAME" in qwen3_4b) echo "Running Qwen3-4B export..." @@ -58,12 +70,12 @@ case "$MODEL_NAME" in --output_name $MODEL_OUT \ -kv \ --use_sdpa_with_kv_cache \ - -X \ - --xnnpack-extended-ops \ --max_context_length 1024 \ --max_seq_length 1024 \ + --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \ + --verbose \ --dtype fp32 \ - --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' + ${BACKEND_ARGS} ;; phi_4_mini) @@ -81,12 +93,12 @@ case "$MODEL_NAME" in --output_name $MODEL_OUT \ -kv \ --use_sdpa_with_kv_cache \ - -X \ - --xnnpack-extended-ops \ --max_context_length 1024 \ --max_seq_length 1024 \ + --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \ + --verbose \ --dtype fp32 \ - --metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' + ${BACKEND_ARGS} ;; *) @@ -104,6 +116,10 @@ if [[ $MODEL_SIZE -gt $EXPECTED_MODEL_SIZE_UPPER_BOUND ]]; then fi # Install ET with CMake +EXECUTORCH_BUILD_KERNELS_TORCHAO="OFF" +if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then + EXECUTORCH_BUILD_KERNELS_TORCHAO="ON" +fi if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then echo "[runner] Building and testing llama_main ..." cmake -DPYTHON_EXECUTABLE=python \ @@ -120,6 +136,7 @@ if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then -DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \ -DEXECUTORCH_BUILD_EXTENSION_LLM=ON \ -DEXECUTORCH_BUILD_KERNELS_LLM=ON \ + -DEXECUTORCH_BUILD_KERNELS_TORCHAO=${EXECUTORCH_BUILD_KERNELS_TORCHAO} \ -Bcmake-out . cmake --build cmake-out -j16 --config Release --target install diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 362df17dc9b..ee2afb7576d 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -594,15 +594,22 @@ jobs: strategy: matrix: model: [qwen3_4b, phi_4_mini] + runner: [linux.2xlarge] + docker-image: [executorch-ubuntu-22.04-clang12] + backend: [xnnpack] include: - model: qwen3_4b - test_with_runner: true + runner: linux.arm64.2xlarge + docker-image: executorch-ubuntu-22.04-gcc11-aarch64 + backend: torchao - model: phi_4_mini - test_with_runner: false + runner: linux.arm64.2xlarge + docker-image: executorch-ubuntu-22.04-gcc11-aarch64 + backend: torchao fail-fast: false with: - runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-22.04-clang12 + runner: ${{ matrix.runner }} + docker-image: ci-image:${{ matrix.docker-image }} submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: 900 @@ -612,9 +619,14 @@ jobs: conda activate "${CONDA_ENV}" PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake + + if [[ "${{ matrix.backend }}" == "torchao" ]]; then + BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install third-party/ao + fi + pip install -U "huggingface_hub[cli]" - bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }} + bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.model != 'phi_4_mini' && '--test_with_runner' || '' }} ${{ matrix.backend == 'torchao' && '--use_torchao_kernels' || '' }} test-multimodal-macos: if: ${{ !github.event.pull_request.head.repo.fork }} diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7192204a141..aa3b157c8da 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -417,6 +417,21 @@ def build_args_parser() -> argparse.ArgumentParser: action="store_true", help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.", ) + parser.add_argument( + "--use-torchao-kernels", + action="store_true", + help="Delegate tied-embedding and quantized linear ops to torchao kernels", + ) + parser.add_argument( + "--use-torchao-kernels-tied-embedding", + action="store_true", + help="Delegate tied-embedding ops to torchao kernels", + ) + parser.add_argument( + "--use-torchao-kernels-linear", + action="store_true", + help="Delegate linear ops to torchao kernels", + ) parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument("--vulkan-force-fp16", action="store_true") parser.add_argument("--mps", action="store_true") @@ -741,6 +756,8 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: preq_group_size=llm_config.base.preq_group_size, preq_embedding_quantize=llm_config.base.preq_embedding_quantize, local_global_attention=llm_config.model.local_global_attention, + use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear, + use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding, ) ) @@ -1303,6 +1320,8 @@ def _get_source_transforms( # noqa preq_group_size: Optional[int] = None, preq_embedding_quantize: Optional[str] = None, local_global_attention: Optional[List[int]] = None, + use_torchao_kernels_linear: bool = False, + use_torchao_kernels_tied_embedding: bool = False, ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: """ Return a list of functions that transform a graph. @@ -1475,6 +1494,17 @@ def _get_source_transforms( # noqa ) ) + if any([use_torchao_kernels_linear, use_torchao_kernels_tied_embedding]): + from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64 + + transforms.append( + partial( + _convert_model_for_aarch64, + convert_linear=use_torchao_kernels_linear, + convert_tied_embedding=use_torchao_kernels_tied_embedding, + ) + ) + return transforms diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index d756d1886ad..b13001c005b 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -452,6 +452,16 @@ class MPSConfig: enabled: bool = False +@dataclass +class TorchAOKernelsConfig: + """ + Configures the torchao-kernels backend. + """ + + use_torchao_kernels_linear: bool = False + use_torchao_kernels_tied_embedding: bool = False + + @dataclass class BackendConfig: """ @@ -464,6 +474,7 @@ class BackendConfig: vulkan: VulkanConfig = field(default_factory=VulkanConfig) qnn: QNNConfig = field(default_factory=QNNConfig) mps: MPSConfig = field(default_factory=MPSConfig) + torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig) ################################################################################ @@ -632,6 +643,28 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 if hasattr(args, "mps"): llm_config.backend.mps.enabled = args.mps + # TorchAoKernels + if any( + hasattr(args, a) + for a in [ + "use_torchao_kernels", + "use_torchao_kernels_linear", + "use_torchao_kernels_tied_embedding", + ] + ): + if hasattr(args, "use_torchao_kernels") and args.use_torchao_kernels: + # Enable all conversions if torchao_kernels is specified + llm_config.backend.torchao.use_torchao_kernels_linear = True + llm_config.backend.torchao.use_torchao_kernels_tied_embedding = True + else: + # Otherwise, only enable the conversions that are specified + llm_config.backend.torchao.use_torchao_kernels_linear = getattr( + args, "use_torchao_kernels_linear", False + ) + llm_config.backend.torchao.use_torchao_kernels_tied_embedding = getattr( + args, "use_torchao_kernels_tied_embedding", False + ) + # DebugConfig if hasattr(args, "profile_memory"): llm_config.debug.profile_memory = args.profile_memory diff --git a/third-party/ao b/third-party/ao index b99904b34c0..b47f1a36550 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit b99904b34c0fd98f8a63ec57cbc1dc4993f74793 +Subproject commit b47f1a3655004b2b4dd3b4f01a5d8eebff1faa3c