From 3558453dea54210dbadd0aba064b3f69e8b56b8b Mon Sep 17 00:00:00 2001 From: SS-JIA Date: Mon, 11 Aug 2025 15:30:44 -0400 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- examples/devtools/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/devtools/CMakeLists.txt b/examples/devtools/CMakeLists.txt index 38a98e83dd7..a0d35f2aad8 100644 --- a/examples/devtools/CMakeLists.txt +++ b/examples/devtools/CMakeLists.txt @@ -65,6 +65,10 @@ target_link_libraries( portable_kernels ) +if (EXECUTORCH_BUILD_VULKAN) + target_link_libraries(example_runner vulkan_backend) +endif() + if(EXECUTORCH_BUILD_COREML) find_library(ACCELERATE_FRAMEWORK Accelerate) find_library(COREML_FRAMEWORK CoreML) From 4d13668d968b9f3b06702f60eb45b3bc815a24e1 Mon Sep 17 00:00:00 2001 From: SS-JIA Date: Mon, 11 Aug 2025 15:30:49 -0400 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- backends/vulkan/test/scripts/test_model.sh | 181 ++++++++++++++++++ examples/vulkan/README.md | 80 ++++++++ examples/vulkan/__init__.py | 5 + examples/vulkan/aot_compiler.py | 204 ++++++++++++++++++++ examples/vulkan/export.py | 211 +++++++++++++++++++++ 5 files changed, 681 insertions(+) create mode 100755 backends/vulkan/test/scripts/test_model.sh create mode 100644 examples/vulkan/README.md create mode 100644 examples/vulkan/__init__.py create mode 100644 examples/vulkan/aot_compiler.py create mode 100644 examples/vulkan/export.py diff --git a/backends/vulkan/test/scripts/test_model.sh b/backends/vulkan/test/scripts/test_model.sh new file mode 100755 index 00000000000..6b5b9b02dc8 --- /dev/null +++ b/backends/vulkan/test/scripts/test_model.sh @@ -0,0 +1,181 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +# Initialize variables +RUN_BUILD=false +RUN_CORRECTNESS_TEST=false +RUN_CLEAN=false +RUN_RECOMPILE=false +MODEL_NAME="" +OUTPUT_DIRECTORY="." + +# Parse arguments +SKIP_NEXT=false +for i in $(seq 1 $#); do + if [[ "$SKIP_NEXT" == true ]]; then + SKIP_NEXT=false + continue + fi + + arg="${!i}" + case $arg in + --build|-b) + RUN_BUILD=true + ;; + --clean|-c) + RUN_CLEAN=true + ;; + --recompile|-rc) + RUN_RECOMPILE=true + ;; + --output_directory|-o) + next_i=$((i + 1)) + if [[ $next_i -le $# ]]; then + OUTPUT_DIRECTORY="${!next_i}" + SKIP_NEXT=true + else + echo "Error: --output_directory|-o requires a value" + exit 1 + fi + ;; + --*|-*) + echo "Unknown argument: $arg" + exit 1 + ;; + *) + if [[ -z "$MODEL_NAME" ]]; then + MODEL_NAME="$arg" + else + echo "Multiple model names provided: $MODEL_NAME and $arg" + exit 1 + fi + ;; + esac +done + +# Determine execution mode based on parsed arguments +if [[ "$RUN_BUILD" == true ]] && [[ -z "$MODEL_NAME" ]]; then + # Build-only mode + RUN_CORRECTNESS_TEST=false +elif [[ "$RUN_BUILD" == true ]] && [[ -n "$MODEL_NAME" ]]; then + # Build and test mode + RUN_CORRECTNESS_TEST=true +elif [[ "$RUN_BUILD" == false ]] && [[ -n "$MODEL_NAME" ]]; then + # Test-only mode + RUN_CORRECTNESS_TEST=true +else + echo "Invalid argument combination. Usage:" + echo " $0 --build|-b [--clean|-c] [--recompile|-rc] [-o|--output_directory DIR] # Build-only mode" + echo " $0 model_name [--build|-b] [--clean|-c] [--recompile|-rc] [-o|--output_directory DIR] # Test mode or build+test mode" + exit 1 +fi + +if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then + PYTHON_EXECUTABLE=python3 +fi +which "${PYTHON_EXECUTABLE}" + +CMAKE_OUTPUT_DIR=cmake-out + +# Only set EXPORTED_MODEL if running correctness test +if [[ "${RUN_CORRECTNESS_TEST}" == true ]]; then + EXPORTED_MODEL=${MODEL_NAME}_vulkan +fi + + +clean_build_directory() { + echo "Cleaning build directory: ${CMAKE_OUTPUT_DIR}" + rm -rf ${CMAKE_OUTPUT_DIR} +} + +recompile() { + cmake --build cmake-out -j64 --target install +} + +build_core_libraries_and_devtools() { + echo "Building core libraries and devtools with comprehensive Vulkan support..." + + # Build core libraries with all required components + cmake . \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM_AOT=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_VULKAN=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -DEXECUTORCH_BUILD_TESTS=ON \ + -Bcmake-out && \ + cmake --build cmake-out -j64 --target install + + # Build devtools example runner + cmake examples/devtools \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ + -DEXECUTORCH_BUILD_VULKAN=ON \ + -Bcmake-out/examples/devtools && \ + cmake --build cmake-out/examples/devtools -j16 --config Release +} + +run_example_runner() { + ./${CMAKE_OUTPUT_DIR}/examples/devtools/example_runner -bundled_program_path "${OUTPUT_DIRECTORY}/${EXPORTED_MODEL}.bpte" -output_verification +} + +test_bundled_model_with_vulkan() { + # Export model as bundled program with Vulkan backend + "${PYTHON_EXECUTABLE}" -m examples.vulkan.export --model_name="${MODEL_NAME}" --output_dir="${OUTPUT_DIRECTORY}" --bundled + + # Update exported model name for bundled program + EXPORTED_MODEL="${MODEL_NAME}_vulkan" + + # Verify the exported bundled model exists + if [[ ! -f "${OUTPUT_DIRECTORY}/${EXPORTED_MODEL}.bpte" ]]; then + echo "Error: Failed to export bundled model ${MODEL_NAME} with Vulkan backend" + exit 1 + fi + + # Note: Running bundled programs may require different executor runner + echo "Bundled program created successfully. Use appropriate bundled program runner to test." + + run_example_runner +} + + +# Main execution +if [[ "${RUN_BUILD}" == true ]]; then + if [[ "${RUN_CLEAN}" == true ]]; then + clean_build_directory + fi + build_core_libraries_and_devtools +fi + +if [[ "${RUN_RECOMPILE}" == true ]]; then + recompile +fi + +if [[ "${RUN_CORRECTNESS_TEST}" == true ]]; then + echo "Testing ${MODEL_NAME} with Vulkan backend..." + # Always use bundled program testing + test_bundled_model_with_vulkan + + # Check if test completed successfully + if [[ $? -eq 0 ]]; then + echo "Vulkan model test completed successfully!" + else + echo "Vulkan model test failed!" + exit 1 + fi +fi diff --git a/examples/vulkan/README.md b/examples/vulkan/README.md new file mode 100644 index 00000000000..71fdd0e4183 --- /dev/null +++ b/examples/vulkan/README.md @@ -0,0 +1,80 @@ +# Vulkan Delegate Export Examples + +This directory contains scripts for exporting models with the Vulkan delegate in ExecuTorch. Vulkan delegation allows you to run your models on devices with Vulkan-capable GPUs, potentially providing significant performance improvements over CPU execution. + +## Scripts + +- `export.py`: Basic export script for models to use with Vulkan delegate +- `aot_compiler.py`: Advanced export script with quantization support + +## Usage + +### Basic Export + +```bash +python -m executorch.examples.vulkan.export -m -o +``` + +### Export with Quantization (Experimental) + +```bash +python -m executorch.examples.vulkan.aot_compiler -m -q -o +``` + +### Dynamic Shape Support + +```bash +python -m executorch.examples.vulkan.export -m -d -o +``` + +### Additional Options + +- `-s/--strict`: Export with strict mode (default: True) +- `-a/--segment_alignment`: Specify segment alignment in hex (default: 0x1000) +- `-e/--external_constants`: Save constants in external .ptd file (default: False) +- `-r/--etrecord`: Generate and save an ETRecord to the given file location + +## Examples + +```bash +# Export MobileNetV2 with Vulkan delegate +python -m executorch.examples.vulkan.export -m mobilenet_v2 -o ./exported_models + +# Export MobileNetV3 with quantization +python -m executorch.examples.vulkan.aot_compiler -m mobilenet_v3 -q -o ./exported_models + +# Export with dynamic shapes +python -m executorch.examples.vulkan.export -m mobilenet_v2 -d -o ./exported_models + +# Export with ETRecord for debugging +python -m executorch.examples.vulkan.export -m mobilenet_v2 -r ./records/mobilenet_record.etrecord -o ./exported_models +``` + +## Supported Operations + +The Vulkan delegate supports various operations including: + +- Basic arithmetic (add, subtract, multiply, divide) +- Activations (ReLU, Sigmoid, Tanh, etc.) +- Convolutions (Conv1d, Conv2d, ConvTranspose2d) +- Pooling operations (MaxPool2d, AvgPool2d) +- Linear/Fully connected layers +- BatchNorm, GroupNorm +- Various tensor operations (cat, reshape, permute, etc.) + +For a complete list of supported operations, refer to the Vulkan delegate implementation in the ExecuTorch codebase. + +## Debugging and Optimization + +If you encounter issues with Vulkan delegation: + +1. Use `-r/--etrecord` to generate an ETRecord for debugging +2. Check if your operations are supported by the Vulkan delegate +3. Ensure your Vulkan drivers are up to date +4. Try using the export script with `--strict False` if strict mode causes issues + +## Requirements + +- Vulkan runtime libraries (libvulkan.so.1) +- A Vulkan-capable GPU with appropriate drivers +- PyTorch with Vulkan support diff --git a/examples/vulkan/__init__.py b/examples/vulkan/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/vulkan/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/vulkan/aot_compiler.py b/examples/vulkan/aot_compiler.py new file mode 100644 index 00000000000..4f95ffa183a --- /dev/null +++ b/examples/vulkan/aot_compiler.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for compiling models with Vulkan delegation + +# pyre-unsafe + +import argparse +import logging + +import torch +from executorch.backends.transforms.convert_dtype_pass import I64toI32 +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer + +from ..models import MODEL_NAME_TO_MODEL +from ..models.model_factory import EagerModelFactory + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs, + quantizer: Quantizer, + dynamic_shapes=None, +) -> torch.nn.Module: + """Quantize a model and lower it with Vulkan delegation""" + compile_options = {} + if dynamic_shapes is not None: + compile_options["require_dynamic_shapes"] = True + + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # Proper handling for Vulkan memory format + ) + + program = torch.export.export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = torch.export.export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"Model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", + ) + parser.add_argument( + "-q", + "--quantize", + action="store_true", + required=False, + default=False, + help="Produce a quantized model. Note: Quantization support may vary by model.", + ) + parser.add_argument( + "-d", + "--delegate", + action="store_true", + required=False, + default=True, + help="Produce a Vulkan delegated model", + ) + parser.add_argument( + "-y", + "--dynamic", + action="store_true", + required=False, + default=False, + help="Enable dynamic shape support", + ) + parser.add_argument( + "-r", + "--etrecord", + required=False, + default="", + help="Generate and save an ETRecord to the given file location", + ) + parser.add_argument("-o", "--output_dir", default=".", help="output directory") + + args = parser.parse_args() + + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[args.model_name] + ) + + model = model.eval() + + if args.dynamic and dynamic_shapes is None: + logging.warning("Dynamic shapes requested but not available for this model.") + + dynamic_shapes_to_use = dynamic_shapes if args.dynamic else None + + # Configure Edge compilation + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # Proper handling for Vulkan memory format + _check_ir_validity=True, + ) + + # Setup compile options + compile_options = {} + if dynamic_shapes_to_use is not None: + compile_options["require_dynamic_shapes"] = True + + if args.quantize: + logging.info("Quantization for Vulkan not fully supported yet. Using experimental path.") + try: + # Try to import quantization utilities if available + try: + from ..quantization.utils import get_quantizer_for_model + quantizer = get_quantizer_for_model(args.model_name) + except ImportError: + # If the specific utility isn't available, create a basic quantizer + logging.warning("Quantization utils not found. Using default quantizer.") + from torchao.quantization.pt2e.quantizer import get_default_quantizer + quantizer = get_default_quantizer() + + edge = quantize_and_lower_module( + model, example_inputs, quantizer, dynamic_shapes=dynamic_shapes_to_use + ) + except (ImportError, NotImplementedError) as e: + logging.error(f"Quantization failed: {e}") + logging.info("Falling back to non-quantized path") + # Export the model using torch.export + program = torch.export.export( + model, example_inputs, dynamic_shapes=dynamic_shapes_to_use, strict=True + ) + + # Transform and lower with Vulkan partitioner + edge = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + generate_etrecord=args.etrecord, + ) + else: + # Standard non-quantized path + # Export the model using torch.export + program = torch.export.export( + model, example_inputs, dynamic_shapes=dynamic_shapes_to_use, strict=True + ) + + # Transform and lower with Vulkan partitioner + edge = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + generate_etrecord=args.etrecord, + ) + + logging.info(f"Exported and lowered graph:\n{edge.exported_program().graph}") + + exec_prog = edge.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=False) + ) + + if args.etrecord: + exec_prog.get_etrecord().save(args.etrecord) + logging.info(f"Saved ETRecord to {args.etrecord}") + + quant_tag = "q8" if args.quantize else "fp32" + model_name = f"{args.model_name}_vulkan_{quant_tag}" + save_pte_program(exec_prog, model_name, args.output_dir) + logging.info(f"Model exported and saved as {model_name}.pte in {args.output_dir}") diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py new file mode 100644 index 00000000000..46c09a791d2 --- /dev/null +++ b/examples/vulkan/export.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting models to flatbuffer with the Vulkan delegate + +# pyre-unsafe + +import argparse +import logging + +import torch + +from executorch.backends.transforms.convert_dtype_pass import I64toI32 +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.devtools import BundledProgram +from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite +from executorch.devtools.bundled_program.serialize import ( + serialize_from_bundled_program_to_flatbuffer, +) +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program +from executorch.extension.pytree import tree_flatten +from torch.export import export + +from ..models import MODEL_NAME_TO_MODEL +from ..models.model_factory import EagerModelFactory + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", + ) + + parser.add_argument( + "-s", + "--strict", + action=argparse.BooleanOptionalAction, + default=True, + help="whether to export with strict mode. Default is True", + ) + + parser.add_argument( + "-a", + "--segment_alignment", + required=False, + help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", + ) + + parser.add_argument( + "-e", + "--external_constants", + action=argparse.BooleanOptionalAction, + default=False, + help="Save constants in external .ptd file. Default is False", + ) + + parser.add_argument( + "-d", + "--dynamic", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable dynamic shape support. Default is False", + ) + + parser.add_argument( + "-r", + "--etrecord", + required=False, + default="", + help="Generate and save an ETRecord to the given file location", + ) + + parser.add_argument("-o", "--output_dir", default=".", help="output directory") + + parser.add_argument( + "-b", + "--bundled", + action=argparse.BooleanOptionalAction, + default=False, + help="Export as bundled program (.bpte) instead of regular program (.pte). Default is False", + ) + + args = parser.parse_args() + + if args.model_name not in MODEL_NAME_TO_MODEL: + raise RuntimeError( + f"Model {args.model_name} is not a valid name. " + f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + ) + + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[args.model_name] + ) + + # Prepare model + model.eval() + + # Setup compile options + compile_options = {} + if args.dynamic or dynamic_shapes is not None: + compile_options["require_dynamic_shapes"] = True + + # Configure Edge compilation + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # Proper handling for Vulkan memory format + ) + + logging.info(f"Exporting model {args.model_name} with Vulkan delegate") + + # Export the model using torch.export + if dynamic_shapes is not None: + program = export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=args.strict + ) + else: + program = export(model, example_inputs, strict=args.strict) + + # Transform and lower with Vulkan partitioner + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + generate_etrecord=args.etrecord, + ) + + logging.info( + f"Exported and lowered graph:\n{edge_program.exported_program().graph}" + ) + + # Configure backend options + backend_config = ExecutorchBackendConfig(external_constants=args.external_constants) + if args.segment_alignment is not None: + backend_config.segment_alignment = int(args.segment_alignment, 16) + + # Create executorch program + exec_prog = edge_program.to_executorch(config=backend_config) + + # Save ETRecord if requested + if args.etrecord: + exec_prog.get_etrecord().save(args.etrecord) + logging.info(f"Saved ETRecord to {args.etrecord}") + + # Save the program + output_filename = f"{args.model_name}_vulkan" + + if args.bundled: + # Create bundled program + logging.info("Creating bundled program with test cases") + + # Generate expected outputs by running the model + expected_outputs = [model(*example_inputs)] + + # Flatten sample inputs to match expected format + inputs_flattened, _ = tree_flatten(example_inputs) + + # Create test suite with the sample inputs and expected outputs + test_suites = [ + MethodTestSuite( + method_name="forward", + test_cases=[ + MethodTestCase( + inputs=inputs_flattened, + expected_outputs=expected_outputs, + ) + ], + ) + ] + + # Create bundled program + bp = BundledProgram(exec_prog, test_suites) + + # Serialize to flatbuffer + bp_buffer = serialize_from_bundled_program_to_flatbuffer(bp) + + # Save bundled program + bundled_output_path = f"{args.output_dir}/{output_filename}.bpte" + with open(bundled_output_path, "wb") as file: + file.write(bp_buffer) + + logging.info( + f"Bundled program exported and saved as {output_filename}.bpte in {args.output_dir}" + ) + else: + # Save regular program + save_pte_program(exec_prog, output_filename, args.output_dir) + logging.info( + f"Model exported and saved as {output_filename}.pte in {args.output_dir}" + ) + + +if __name__ == "__main__": + with torch.no_grad(): + main() # pragma: no cover From bbfacb0e91167c2a499ece65338b05107f140116 Mon Sep 17 00:00:00 2001 From: SS-JIA Date: Mon, 11 Aug 2025 15:30:54 -0400 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- .github/workflows/pull.yml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index d39e9a43f25..3d096488bdd 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -862,6 +862,38 @@ jobs: PYTHON_EXECUTABLE=python bash examples/nxp/run_aot_example.sh + test-vulkan-models-linux: + name: test-vulkan-models-linux + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + with: + runner: linux.2xlarge + docker-image: ci-image:executorch-ubuntu-22.04-clang12 + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + set -eux + + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + # Setup swiftshader and Vulkan SDK which are required to build the Vulkan delegate + source .ci/scripts/setup-vulkan-linux-deps.sh + + # Setup python + PYTHON_EXECUTABLE=python \ + CMAKE_ARGS="-DEXECUTORCH_BUILD_VULKAN=ON" \ + .ci/scripts/setup-linux.sh --build-tool "cmake" + + PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build + + # Test models serially + PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh mv2 + nxp-build-test: name: nxp-build-test uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main From c66e28609968f4a3a3ebf99514c233fbeede144d Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 11 Aug 2025 15:35:51 -0700 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- backends/vulkan/runtime/VulkanBackend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index ceb95f3a304..743835a137b 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -594,7 +594,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { compute_graph->execute(); - for (size_t i = 0; i < compute_graph->outputs().size(); i++) { + for (size_t i = 0; i < compute_graph->outputs().size(); ++i) { const size_t o = i + num_inputs; const ValueRef oref = compute_graph->outputs()[i].value; if (compute_graph->val_is_tensor(oref)) { From eadf7ea147d81b3dd715c8a83ad64c2e49e850a1 Mon Sep 17 00:00:00 2001 From: SS-JIA Date: Tue, 12 Aug 2025 12:27:24 -0400 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- .../vulkan/partitioner/vulkan_partitioner.py | 38 ++++++++++++- backends/vulkan/utils.py | 57 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 776d1d6e168..302b9af83e2 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -7,7 +7,7 @@ # pyre-strict import logging -from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple import executorch.backends.vulkan.utils as utils @@ -17,6 +17,7 @@ get_op_features, has_impl, OpFeatures, + OpKey, vulkan_supported_ops, ) @@ -55,11 +56,17 @@ def __init__( texture_limits: utils.ImageExtents, buffer_limit: int, require_dynamic_shape: bool = False, + operator_blocklist: Optional[Set[OpKey]] = None, + operator_allowlist: Optional[Set[OpKey]] = None, ) -> None: super().__init__() self.texture_limits: utils.ImageExtents = texture_limits self.buffer_limit = buffer_limit self.require_dynamic_shapes = require_dynamic_shape + self.operator_blocklist: Set[OpKey] = ( + operator_blocklist if operator_blocklist is not None else set() + ) + self.operator_allowlist = operator_allowlist def op_node_is_compatible( # noqa: C901: Function is too complex self, node: torch.fx.Node, features: Optional[OpFeatures] = None @@ -77,6 +84,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() + # Operator allow list is only used for torch ops + if ( + utils.is_torch_op_node(node) + and (self.operator_allowlist is not None) + and (target not in self.operator_allowlist) + ): + return False, "op is not in allowlist" + + if target in self.operator_blocklist: + return False, "op is in blocklist" + # Extract the features for the node's operator, if no override was provided if features is None: if not has_impl(target): @@ -93,7 +111,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex if op_repsets.any_is_empty(): return ( False, - "No valid representations for a tensor in the operation", + f"no valid representations for op {utils.node_io_str(node)}", ) return True, "Op is compatible" @@ -277,6 +295,8 @@ class VulkanPartitioner(Partitioner): def __init__( self, compile_options: Optional[Dict[str, Any]] = None, + operator_blocklist: Optional[List[OpKey]] = None, + operator_allowlist: Optional[List[OpKey]] = None, ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: @@ -285,6 +305,18 @@ def __init__( compile_spec = parse_compile_options(self.options) self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec) + self.operator_blocklist: Set[OpKey] = set() + if operator_blocklist is not None: + for entry in operator_blocklist or []: + self.operator_blocklist.add(entry) + + self.operator_allowlist: Optional[Set[OpKey]] = None + if operator_allowlist is not None: + self.operator_allowlist = set() + for entry in operator_allowlist: + assert self.operator_allowlist is not None + self.operator_allowlist.add(entry) + def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -308,6 +340,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: texture_limits, buffer_limit, require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + operator_blocklist=self.operator_blocklist, + operator_allowlist=self.operator_allowlist, ), allows_single_node_partition=True, ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index fa45063a4d3..1765f0b5e1c 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -18,6 +18,8 @@ format_target_name, ) +from executorch.exir.dialects.edge._ops import EdgeOpOverload + from executorch.exir.tensor import TensorSpec from torch._export.utils import is_buffer, is_param @@ -54,6 +56,18 @@ MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]] +def is_torch_op_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + if isinstance(node.target, EdgeOpOverload): + return True + if isinstance(node.target, torch._ops.OpOverload): + return True + + return False + + def is_dequant_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False @@ -1033,6 +1047,49 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]: ## +def get_tensor_val_str(tensor_val: FakeTensor) -> str: + return f"{tensor_val.dtype}: {tensor_val.shape}" + + +def get_node_val_str(node: torch.fx.Node) -> str: + if is_single_tensor_node(node): + assert isinstance(node.meta["val"], FakeTensor) + return get_tensor_val_str(node.meta["val"]) + elif is_tensor_collection_node(node): + assert isinstance(node.meta["val"], (list, tuple)) + return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]" + else: + return str(node.meta["val"]) + + +def get_arg_node_val_str(arg_node: Any) -> str: + if isinstance(arg_node, torch.fx.Node): + return get_node_val_str(arg_node) + elif isinstance(arg_node, (list, tuple)): + return f"[{', '.join(get_arg_node_val_str(n) for n in arg_node)}]" + else: + return str(arg_node) + + +def node_io_str(node: torch.fx.Node) -> str: + target = node.target + if isinstance(target, EdgeOpOverload): + assert isinstance(target, EdgeOpOverload) + target_name = target.__name__ + elif isinstance(target, torch._ops.OpOverload): + assert isinstance(target, torch._ops.OpOverload) + target_name = target.name() + else: + target_name = str(target) + + out_str = f"{get_node_val_str(node)} = {target_name}(" + for arg in node.args: + out_str += get_arg_node_val_str(arg) + ", " + + out_str += " ...)" + return out_str + + def update_program_state_dict( program: ExportedProgram, buffer_name: str, From e904823d43ba488bb843babf5925af4c9ce75b01 Mon Sep 17 00:00:00 2001 From: SS-JIA Date: Tue, 12 Aug 2025 12:27:28 -0400 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- CMakeLists.txt | 4 ++++ backends/vulkan/CMakeLists.txt | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f2fba8921f5..e0c5e0fe840 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -763,6 +763,10 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) endif() + if(EXECUTORCH_BUILD_VULKAN) + list(APPEND _dep_libs vulkan_backend) + endif() + # compile options for pybind set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti -fexceptions diff --git a/backends/vulkan/CMakeLists.txt b/backends/vulkan/CMakeLists.txt index 72d5fb8d830..29ff90e7293 100644 --- a/backends/vulkan/CMakeLists.txt +++ b/backends/vulkan/CMakeLists.txt @@ -101,7 +101,7 @@ set_target_properties(vulkan_schema PROPERTIES LINKER_LANGUAGE CXX) target_include_directories( vulkan_schema INTERFACE - ${SCHEMA_INCLUDE_DIR} + $ $ ) From f1ed16a48d577a676e8ae990a498c6bd88306410 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 12 Aug 2025 12:25:59 -0700 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- backends/vulkan/cmake/ShaderLibrary.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/cmake/ShaderLibrary.cmake b/backends/vulkan/cmake/ShaderLibrary.cmake index 3e396b8a8a0..2583488f80d 100644 --- a/backends/vulkan/cmake/ShaderLibrary.cmake +++ b/backends/vulkan/cmake/ShaderLibrary.cmake @@ -51,8 +51,8 @@ function(gen_vulkan_shader_lib_cpp shaders_path) set(GEN_SPV_ARGS "--optimize") if(DEFINED ENV{ETVK_USING_SWIFTSHADER} - AND (("$ENV{ETVK_USING_SWIFTSHADER}" STREQUAL "1") - OR ("$ENV{ETVK_USING_SWIFTSHADER}" STREQUAL "True")) + and (("$ENV{ETVK_USING_SWIFTSHADER}" strequal "1") + or ("$ENV{ETVK_USING_SWIFTSHADER}" strequal "True")) ) list(APPEND GEN_SPV_ARGS "--replace-u16vecn") endif() From 2048d61a1cac02e1822112a9b945a16f65ba0a9e Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 12 Aug 2025 13:28:13 -0700 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- backends/vulkan/cmake/ShaderLibrary.cmake | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/cmake/ShaderLibrary.cmake b/backends/vulkan/cmake/ShaderLibrary.cmake index 2583488f80d..1b6838c4dfd 100644 --- a/backends/vulkan/cmake/ShaderLibrary.cmake +++ b/backends/vulkan/cmake/ShaderLibrary.cmake @@ -50,11 +50,12 @@ function(gen_vulkan_shader_lib_cpp shaders_path) set(VULKAN_SHADERGEN_OUT_PATH ${CMAKE_BINARY_DIR}/vulkan_compute_shaders) set(GEN_SPV_ARGS "--optimize") - if(DEFINED ENV{ETVK_USING_SWIFTSHADER} - and (("$ENV{ETVK_USING_SWIFTSHADER}" strequal "1") - or ("$ENV{ETVK_USING_SWIFTSHADER}" strequal "True")) - ) - list(APPEND GEN_SPV_ARGS "--replace-u16vecn") + if(DEFINED ENV{ETVK_USING_SWIFTSHADER}) + if("$ENV{ETVK_USING_SWIFTSHADER}" STREQUAL "1" + OR "$ENV{ETVK_USING_SWIFTSHADER}" STREQUAL "True" + ) + list(APPEND GEN_SPV_ARGS "--replace-u16vecn") + endif() endif() add_custom_command(