Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions .ci/scripts/test_torchao_huggingface_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set -euxo pipefail
# Args / flags
# -------------------------
TEST_WITH_RUNNER=0
USE_TORCHAO_KERNELS=0
MODEL_NAME=""

# Parse args
Expand All @@ -22,10 +23,14 @@ while [[ $# -gt 0 ]]; do
--test_with_runner)
TEST_WITH_RUNNER=1
;;
--use_torchao_kernels)
USE_TORCHAO_KERNELS=1
;;
-h|--help)
echo "Usage: $0 <model_name> [--test_with_runner]"
echo "Usage: $0 <model_name> [--test_with_runner] [--use_torchao_kernels]"
echo " model_name: qwen3_4b | phi_4_mini"
echo " --test_with_runner: build ET + run llama_main to sanity-check the export"
echo " --use_torchao_kernels: use torchao kernels for linear and tied embedding"
exit 0
;;
*)
Expand All @@ -42,6 +47,13 @@ fi

MODEL_OUT=model.pte


# Default to XNNPACK
BACKEND_ARGS="-X --xnnpack-extended-ops"
if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then
BACKEND_ARGS="--use-torchao-kernels"
fi

case "$MODEL_NAME" in
qwen3_4b)
echo "Running Qwen3-4B export..."
Expand All @@ -58,12 +70,12 @@ case "$MODEL_NAME" in
--output_name $MODEL_OUT \
-kv \
--use_sdpa_with_kv_cache \
-X \
--xnnpack-extended-ops \
--max_context_length 1024 \
--max_seq_length 1024 \
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
--verbose \
--dtype fp32 \
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
${BACKEND_ARGS}
;;

phi_4_mini)
Expand All @@ -81,12 +93,12 @@ case "$MODEL_NAME" in
--output_name $MODEL_OUT \
-kv \
--use_sdpa_with_kv_cache \
-X \
--xnnpack-extended-ops \
--max_context_length 1024 \
--max_seq_length 1024 \
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}' \
--verbose \
--dtype fp32 \
--metadata '{"get_bos_id":199999, "get_eos_ids":[200020,199999]}'
${BACKEND_ARGS}
;;

*)
Expand All @@ -104,6 +116,10 @@ if [[ $MODEL_SIZE -gt $EXPECTED_MODEL_SIZE_UPPER_BOUND ]]; then
fi

# Install ET with CMake
EXECUTORCH_BUILD_KERNELS_TORCHAO="OFF"
if [[ "$USE_TORCHAO_KERNELS" -eq 1 ]]; then
EXECUTORCH_BUILD_KERNELS_TORCHAO="ON"
fi
if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then
echo "[runner] Building and testing llama_main ..."
cmake -DPYTHON_EXECUTABLE=python \
Expand All @@ -120,6 +136,7 @@ if [[ "$TEST_WITH_RUNNER" -eq 1 ]]; then
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
-DEXECUTORCH_BUILD_KERNELS_LLM=ON \
-DEXECUTORCH_BUILD_KERNELS_TORCHAO=${EXECUTORCH_BUILD_KERNELS_TORCHAO} \
-Bcmake-out .
cmake --build cmake-out -j16 --config Release --target install

Expand Down
22 changes: 17 additions & 5 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -594,15 +594,22 @@ jobs:
strategy:
matrix:
model: [qwen3_4b, phi_4_mini]
runner: [linux.2xlarge]
docker-image: [executorch-ubuntu-22.04-clang12]
backend: [xnnpack]
include:
- model: qwen3_4b
test_with_runner: true
runner: linux.arm64.2xlarge
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
backend: torchao
- model: phi_4_mini
test_with_runner: false
runner: linux.arm64.2xlarge
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
backend: torchao
fail-fast: false
with:
runner: linux.2xlarge
docker-image: ci-image:executorch-ubuntu-22.04-clang12
runner: ${{ matrix.runner }}
docker-image: ci-image:${{ matrix.docker-image }}
submodules: 'recursive'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 900
Expand All @@ -612,9 +619,14 @@ jobs:
conda activate "${CONDA_ENV}"

PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool cmake

if [[ "${{ matrix.backend }}" == "torchao" ]]; then
BUILD_TORCHAO_EXPERIMENTAL=1 TORCHAO_BUILD_CPU_AARCH64=1 TORCHAO_BUILD_KLEIDIAI=1 TORCHAO_ENABLE_ARM_NEON_DOT=1 TORCHAO_PARALLEL_BACKEND=OPENMP pip install third-party/ao
fi

pip install -U "huggingface_hub[cli]"

bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }}
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.model != 'phi_4_mini' && '--test_with_runner' || '' }} ${{ matrix.backend == 'torchao' && '--use_torchao_kernels' || '' }}

test-multimodal-macos:
if: ${{ !github.event.pull_request.head.repo.fork }}
Expand Down
30 changes: 30 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,21 @@ def build_args_parser() -> argparse.ArgumentParser:
action="store_true",
help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.",
)
parser.add_argument(
"--use-torchao-kernels",
action="store_true",
help="Delegate tied-embedding and quantized linear ops to torchao kernels",
)
Comment on lines +420 to +424
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this when it's combining the below two args?

parser.add_argument(
"--use-torchao-kernels-tied-embedding",
action="store_true",
help="Delegate tied-embedding ops to torchao kernels",
)
parser.add_argument(
"--use-torchao-kernels-linear",
action="store_true",
help="Delegate linear ops to torchao kernels",
)
parser.add_argument("-V", "--vulkan", action="store_true")
parser.add_argument("--vulkan-force-fp16", action="store_true")
parser.add_argument("--mps", action="store_true")
Expand Down Expand Up @@ -741,6 +756,8 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
preq_group_size=llm_config.base.preq_group_size,
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
local_global_attention=llm_config.model.local_global_attention,
use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear,
use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
)
)

Expand Down Expand Up @@ -1303,6 +1320,8 @@ def _get_source_transforms( # noqa
preq_group_size: Optional[int] = None,
preq_embedding_quantize: Optional[str] = None,
local_global_attention: Optional[List[int]] = None,
use_torchao_kernels_linear: bool = False,
use_torchao_kernels_tied_embedding: bool = False,
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
"""
Return a list of functions that transform a graph.
Expand Down Expand Up @@ -1475,6 +1494,17 @@ def _get_source_transforms( # noqa
)
)

if any([use_torchao_kernels_linear, use_torchao_kernels_tied_embedding]):
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64

transforms.append(
partial(
_convert_model_for_aarch64,
convert_linear=use_torchao_kernels_linear,
convert_tied_embedding=use_torchao_kernels_tied_embedding,
)
)

return transforms


Expand Down
33 changes: 33 additions & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,16 @@ class MPSConfig:
enabled: bool = False


@dataclass
class TorchAOKernelsConfig:
"""
Configures the torchao-kernels backend.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow the other backend config examples and use enabled?

use_torchao_kernels_linear: bool = False
use_torchao_kernels_tied_embedding: bool = False


@dataclass
class BackendConfig:
"""
Expand All @@ -464,6 +474,7 @@ class BackendConfig:
vulkan: VulkanConfig = field(default_factory=VulkanConfig)
qnn: QNNConfig = field(default_factory=QNNConfig)
mps: MPSConfig = field(default_factory=MPSConfig)
torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig)


################################################################################
Expand Down Expand Up @@ -632,6 +643,28 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
if hasattr(args, "mps"):
llm_config.backend.mps.enabled = args.mps

# TorchAoKernels
if any(
hasattr(args, a)
for a in [
"use_torchao_kernels",
"use_torchao_kernels_linear",
"use_torchao_kernels_tied_embedding",
]
):
if hasattr(args, "use_torchao_kernels") and args.use_torchao_kernels:
# Enable all conversions if torchao_kernels is specified
llm_config.backend.torchao.use_torchao_kernels_linear = True
llm_config.backend.torchao.use_torchao_kernels_tied_embedding = True
else:
# Otherwise, only enable the conversions that are specified
llm_config.backend.torchao.use_torchao_kernels_linear = getattr(
args, "use_torchao_kernels_linear", False
)
llm_config.backend.torchao.use_torchao_kernels_tied_embedding = getattr(
args, "use_torchao_kernels_tied_embedding", False
)

# DebugConfig
if hasattr(args, "profile_memory"):
llm_config.debug.profile_memory = args.profile_memory
Expand Down
2 changes: 1 addition & 1 deletion third-party/ao
Submodule ao updated 146 files
Loading