From 1ad66ee8d4e4d3b73eedda39ea5411ae520dd65b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Thu, 11 Sep 2025 10:08:03 +0200 Subject: [PATCH 1/2] Reapply "Update default executor runner with new optional options" (#14193) This reverts commit dc190f93fa43da3892f71b8103f8ed785719adb7. Original PR, https://github.com/pytorch/executorch/pull/14017 was reverted by https://github.com/pytorch/executorch/pull/14193. This reverts the revert. The only difference from the original PR is fixing a printf format mismatch in inputs.cpp: %ld is changed to %zu. This should compile on both 32- and 64-bit ABIs. Change-Id: If487e6c6bde313844a94db99a7431af33dfcdd0a --- .../arm/scripts/build_executor_runner_vkml.sh | 2 + backends/arm/test/ops/test_acos.py | 13 +- backends/arm/test/ops/test_add.py | 9 +- backends/arm/test/runner_utils.py | 93 ++++++++------ .../executor_runner/executor_runner.cpp | 113 ++++++++++++++++-- extension/runner_util/inputs.cpp | 36 +++++- extension/runner_util/inputs.h | 11 +- extension/runner_util/inputs_aten.cpp | 9 +- extension/runner_util/inputs_portable.cpp | 10 +- 9 files changed, 235 insertions(+), 61 deletions(-) diff --git a/backends/arm/scripts/build_executor_runner_vkml.sh b/backends/arm/scripts/build_executor_runner_vkml.sh index ecd33768577..1df63acc425 100755 --- a/backends/arm/scripts/build_executor_runner_vkml.sh +++ b/backends/arm/scripts/build_executor_runner_vkml.sh @@ -64,6 +64,8 @@ fi echo "Building with extra flags: ${build_with_etdump_flags} ${extra_build_flags}" cmake \ + -Wall \ + -Werror \ -DCMAKE_BUILD_TYPE=${build_type} \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ diff --git a/backends/arm/test/ops/test_acos.py b/backends/arm/test/ops/test_acos.py index 102d979352e..28dadcf95be 100644 --- a/backends/arm/test/ops/test_acos.py +++ b/backends/arm/test/ops/test_acos.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import Tuple +import pytest import torch from executorch.backends.arm.test import common @@ -102,8 +103,12 @@ def test_acos_vgf_FP(test_data: Tuple): [], [], tosa_version="TOSA-1.0+FP", + run_on_vulkan_runtime=True, ) - pipeline.run() + try: + pipeline.run() + except FileNotFoundError as e: + pytest.skip(f"VKML executor_runner not found - not built - skip {e}") @common.parametrize("test_data", test_data_suite) @@ -115,5 +120,9 @@ def test_acos_vgf_INT(test_data: Tuple): [], [], tosa_version="TOSA-1.0+INT", + run_on_vulkan_runtime=True, ) - pipeline.run() + try: + pipeline.run() + except FileNotFoundError as e: + pytest.skip(f"VKML executor_runner not found - not built - skip {e}") diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 9d15cea815c..2eabd302df6 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -202,12 +202,7 @@ def test_add_tensor_u85_INT_2(test_data: input_t2): pipeline.run() -# TODO/MLETORCH-1282: remove once inputs are not hard coded to ones -skip_keys = {"5d_float", "1d_ones", "1d_randn"} -filtered_test_data = {k: v for k, v in Add.test_data.items() if k not in skip_keys} - - -@common.parametrize("test_data", filtered_test_data) +@common.parametrize("test_data", Add.test_data) @common.SkipIfNoModelConverter def test_add_tensor_vgf_FP(test_data: input_t1): pipeline = VgfPipeline[input_t1]( @@ -224,7 +219,7 @@ def test_add_tensor_vgf_FP(test_data: input_t1): pytest.skip(f"VKML executor_runner not found - not built - skip {e}") -@common.parametrize("test_data", filtered_test_data) +@common.parametrize("test_data", Add.test_data) @common.SkipIfNoModelConverter def test_add_tensor_vgf_INT(test_data: input_t1): pipeline = VgfPipeline[input_t1]( diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index f97a2e0c383..aeb0e3a56bd 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -223,13 +223,48 @@ def run_target( elif target_board == "vkml_emulation_layer": return run_vkml_emulation_layer( executorch_program_manager, + inputs, intermediate_path, elf_path, ) +def save_inputs_to_file( + exported_program: ExportedProgram, + inputs: Tuple[torch.Tensor], + intermediate_path: str | Path, +): + input_file_paths = [] + input_names = get_input_names(exported_program) + for input_name, input_ in zip(input_names, inputs): + input_path = save_bytes(intermediate_path, input_, input_name) + input_file_paths.append(input_path) + + return input_file_paths + + +def get_output_from_file( + exported_program: ExportedProgram, + intermediate_path: str | Path, + output_base_name: str, +): + output_np = [] + output_node = exported_program.graph_module.graph.output_node() + for i, node in enumerate(output_node.args[0]): + output_shape = node.meta["val"].shape + output_dtype = node.meta["val"].dtype + tosa_ref_output = np.fromfile( + os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"), + _torch_to_numpy_dtype_dict[output_dtype], + ) + + output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) + return tuple(output_np) + + def run_vkml_emulation_layer( executorch_program_manager: ExecutorchProgramManager, + inputs: Tuple[torch.Tensor], intermediate_path: str | Path, elf_path: str | Path, ): @@ -239,7 +274,7 @@ def run_vkml_emulation_layer( `intermediate_path`: Directory to save the .pte and capture outputs. `elf_path`: Path to the Vulkan-capable executor_runner binary. """ - + exported_program = executorch_program_manager.exported_program() intermediate_path = Path(intermediate_path) intermediate_path.mkdir(exist_ok=True) elf_path = Path(elf_path) @@ -251,26 +286,29 @@ def run_vkml_emulation_layer( with open(pte_path, "wb") as f: f.write(executorch_program_manager.buffer) - cmd_line = [str(elf_path), "-model_path", pte_path] + output_base_name = "out" + out_path = os.path.join(intermediate_path, output_base_name) + + cmd_line = f"{elf_path} -model_path {pte_path} -output_file {out_path}" + + input_string = None + input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) + for input_path in input_paths: + if input_string is None: + input_string = f" -inputs={input_path}" + else: + input_string += f",{input_path}" + if input_string is not None: + cmd_line += input_string + cmd_line = cmd_line.split() + result = _run_cmd(cmd_line) - result_stdout = result.stdout.decode() # noqa: F841 # TODO: MLETORCH-1234: Support VGF e2e tests in VgfPipeline # TODO: Add regex to check for error or fault messages in stdout from Emulation Layer - # Regex to extract tensor values from stdout - output_np = [] - matches = re.findall( - r"Output\s+\d+:\s+tensor\(sizes=\[(.*?)\],\s+\[(.*?)\]\)", - result_stdout, - re.DOTALL, - ) - - for shape_str, values_str in matches: - shape = list(map(int, shape_str.split(","))) - values = list(map(float, re.findall(r"[-+]?\d*\.\d+|\d+", values_str))) - output_np.append(torch.tensor(values).reshape(shape)) + result_stdout = result.stdout.decode() # noqa: F841 - return tuple(output_np) + return get_output_from_file(exported_program, intermediate_path, output_base_name) def run_corstone( @@ -312,14 +350,10 @@ def run_corstone( with open(pte_path, "wb") as f: f.write(executorch_program_manager.buffer) - # Save inputs to file - input_names = get_input_names(exported_program) - input_paths = [] - for input_name, input_ in zip(input_names, inputs): - input_path = save_bytes(intermediate_path, input_, input_name) - input_paths.append(input_path) + input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) - out_path = os.path.join(intermediate_path, "out") + output_base_name = "out" + out_path = os.path.join(intermediate_path, output_base_name) cmd_line = f"executor_runner -m {pte_path} -o {out_path}" for input_path in input_paths: @@ -401,18 +435,7 @@ def run_corstone( f"Corstone simulation failed:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}" ) - output_np = [] - output_node = exported_program.graph_module.graph.output_node() - for i, node in enumerate(output_node.args[0]): - output_shape = node.meta["val"].shape - output_dtype = node.meta["val"].dtype - tosa_ref_output = np.fromfile( - os.path.join(intermediate_path, f"out-{i}.bin"), - _torch_to_numpy_dtype_dict[output_dtype], - ) - - output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) - return tuple(output_np) + return get_output_from_file(exported_program, intermediate_path, output_base_name) def prep_data_for_save( diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 434b4783bac..0146a708364 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -1,7 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. - * Copyright 2024-2025 Arm Limited and/or its affiliates. * All rights reserved. + * Copyright 2024-2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -18,6 +18,7 @@ * all fp32 tensors. */ +#include #include #include @@ -49,6 +50,16 @@ DEFINE_string( model_path, "model.pte", "Model serialized in flatbuffer format."); +DEFINE_string(inputs, "", "Comma-separated list of input files"); +DEFINE_string( + output_file, + "", + "Base name of output file. If not empty output will be written to the file(s)."); + +DEFINE_bool( + print_all_output, + false, + "Prints all output. By default only first and last 100 elements are printed."); DEFINE_uint32(num_executions, 1, "Number of times to run the model."); #ifdef ET_EVENT_TRACER_ENABLED DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path."); @@ -58,6 +69,8 @@ DEFINE_int32( -1, "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); +using executorch::aten::ScalarType; +using executorch::aten::Tensor; using executorch::extension::FileDataLoader; using executorch::runtime::Error; using executorch::runtime::EValue; @@ -70,6 +83,8 @@ using executorch::runtime::MethodMeta; using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; +using executorch::runtime::Tag; +using executorch::runtime::TensorInfo; /// Helper to manage resources for ETDump generation class EventTraceManager { @@ -156,6 +171,31 @@ int main(int argc, char** argv) { "FileDataLoader::from() failed: 0x%" PRIx32, (uint32_t)loader.error()); + std::vector inputs_storage; + std::vector> input_buffers; + + std::stringstream list_of_input_files(FLAGS_inputs); + std::string token; + + while (std::getline(list_of_input_files, token, ',')) { + std::ifstream input_file_handle(token, std::ios::binary | std::ios::ate); + if (!input_file_handle) { + ET_LOG(Error, "Failed to open input file: %s\n", token.c_str()); + return 1; + } + + std::streamsize file_size = input_file_handle.tellg(); + input_file_handle.seekg(0, std::ios::beg); + + inputs_storage.emplace_back(file_size, '\0'); + if (!input_file_handle.read(&inputs_storage.back()[0], file_size)) { + ET_LOG(Error, "Failed to read input file: %s\n", token.c_str()); + return 1; + } + + input_buffers.emplace_back(&inputs_storage.back()[0], file_size); + } + // Parse the program file. This is immutable, and can also be reused between // multiple execution invocations across multiple threads. Result program = Program::load(&loader.get()); @@ -254,7 +294,8 @@ int main(int argc, char** argv) { // Run the model. for (uint32_t i = 0; i < FLAGS_num_executions; i++) { ET_LOG(Debug, "Preparing inputs."); - // Allocate input tensors and set all of their elements to 1. The `inputs` + // Allocate input tensors and set all of their elements to 1 or to the + // contents of input_buffers if available. The `inputs` // variable owns the allocated memory and must live past the last call to // `execute()`. // @@ -262,7 +303,8 @@ int main(int argc, char** argv) { // because inputs whose space gets reused by memory planning (if // any such inputs exist) will not be preserved for the next // execution. - auto inputs = executorch::extension::prepare_input_tensors(*method); + auto inputs = executorch::extension::prepare_input_tensors( + *method, {}, input_buffers); ET_CHECK_MSG( inputs.ok(), "Could not prepare inputs: 0x%" PRIx32, @@ -295,10 +337,67 @@ int main(int argc, char** argv) { ET_LOG(Info, "%zu outputs: ", outputs.size()); Error status = method->get_outputs(outputs.data(), outputs.size()); ET_CHECK(status == Error::Ok); - // Print the first and last 100 elements of long lists of scalars. - std::cout << executorch::extension::evalue_edge_items(100); - for (int i = 0; i < outputs.size(); ++i) { - std::cout << "Output " << i << ": " << outputs[i] << std::endl; + + if (FLAGS_output_file.size() > 0) { + for (int i = 0; i < outputs.size(); ++i) { + if (outputs[i].isTensor()) { + Tensor tensor = outputs[i].toTensor(); + + char out_filename[255]; + snprintf(out_filename, 255, "%s-%d.bin", FLAGS_output_file.c_str(), i); + ET_LOG(Info, "Writing output to file: %s", out_filename); + FILE* out_file = fopen(out_filename, "wb"); + auto written_size = + fwrite(tensor.const_data_ptr(), 1, tensor.nbytes(), out_file); + fclose(out_file); + } + } + } + + if (FLAGS_print_all_output) { + for (int i = 0; i < outputs.size(); ++i) { + if (outputs[i].isTensor()) { + Tensor tensor = outputs[i].toTensor(); + + for (int j = 0; j < tensor.numel(); ++j) { + if (tensor.scalar_type() == ScalarType::Int) { + printf( + "Output[%d][%d]: (int) %d\n", + i, + j, + tensor.const_data_ptr()[j]); + } else if (tensor.scalar_type() == ScalarType::Float) { + printf( + "Output[%d][%d]: (float) %f\n", + i, + j, + tensor.const_data_ptr()[j]); + } else if (tensor.scalar_type() == ScalarType::Char) { + printf( + "Output[%d][%d]: (char) %d\n", + i, + j, + tensor.const_data_ptr()[j]); + } else if (tensor.scalar_type() == ScalarType::Bool) { + printf( + "Output[%d][%d]: (bool) %s (0x%x)\n", + i, + j, + tensor.const_data_ptr()[j] ? "true " : "false", + tensor.const_data_ptr()[j]); + } + } + } else { + printf("Output[%d]: Not Tensor\n", i); + } + } + } else { + // Print the first and last 100 elements of long lists of scalars. + std::cout << executorch::extension::evalue_edge_items(100); + + for (int i = 0; i < outputs.size(); ++i) { + std::cout << "OutputX " << i << ": " << outputs[i] << std::endl; + } } if (tracer.get_event_tracer()) { diff --git a/extension/runner_util/inputs.cpp b/extension/runner_util/inputs.cpp index df3727b77d9..eceaf3cfeca 100644 --- a/extension/runner_util/inputs.cpp +++ b/extension/runner_util/inputs.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -24,9 +25,22 @@ namespace extension { Result prepare_input_tensors( Method& method, - PrepareInputTensorsOptions options) { + PrepareInputTensorsOptions options, + const std::vector>& input_buffers) { MethodMeta method_meta = method.method_meta(); size_t num_inputs = method_meta.num_inputs(); + bool hard_code_inputs_to_ones = true; + + if (input_buffers.size() > 0) { + hard_code_inputs_to_ones = false; + + ET_CHECK_OR_RETURN_ERROR( + num_inputs == input_buffers.size(), + InvalidArgument, + "Wrong number of inputs allocated compared to method %zu ? %zu", + num_inputs, + input_buffers.size()); + } // A large number of small allocations could exhaust the heap even if the // total size is smaller than the limit. @@ -94,9 +108,25 @@ Result prepare_input_tensors( } inputs[num_allocated++] = data_ptr; + // Write input data for input tensor + if (!hard_code_inputs_to_ones) { + auto [buffer, buffer_size] = input_buffers.at(i); + if (buffer_size != tensor_meta->nbytes()) { + ET_LOG( + Error, + "input size (%zu) and tensor size (%zu) mismatch!", + buffer_size, + tensor_meta->nbytes()); + BufferCleanup cleanup({inputs, num_allocated}); + return Error::InvalidArgument; + } + std::memcpy(data_ptr, buffer, buffer_size); + } + // Create the tensor and set it as the input. - Error err = - internal::fill_and_set_input(method, tensor_meta.get(), i, data_ptr); + Error err = internal::fill_and_set_input( + method, tensor_meta.get(), i, data_ptr, hard_code_inputs_to_ones); + if (err != Error::Ok) { ET_LOG( Error, "Failed to prepare input %zu: 0x%" PRIx32, i, (uint32_t)err); diff --git a/extension/runner_util/inputs.h b/extension/runner_util/inputs.h index 214b76d67e3..1a30e2cc4df 100644 --- a/extension/runner_util/inputs.h +++ b/extension/runner_util/inputs.h @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -8,6 +9,8 @@ #pragma once +#include + #include #include #include @@ -84,18 +87,20 @@ struct PrepareInputTensorsOptions { */ executorch::runtime::Result prepare_input_tensors( Method& method, - PrepareInputTensorsOptions options = {}); + PrepareInputTensorsOptions options = {}, + const std::vector>& input_buffers = {}); namespace internal { /** * INTERNAL-ONLY: Creates a Tensor using the provided shape and buffer, - * fills it with ones, and sets the input at `input_index`. + * fills it with ones by default, and sets the input at `input_index`. */ executorch::runtime::Error fill_and_set_input( Method& method, TensorInfo& tensor_meta, size_t input_index, - void* data_ptr); + void* data_ptr, + bool fill_tensor = true); } // namespace internal } // namespace extension diff --git a/extension/runner_util/inputs_aten.cpp b/extension/runner_util/inputs_aten.cpp index b89562a2f69..c3fdd524a13 100644 --- a/extension/runner_util/inputs_aten.cpp +++ b/extension/runner_util/inputs_aten.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -26,7 +27,8 @@ Error fill_and_set_input( Method& method, TensorInfo& tensor_meta, size_t input_index, - void* data_ptr) { + void* data_ptr, + bool fill_tensor) { // Convert the sizes array from int32_t to int64_t. std::vector sizes; for (auto s : tensor_meta.sizes()) { @@ -34,7 +36,10 @@ Error fill_and_set_input( } at::Tensor t = at::from_blob( data_ptr, sizes, at::TensorOptions(tensor_meta.scalar_type())); - t.fill_(1.0f); + + if (fill_tensor) { + t.fill_(1.0f); + } return method.set_input(t, input_index); } diff --git a/extension/runner_util/inputs_portable.cpp b/extension/runner_util/inputs_portable.cpp index 6f31acc31e1..9ebaaa7ce85 100644 --- a/extension/runner_util/inputs_portable.cpp +++ b/extension/runner_util/inputs_portable.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -56,7 +57,8 @@ Error fill_and_set_input( Method& method, TensorInfo& tensor_meta, size_t input_index, - void* data_ptr) { + void* data_ptr, + bool fill_tensor) { TensorImpl impl = TensorImpl( tensor_meta.scalar_type(), /*dim=*/tensor_meta.sizes().size(), @@ -68,7 +70,11 @@ Error fill_and_set_input( data_ptr, const_cast(tensor_meta.dim_order().data())); Tensor t(&impl); - ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t)); + + if (fill_tensor) { + ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t)); + } + return method.set_input(t, input_index); } From 2ea12236931ebbdb0caa7271b13bb3325fb60684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Thu, 11 Sep 2025 15:48:48 +0200 Subject: [PATCH 2/2] Remove unused variable --- examples/portable/executor_runner/executor_runner.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 0146a708364..4f4208a5b53 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -347,8 +347,7 @@ int main(int argc, char** argv) { snprintf(out_filename, 255, "%s-%d.bin", FLAGS_output_file.c_str(), i); ET_LOG(Info, "Writing output to file: %s", out_filename); FILE* out_file = fopen(out_filename, "wb"); - auto written_size = - fwrite(tensor.const_data_ptr(), 1, tensor.nbytes(), out_file); + fwrite(tensor.const_data_ptr(), 1, tensor.nbytes(), out_file); fclose(out_file); } }