diff --git a/CMakeLists.txt b/CMakeLists.txt index f2fba8921f5..e0c5e0fe840 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -763,6 +763,10 @@ if(EXECUTORCH_BUILD_PYBIND) list(APPEND _dep_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) endif() + if(EXECUTORCH_BUILD_VULKAN) + list(APPEND _dep_libs vulkan_backend) + endif() + # compile options for pybind set(_pybind_compile_options -Wno-deprecated-declarations -fPIC -frtti -fexceptions diff --git a/backends/vulkan/CMakeLists.txt b/backends/vulkan/CMakeLists.txt index 72d5fb8d830..29ff90e7293 100644 --- a/backends/vulkan/CMakeLists.txt +++ b/backends/vulkan/CMakeLists.txt @@ -101,7 +101,7 @@ set_target_properties(vulkan_schema PROPERTIES LINKER_LANGUAGE CXX) target_include_directories( vulkan_schema INTERFACE - ${SCHEMA_INCLUDE_DIR} + $ $ ) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 776d1d6e168..302b9af83e2 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -7,7 +7,7 @@ # pyre-strict import logging -from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple +from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple import executorch.backends.vulkan.utils as utils @@ -17,6 +17,7 @@ get_op_features, has_impl, OpFeatures, + OpKey, vulkan_supported_ops, ) @@ -55,11 +56,17 @@ def __init__( texture_limits: utils.ImageExtents, buffer_limit: int, require_dynamic_shape: bool = False, + operator_blocklist: Optional[Set[OpKey]] = None, + operator_allowlist: Optional[Set[OpKey]] = None, ) -> None: super().__init__() self.texture_limits: utils.ImageExtents = texture_limits self.buffer_limit = buffer_limit self.require_dynamic_shapes = require_dynamic_shape + self.operator_blocklist: Set[OpKey] = ( + operator_blocklist if operator_blocklist is not None else set() + ) + self.operator_allowlist = operator_allowlist def op_node_is_compatible( # noqa: C901: Function is too complex self, node: torch.fx.Node, features: Optional[OpFeatures] = None @@ -77,6 +84,17 @@ def op_node_is_compatible( # noqa: C901: Function is too complex assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() + # Operator allow list is only used for torch ops + if ( + utils.is_torch_op_node(node) + and (self.operator_allowlist is not None) + and (target not in self.operator_allowlist) + ): + return False, "op is not in allowlist" + + if target in self.operator_blocklist: + return False, "op is in blocklist" + # Extract the features for the node's operator, if no override was provided if features is None: if not has_impl(target): @@ -93,7 +111,7 @@ def op_node_is_compatible( # noqa: C901: Function is too complex if op_repsets.any_is_empty(): return ( False, - "No valid representations for a tensor in the operation", + f"no valid representations for op {utils.node_io_str(node)}", ) return True, "Op is compatible" @@ -277,6 +295,8 @@ class VulkanPartitioner(Partitioner): def __init__( self, compile_options: Optional[Dict[str, Any]] = None, + operator_blocklist: Optional[List[OpKey]] = None, + operator_allowlist: Optional[List[OpKey]] = None, ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: @@ -285,6 +305,18 @@ def __init__( compile_spec = parse_compile_options(self.options) self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec) + self.operator_blocklist: Set[OpKey] = set() + if operator_blocklist is not None: + for entry in operator_blocklist or []: + self.operator_blocklist.add(entry) + + self.operator_allowlist: Optional[Set[OpKey]] = None + if operator_allowlist is not None: + self.operator_allowlist = set() + for entry in operator_allowlist: + assert self.operator_allowlist is not None + self.operator_allowlist.add(entry) + def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: @@ -308,6 +340,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: texture_limits, buffer_limit, require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + operator_blocklist=self.operator_blocklist, + operator_allowlist=self.operator_allowlist, ), allows_single_node_partition=True, ) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index d42d7ab33be..9b6d53c5d05 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -1083,7 +1083,6 @@ def compile_spirv(shader_paths_pair) -> Tuple[str, str]: for spv_out_path, glsl_out_path in pool.map( compile_spirv, self.output_file_map.items() ): - print(spv_to_glsl_map) spv_to_glsl_map[spv_out_path] = glsl_out_path return spv_to_glsl_map diff --git a/backends/vulkan/test/scripts/test_model.sh b/backends/vulkan/test/scripts/test_model.sh new file mode 100755 index 00000000000..5f06d2c039b --- /dev/null +++ b/backends/vulkan/test/scripts/test_model.sh @@ -0,0 +1,180 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +# Initialize variables +RUN_BUILD=false +RUN_CORRECTNESS_TEST=false +RUN_CLEAN=false +RUN_RECOMPILE=false +MODEL_NAME="" +OUTPUT_DIRECTORY="." + +# Parse arguments +SKIP_NEXT=false +for i in $(seq 1 $#); do + if [[ "$SKIP_NEXT" == true ]]; then + SKIP_NEXT=false + continue + fi + + arg="${!i}" + case $arg in + --build|-b) + RUN_BUILD=true + ;; + --clean|-c) + RUN_CLEAN=true + ;; + --recompile|-rc) + RUN_RECOMPILE=true + ;; + --output_directory|-o) + next_i=$((i + 1)) + if [[ $next_i -le $# ]]; then + OUTPUT_DIRECTORY="${!next_i}" + SKIP_NEXT=true + else + echo "Error: --output_directory|-o requires a value" + exit 1 + fi + ;; + --*|-*) + echo "Unknown argument: $arg" + exit 1 + ;; + *) + if [[ -z "$MODEL_NAME" ]]; then + MODEL_NAME="$arg" + else + echo "Multiple model names provided: $MODEL_NAME and $arg" + exit 1 + fi + ;; + esac +done + +# Determine execution mode based on parsed arguments +if [[ "$RUN_BUILD" == true ]] && [[ -z "$MODEL_NAME" ]]; then + # Build-only mode + RUN_CORRECTNESS_TEST=false +elif [[ "$RUN_BUILD" == true ]] && [[ -n "$MODEL_NAME" ]]; then + # Build and test mode + RUN_CORRECTNESS_TEST=true +elif [[ "$RUN_BUILD" == false ]] && [[ -n "$MODEL_NAME" ]]; then + # Test-only mode + RUN_CORRECTNESS_TEST=true +else + echo "Invalid argument combination. Usage:" + echo " $0 --build|-b [--clean|-c] [--recompile|-rc] [-o|--output_directory DIR] # Build-only mode" + echo " $0 model_name [--build|-b] [--clean|-c] [--recompile|-rc] [-o|--output_directory DIR] # Test mode or build+test mode" + exit 1 +fi + +if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then + PYTHON_EXECUTABLE=python3 +fi +which "${PYTHON_EXECUTABLE}" + +CMAKE_OUTPUT_DIR=cmake-out + +# Only set EXPORTED_MODEL if running correctness test +if [[ "${RUN_CORRECTNESS_TEST}" == true ]]; then + EXPORTED_MODEL=${MODEL_NAME}_vulkan +fi + + +clean_build_directory() { + echo "Cleaning build directory: ${CMAKE_OUTPUT_DIR}" + rm -rf ${CMAKE_OUTPUT_DIR} +} + +recompile() { + cmake --build cmake-out -j64 --target install +} + +build_core_libraries_and_devtools() { + echo "Building core libraries and devtools with comprehensive Vulkan support..." + + # Build core libraries with all required components + cmake . \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM_AOT=ON \ + -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_VULKAN=ON \ + -DEXECUTORCH_BUILD_XNNPACK=ON \ + -Bcmake-out && \ + cmake --build cmake-out -j64 --target install + + # Build devtools example runner + cmake examples/devtools \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ + -DEXECUTORCH_BUILD_VULKAN=ON \ + -Bcmake-out/examples/devtools && \ + cmake --build cmake-out/examples/devtools -j16 --config Release +} + +run_example_runner() { + ./${CMAKE_OUTPUT_DIR}/examples/devtools/example_runner -bundled_program_path "${OUTPUT_DIRECTORY}/${EXPORTED_MODEL}.bpte" -output_verification +} + +test_bundled_model_with_vulkan() { + # Export model as bundled program with Vulkan backend + "${PYTHON_EXECUTABLE}" -m examples.vulkan.export --model_name="${MODEL_NAME}" --output_dir="${OUTPUT_DIRECTORY}" --bundled + + # Update exported model name for bundled program + EXPORTED_MODEL="${MODEL_NAME}_vulkan" + + # Verify the exported bundled model exists + if [[ ! -f "${OUTPUT_DIRECTORY}/${EXPORTED_MODEL}.bpte" ]]; then + echo "Error: Failed to export bundled model ${MODEL_NAME} with Vulkan backend" + exit 1 + fi + + # Note: Running bundled programs may require different executor runner + echo "Bundled program created successfully. Use appropriate bundled program runner to test." + + run_example_runner +} + + +# Main execution +if [[ "${RUN_BUILD}" == true ]]; then + if [[ "${RUN_CLEAN}" == true ]]; then + clean_build_directory + fi + build_core_libraries_and_devtools +fi + +if [[ "${RUN_RECOMPILE}" == true ]]; then + recompile +fi + +if [[ "${RUN_CORRECTNESS_TEST}" == true ]]; then + echo "Testing ${MODEL_NAME} with Vulkan backend..." + # Always use bundled program testing + test_bundled_model_with_vulkan + + # Check if test completed successfully + if [[ $? -eq 0 ]]; then + echo "Vulkan model test completed successfully!" + else + echo "Vulkan model test failed!" + exit 1 + fi +fi diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py new file mode 100644 index 00000000000..0d6776da6b7 --- /dev/null +++ b/backends/vulkan/test/utils.py @@ -0,0 +1,586 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from typing import List, Optional, Tuple + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend +from executorch.devtools import BundledProgram +from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite +from executorch.devtools.bundled_program.serialize import ( + serialize_from_bundled_program_to_flatbuffer, +) +from executorch.exir import ExecutorchProgramManager, to_edge_transform_and_lower +from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, +) +from executorch.extension.pytree import tree_flatten +from torch.export import export, export_for_training + + +def export_model_to_vulkan( + model, + sample_inputs, + dynamic_shapes=None, + operator_blocklist=None, + operator_allowlist=None, +): + """Helper to export a model to Vulkan backend.""" + compile_options = {} + export_training_graph = export_for_training( + model, sample_inputs, strict=True + ).module() + program = export( + export_training_graph, + sample_inputs, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + edge_program = to_edge_transform_and_lower( + program, + partitioner=[ + VulkanPartitioner( + compile_options, + operator_blocklist=operator_blocklist, + operator_allowlist=operator_allowlist, + ) + ], + transform_passes=None, + compile_config=None, + ) + + executorch_program = edge_program.to_executorch() + + # Check if the delegate ID matches VulkanBackend + if ( + executorch_program.executorch_program.execution_plan[0].delegates[0].id + != VulkanBackend.__name__ + ): + raise RuntimeError( + f"Expected delegate ID {VulkanBackend.__name__}, but got {executorch_program.executorch_program.execution_plan[0].delegates[0].id}" + ) + + return executorch_program + + +def export_model_to_xnnpack(model, sample_inputs, dynamic_shapes=None): + """Helper to export a model to XNNPACK backend.""" + compile_options = {} + export_training_graph = export_for_training( + model, sample_inputs, strict=True + ).module() + program = export( + export_training_graph, + sample_inputs, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + edge_program = to_edge_transform_and_lower( + program, + partitioner=[XnnpackPartitioner(compile_options)], + transform_passes=None, + compile_config=None, + ) + + executorch_program = edge_program.to_executorch() + + # Check if the delegate ID matches XnnpackBackend + if ( + executorch_program.executorch_program.execution_plan[0].delegates[0].id + != XnnpackBackend.__name__ + ): + raise RuntimeError( + f"Expected delegate ID {XnnpackBackend.__name__}, but got {executorch_program.executorch_program.execution_plan[0].delegates[0].id}" + ) + + return executorch_program + + +def check_outputs_equal( + model_output, ref_output, atol=1e-03, rtol=1e-03, first_output_only=False +): + """ + Helper function that checks if model output and reference output are equal with some tolerance. + Returns True if equal, False otherwise. + """ + # Compare the result from executor and eager mode directly + if isinstance(ref_output, tuple) or isinstance(ref_output, list): + # Multiple outputs executor always returns tuple, even if there is one output + if len(ref_output) != len(model_output): + return False + if first_output_only: + return torch.allclose(model_output[0], ref_output[0], atol=atol, rtol=rtol) + else: + for i in range(len(ref_output)): + if not torch.allclose( + model_output[i], ref_output[i], atol=atol, rtol=rtol + ): + return False + return True + else: + # If one output, eager returns tensor while executor tuple of size 1 + return torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol) + + +def run_and_check_output( + reference_model: torch.nn.Module, + executorch_program: ExecutorchProgramManager, + sample_inputs: Tuple[torch.Tensor], + atol=1e-03, + rtol=1e-01, + first_output_only=False, +) -> bool: + """ + Utility function that accepts an already lowered ExecuTorch program, executes it with + the provided sample input, and checks the output for correctness. + + Args: + executorch_program: Already lowered ExecutorchProgramManager + sample_inputs: Sample inputs to run the program with + reference_model: Reference model to generate reference outputs for comparison + atol: Absolute tolerance for output comparison + rtol: Relative tolerance for output comparison + first_output_only: Whether to compare only the first output + + Returns: + bool: True if outputs match within tolerance, False otherwise + """ + # Load the ExecutorTorch program + executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) + + # Flatten inputs for execution + inputs_flattened, _ = tree_flatten(sample_inputs) + + # Run the ExecutorTorch program + model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) + + # Generate reference outputs using the reference model + ref_output = reference_model(*sample_inputs) + + # Check if outputs are equal + return check_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ) + + +def lower_module_and_test_output( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + atol=1e-03, + rtol=1e-01, + dynamic_shapes=None, + test_inputs=None, + first_output_only=False, + operator_blocklist=None, + operator_allowlist=None, +) -> bool: + """ + Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with + the given sample inputs. It then runs the lowered module and compares its + outputs with the outputs of the eager module. + + Returns: + bool: True if all comparisons pass, False otherwise. + """ + # Export model to Vulkan using the helper function + executorch_program = export_model_to_vulkan( + model, sample_inputs, dynamic_shapes, operator_blocklist, operator_allowlist + ) + + executorch_module = _load_for_executorch_from_buffer(executorch_program.buffer) + + inputs_flattened, _ = tree_flatten(sample_inputs) + + model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) + ref_output = model(*sample_inputs) + + if not check_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ): + return False + + if test_inputs is not None: + for test_input in test_inputs: + test_inputs_flattened, _ = tree_flatten(test_input) + model_output = executorch_module.run_method( + "forward", tuple(test_inputs_flattened) + ) + ref_output = model(*test_input) + + if not check_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ): + return False + + return True + + +def save_bundled_program( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + output_path: str, + method_name: str = "forward", + et_program: Optional[ExecutorchProgramManager] = None, + dynamic_shapes=None, +) -> str: + """ + Export a bundled .pte file containing the model and test cases. + + Args: + model: The PyTorch model to export + sample_inputs: Sample inputs for the model + output_path: Path where the bundled .pte file should be saved (should end with .bpte) + method_name: Name of the method to test (default: "forward") + et_program: Optional pre-exported ExecutorchProgramManager. If None, will export to Vulkan + dynamic_shapes: Optional dynamic shapes for export + + Returns: + str: Path to the saved bundled program file + """ + # If no ExecutorchProgramManager provided, export to Vulkan + if et_program is None: + et_program = export_model_to_vulkan(model, sample_inputs, dynamic_shapes) + + # Generate expected outputs by running the model + expected_outputs = [getattr(model, method_name)(*sample_inputs)] + + # Flatten sample inputs to match expected format + inputs_flattened, _ = tree_flatten(sample_inputs) + + # Create test suite with the sample inputs and expected outputs + test_suites = [ + MethodTestSuite( + method_name=method_name, + test_cases=[ + MethodTestCase( + inputs=inputs_flattened, + expected_outputs=expected_outputs, + ) + ], + ) + ] + + # Create bundled program + bp = BundledProgram(et_program, test_suites) + + # Serialize to flatbuffer + bp_buffer = serialize_from_bundled_program_to_flatbuffer(bp) + + # Ensure output path has correct extension + if not output_path.endswith(".bpte"): + output_path = output_path + ".bpte" + + # Write to file + with open(output_path, "wb") as file: + file.write(bp_buffer) + return output_path + + +def save_executorch_program( + executorch_program: ExecutorchProgramManager, + output_path: str, +) -> str: + """ + Save an ExecutorchProgramManager as a .pte file. + + Args: + executorch_program: The ExecutorchProgramManager to save + output_path: Path where the .pte file should be saved (should end with .pte) + + Returns: + str: Path to the saved .pte file + """ + # Ensure output path has correct extension + if not output_path.endswith(".pte"): + output_path = output_path + ".pte" + + # Write to file + with open(output_path, "wb") as file: + executorch_program.write_to_file(file) + + return output_path + + +def print_occurrences(edge_program, operator_list: List): + """ + Print the input/output information for all occurrences of specified operators in the edge program. + + Args: + edge_program: The edge program created by to_edge_transform_and_lower + operator_list: List of operators to search for in the graph + """ + logger = logging.getLogger("") + logger.setLevel(logging.INFO) + + logger.info( + f"Searching for occurrences of {len(operator_list)} operators in the graph..." + ) + + occurrence_count = 0 + + for node in edge_program.exported_program().graph.nodes: + if utils.is_torch_op_node(node): + target = node.target + # Handle auto_functionalized nodes + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + if hasattr(first_arg, "name"): + target = first_arg.name() + elif hasattr(first_arg, "__name__"): + target = first_arg.__name__ + + # Check if this operator is in our list + if target in operator_list: + occurrence_count += 1 + logger.info(f"Occurrence {occurrence_count}: {node.format_node()}") + + # Get the node I/O string using the utils function + try: + io_str = utils.node_io_str(node) + logger.info(f" {io_str}") + except Exception as e: + logger.info(f" Error getting I/O string: {e}") + + if occurrence_count == 0: + logger.info("No occurrences of the specified operators found in the graph.") + else: + logger.info( + f"Found {occurrence_count} total occurrences of the specified operators." + ) + + +def op_ablation_test( # noqa: C901 + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + atol=1e-03, + rtol=1e-01, + dynamic_shapes=None, + test_inputs=None, + first_output_only=False, +) -> dict: + """ + Fast binary search utility function to determine which operators work correctly when delegated to Vulkan. + + This function uses a binary search approach to efficiently find bad operators: + 1. Split operators into two halves (least frequent first, most frequent second) + 2. Test each half to see if it produces correct output + 3. Add good halves to known_good_ops and recursively search bad halves + 4. Continue until all operators are classified + + Args: + model: The PyTorch model to test + sample_inputs: Sample inputs for the model + atol: Absolute tolerance for output comparison + rtol: Relative tolerance for output comparison + dynamic_shapes: Optional dynamic shapes for export + test_inputs: Optional additional test inputs + first_output_only: Whether to compare only the first output + + Returns: + dict: Dictionary with keys: + - 'good_operators': List of operators that work correctly + - 'bad_operators': List of operators that cause failures + - 'operator_frequencies': Dictionary mapping operators to their occurrence count + - 'all_operators': List of all unique operators found in the graph + - 'test_count': Number of tests performed + """ + logger = logging.getLogger("") + logger.setLevel(logging.INFO) + + logger.info("Starting fast binary search operator ablation test...") + + # Step 1: Export model to get edge_program and extract operators + export_training_graph = export_for_training( + model, sample_inputs, strict=True + ).module() + program = export( + export_training_graph, + sample_inputs, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + edge_program = to_edge_transform_and_lower( + program, + partitioner=[], # No partitioner to get the full graph + transform_passes=None, + compile_config=None, + ) + + # Step 2: Scan edge_program.graph_module to obtain unique operators and their frequencies + operator_frequencies = {} + for node in edge_program.exported_program().graph.nodes: + if utils.is_torch_op_node(node): + target = node.target + # Handle auto_functionalized nodes + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + if hasattr(first_arg, "name"): + target = first_arg.name() + elif hasattr(first_arg, "__name__"): + target = first_arg.__name__ + + if target in operator_frequencies: + operator_frequencies[target] += 1 + else: + operator_frequencies[target] = 1 + + all_operators = list(operator_frequencies.keys()) + logger.info(f"Found {len(all_operators)} unique operators in the graph") + + # Sort operators by frequency (least frequent first for binary search) + operators_by_frequency = sorted( + all_operators, key=lambda op: operator_frequencies[op] + ) + + logger.info("Operator frequencies (sorted by occurrence, least frequent first):") + for op in operators_by_frequency: + logger.info(f" {op}: {operator_frequencies[op]} occurrences") + + # Global test counter + test_count = 0 + + def test_operator_set(ops_to_test: List, known_good_ops: List) -> bool: + """Test if a set of operators works correctly when combined with known good operators.""" + nonlocal test_count + test_count += 1 + + test_allowlist = known_good_ops + ops_to_test + logger.info( + f"Test {test_count}: Testing {len(ops_to_test)} operators with {len(known_good_ops)} known good" + ) + + try: + success = lower_module_and_test_output( + model=model, + sample_inputs=sample_inputs, + atol=atol, + rtol=rtol, + dynamic_shapes=dynamic_shapes, + test_inputs=test_inputs, + first_output_only=first_output_only, + operator_allowlist=test_allowlist, + ) + logger.info(f" {'✓ PASS' if success else '✗ FAIL'}") + return success + except Exception as e: + logger.info(f" ! Error: {e}") + return False + + def find_bad_operators( + ops_to_test: List, known_good_ops: List + ) -> Tuple[List, List]: + """ + Recursively find bad operators using binary search. + + Returns: + Tuple of (good_operators, bad_operators) from ops_to_test + """ + if not ops_to_test: + return [], [] + + if len(ops_to_test) == 1: + # Base case: single operator + op = ops_to_test[0] + if test_operator_set([op], known_good_ops): + logger.info(f" Single operator {op} is GOOD") + return [op], [] + else: + logger.info(f" Single operator {op} is BAD") + return [], [op] + + # Split ops_to_test into two halves + mid = len(ops_to_test) // 2 + first_half = ops_to_test[:mid] # Least frequent operators + second_half = ops_to_test[mid:] # Most frequent operators + + logger.info( + f"Splitting {len(ops_to_test)} operators: {len(first_half)} + {len(second_half)}" + ) + + # Test each half + first_half_good = test_operator_set(first_half, known_good_ops) + second_half_good = test_operator_set(second_half, known_good_ops) + + good_ops = [] + bad_ops = [] + + # Process first half + if first_half_good: + logger.info( + f"First half ({len(first_half)} ops) is good - adding to known good" + ) + good_ops.extend(first_half) + known_good_ops.extend(first_half) + if second_half_good: + logger.info( + f"Second half ({len(second_half)} ops) is good - adding to known good" + ) + good_ops.extend(second_half) + + if not first_half_good: + logger.info(f"First half ({len(first_half)} ops) is bad - recursing") + sub_good, sub_bad = find_bad_operators(first_half, known_good_ops) + good_ops.extend(sub_good) + bad_ops.extend(sub_bad) + known_good_ops.extend(sub_good) + if not second_half_good: + logger.info(f"Second half ({len(second_half)} ops) is bad - recursing") + sub_good, sub_bad = find_bad_operators(second_half, known_good_ops) + good_ops.extend(sub_good) + bad_ops.extend(sub_bad) + + return good_ops, bad_ops + + # Start the binary search + logger.info( + f"\n=== Starting binary search on {len(operators_by_frequency)} operators ===" + ) + good_operators, bad_operators = find_bad_operators(operators_by_frequency, []) + + # Summary of results + logger.info(f"\n=== Binary search complete after {test_count} tests ===") + logger.info(f"Good operators ({len(good_operators)}):") + for op in good_operators: + logger.info(f" ✓ {op} (frequency: {operator_frequencies[op]})") + + logger.info(f"Bad operators ({len(bad_operators)}):") + for op in bad_operators: + logger.info(f" ✗ {op} (frequency: {operator_frequencies[op]})") + + print_occurrences(edge_program, bad_operators) + + efficiency_gain = len(all_operators) - test_count + logger.info( + f"Efficiency: {test_count} tests instead of {len(all_operators)} (saved {efficiency_gain} tests)" + ) + + return { + "good_operators": good_operators, + "bad_operators": bad_operators, + "operator_frequencies": operator_frequencies, + "all_operators": all_operators, + "test_count": test_count, + } diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index fa45063a4d3..1765f0b5e1c 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -18,6 +18,8 @@ format_target_name, ) +from executorch.exir.dialects.edge._ops import EdgeOpOverload + from executorch.exir.tensor import TensorSpec from torch._export.utils import is_buffer, is_param @@ -54,6 +56,18 @@ MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]] +def is_torch_op_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + + if isinstance(node.target, EdgeOpOverload): + return True + if isinstance(node.target, torch._ops.OpOverload): + return True + + return False + + def is_dequant_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False @@ -1033,6 +1047,49 @@ def get_node_repr(node) -> Union[TensorRepr, TensorReprList]: ## +def get_tensor_val_str(tensor_val: FakeTensor) -> str: + return f"{tensor_val.dtype}: {tensor_val.shape}" + + +def get_node_val_str(node: torch.fx.Node) -> str: + if is_single_tensor_node(node): + assert isinstance(node.meta["val"], FakeTensor) + return get_tensor_val_str(node.meta["val"]) + elif is_tensor_collection_node(node): + assert isinstance(node.meta["val"], (list, tuple)) + return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]" + else: + return str(node.meta["val"]) + + +def get_arg_node_val_str(arg_node: Any) -> str: + if isinstance(arg_node, torch.fx.Node): + return get_node_val_str(arg_node) + elif isinstance(arg_node, (list, tuple)): + return f"[{', '.join(get_arg_node_val_str(n) for n in arg_node)}]" + else: + return str(arg_node) + + +def node_io_str(node: torch.fx.Node) -> str: + target = node.target + if isinstance(target, EdgeOpOverload): + assert isinstance(target, EdgeOpOverload) + target_name = target.__name__ + elif isinstance(target, torch._ops.OpOverload): + assert isinstance(target, torch._ops.OpOverload) + target_name = target.name() + else: + target_name = str(target) + + out_str = f"{get_node_val_str(node)} = {target_name}(" + for arg in node.args: + out_str += get_arg_node_val_str(arg) + ", " + + out_str += " ...)" + return out_str + + def update_program_state_dict( program: ExportedProgram, buffer_name: str, diff --git a/examples/vulkan/README.md b/examples/vulkan/README.md new file mode 100644 index 00000000000..71fdd0e4183 --- /dev/null +++ b/examples/vulkan/README.md @@ -0,0 +1,80 @@ +# Vulkan Delegate Export Examples + +This directory contains scripts for exporting models with the Vulkan delegate in ExecuTorch. Vulkan delegation allows you to run your models on devices with Vulkan-capable GPUs, potentially providing significant performance improvements over CPU execution. + +## Scripts + +- `export.py`: Basic export script for models to use with Vulkan delegate +- `aot_compiler.py`: Advanced export script with quantization support + +## Usage + +### Basic Export + +```bash +python -m executorch.examples.vulkan.export -m -o +``` + +### Export with Quantization (Experimental) + +```bash +python -m executorch.examples.vulkan.aot_compiler -m -q -o +``` + +### Dynamic Shape Support + +```bash +python -m executorch.examples.vulkan.export -m -d -o +``` + +### Additional Options + +- `-s/--strict`: Export with strict mode (default: True) +- `-a/--segment_alignment`: Specify segment alignment in hex (default: 0x1000) +- `-e/--external_constants`: Save constants in external .ptd file (default: False) +- `-r/--etrecord`: Generate and save an ETRecord to the given file location + +## Examples + +```bash +# Export MobileNetV2 with Vulkan delegate +python -m executorch.examples.vulkan.export -m mobilenet_v2 -o ./exported_models + +# Export MobileNetV3 with quantization +python -m executorch.examples.vulkan.aot_compiler -m mobilenet_v3 -q -o ./exported_models + +# Export with dynamic shapes +python -m executorch.examples.vulkan.export -m mobilenet_v2 -d -o ./exported_models + +# Export with ETRecord for debugging +python -m executorch.examples.vulkan.export -m mobilenet_v2 -r ./records/mobilenet_record.etrecord -o ./exported_models +``` + +## Supported Operations + +The Vulkan delegate supports various operations including: + +- Basic arithmetic (add, subtract, multiply, divide) +- Activations (ReLU, Sigmoid, Tanh, etc.) +- Convolutions (Conv1d, Conv2d, ConvTranspose2d) +- Pooling operations (MaxPool2d, AvgPool2d) +- Linear/Fully connected layers +- BatchNorm, GroupNorm +- Various tensor operations (cat, reshape, permute, etc.) + +For a complete list of supported operations, refer to the Vulkan delegate implementation in the ExecuTorch codebase. + +## Debugging and Optimization + +If you encounter issues with Vulkan delegation: + +1. Use `-r/--etrecord` to generate an ETRecord for debugging +2. Check if your operations are supported by the Vulkan delegate +3. Ensure your Vulkan drivers are up to date +4. Try using the export script with `--strict False` if strict mode causes issues + +## Requirements + +- Vulkan runtime libraries (libvulkan.so.1) +- A Vulkan-capable GPU with appropriate drivers +- PyTorch with Vulkan support diff --git a/examples/vulkan/__init__.py b/examples/vulkan/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/examples/vulkan/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py new file mode 100644 index 00000000000..b01bf7d37f3 --- /dev/null +++ b/examples/vulkan/export.py @@ -0,0 +1,241 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting models to flatbuffer with the Vulkan delegate + +# pyre-unsafe + +import argparse +import logging + +import backends.vulkan.test.utils as test_utils + +import torch + +from executorch.backends.transforms.convert_dtype_pass import I64toI32 +from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner +from executorch.devtools import BundledProgram +from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite +from executorch.devtools.bundled_program.serialize import ( + serialize_from_bundled_program_to_flatbuffer, +) +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.extension.export_util.utils import save_pte_program +from executorch.extension.pytree import tree_flatten +from torch.export import export + +from ..models import MODEL_NAME_TO_MODEL +from ..models.model_factory import EagerModelFactory + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + logger = logging.getLogger("") + logger.setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", + ) + + parser.add_argument( + "-s", + "--strict", + action=argparse.BooleanOptionalAction, + default=True, + help="whether to export with strict mode. Default is True", + ) + + parser.add_argument( + "-a", + "--segment_alignment", + required=False, + help="specify segment alignment in hex. Default is 0x1000. Use 0x4000 for iOS", + ) + + parser.add_argument( + "-e", + "--external_constants", + action=argparse.BooleanOptionalAction, + default=False, + help="Save constants in external .ptd file. Default is False", + ) + + parser.add_argument( + "-d", + "--dynamic", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable dynamic shape support. Default is False", + ) + + parser.add_argument( + "-r", + "--etrecord", + required=False, + default="", + help="Generate and save an ETRecord to the given file location", + ) + + parser.add_argument("-o", "--output_dir", default=".", help="output directory") + + parser.add_argument( + "-b", + "--bundled", + action=argparse.BooleanOptionalAction, + default=False, + help="Export as bundled program (.bpte) instead of regular program (.pte). Default is False", + ) + + parser.add_argument( + "-t", + "--test", + action=argparse.BooleanOptionalAction, + default=False, + help="Execute lower_module_and_test_output to validate the model. Default is False", + ) + + args = parser.parse_args() + + if args.model_name not in MODEL_NAME_TO_MODEL: + raise RuntimeError( + f"Model {args.model_name} is not a valid name. " + f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}." + ) + + model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL[args.model_name] + ) + + # Prepare model + model.eval() + + # Setup compile options + compile_options = {} + if args.dynamic or dynamic_shapes is not None: + compile_options["require_dynamic_shapes"] = True + + # Configure Edge compilation + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # Proper handling for Vulkan memory format + ) + + logging.info(f"Exporting model {args.model_name} with Vulkan delegate") + + # Export the model using torch.export + if dynamic_shapes is not None: + program = export( + model, example_inputs, dynamic_shapes=dynamic_shapes, strict=args.strict + ) + else: + program = export(model, example_inputs, strict=args.strict) + + # Transform and lower with Vulkan partitioner + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + generate_etrecord=args.etrecord, + ) + + logging.info( + f"Exported and lowered graph:\n{edge_program.exported_program().graph}" + ) + + # Configure backend options + backend_config = ExecutorchBackendConfig(external_constants=args.external_constants) + if args.segment_alignment is not None: + backend_config.segment_alignment = int(args.segment_alignment, 16) + + # Create executorch program + exec_prog = edge_program.to_executorch(config=backend_config) + + # Save ETRecord if requested + if args.etrecord: + exec_prog.get_etrecord().save(args.etrecord) + logging.info(f"Saved ETRecord to {args.etrecord}") + + # Save the program + output_filename = f"{args.model_name}_vulkan" + + # Test the model if --test flag is provided + if args.test: + test_result = test_utils.run_and_check_output( + reference_model=model, + executorch_program=exec_prog, + sample_inputs=example_inputs, + ) + + if test_result: + logging.info( + "✓ Model test PASSED - outputs match reference within tolerance" + ) + else: + logging.error("✗ Model test FAILED - outputs do not match reference") + raise RuntimeError( + "Model validation failed: ExecutorTorch outputs do not match reference model outputs" + ) + + if args.bundled: + # Create bundled program + logging.info("Creating bundled program with test cases") + + # Generate expected outputs by running the model + expected_outputs = [model(*example_inputs)] + + # Flatten sample inputs to match expected format + inputs_flattened, _ = tree_flatten(example_inputs) + + # Create test suite with the sample inputs and expected outputs + test_suites = [ + MethodTestSuite( + method_name="forward", + test_cases=[ + MethodTestCase( + inputs=inputs_flattened, + expected_outputs=expected_outputs, + ) + ], + ) + ] + + # Create bundled program + bp = BundledProgram(exec_prog, test_suites) + + # Serialize to flatbuffer + bp_buffer = serialize_from_bundled_program_to_flatbuffer(bp) + + # Save bundled program + bundled_output_path = f"{args.output_dir}/{output_filename}.bpte" + with open(bundled_output_path, "wb") as file: + file.write(bp_buffer) + + logging.info( + f"Bundled program exported and saved as {output_filename}.bpte in {args.output_dir}" + ) + else: + # Save regular program + save_pte_program(exec_prog, output_filename, args.output_dir) + logging.info( + f"Model exported and saved as {output_filename}.pte in {args.output_dir}" + ) + + +if __name__ == "__main__": + with torch.no_grad(): + main() # pragma: no cover