From 58998b051cbc8981c196ef7fcb3dc2c2aacbf1ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Mon, 1 Sep 2025 15:56:06 +0200 Subject: [PATCH 1/3] Update default portable executor runner with real input In the executor runner all inputs get hard coded to ones. Add optional input option, in which case tensor inputs will be written from supplied binary input files. Also update the Arm VKML runner unit test runner as a user with real inputs. GenAI used, Blackduck Scan OK Change-Id: Ie0363e04f0bbcb2342781f3c560ac30837d88a31 --- backends/arm/test/ops/test_add.py | 9 +---- backends/arm/test/runner_utils.py | 39 +++++++++++++++---- .../executor_runner/executor_runner.cpp | 39 +++++++++++++++++-- extension/runner_util/inputs.cpp | 36 +++++++++++++++-- extension/runner_util/inputs.h | 11 ++++-- extension/runner_util/inputs_aten.cpp | 9 ++++- extension/runner_util/inputs_portable.cpp | 10 ++++- 7 files changed, 125 insertions(+), 28 deletions(-) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 0e825d57894..834be848d98 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -185,12 +185,7 @@ def test_add_tensor_u85_INT_2(test_data: input_t2): pipeline.run() -# TODO/MLETORCH-1282: remove once inputs are not hard coded to ones -skip_keys = {"5d_float", "1d_ones", "1d_randn"} -filtered_test_data = {k: v for k, v in Add.test_data.items() if k not in skip_keys} - - -@common.parametrize("test_data", filtered_test_data) +@common.parametrize("test_data", Add.test_data) @common.SkipIfNoModelConverter def test_add_tensor_vgf_FP(test_data: input_t1): pipeline = VgfPipeline[input_t1]( @@ -204,7 +199,7 @@ def test_add_tensor_vgf_FP(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", filtered_test_data) +@common.parametrize("test_data", Add.test_data) @common.SkipIfNoModelConverter def test_add_tensor_vgf_INT(test_data: input_t1): pipeline = VgfPipeline[input_t1]( diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 9234f4dd7e5..8a81f335bd2 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -223,13 +223,29 @@ def run_target( elif target_board == "vkml_emulation_layer": return run_vkml_emulation_layer( executorch_program_manager, + inputs, intermediate_path, elf_path, ) +def save_inputs_to_file( + exported_program: ExportedProgram, + inputs: Tuple[torch.Tensor], + intermediate_path: str | Path, +): + input_file_paths = [] + input_names = get_input_names(exported_program) + for input_name, input_ in zip(input_names, inputs): + input_path = save_bytes(intermediate_path, input_, input_name) + input_file_paths.append(input_path) + + return input_file_paths + + def run_vkml_emulation_layer( executorch_program_manager: ExecutorchProgramManager, + inputs: Tuple[torch.Tensor], intermediate_path: str | Path, elf_path: str | Path, ): @@ -239,7 +255,7 @@ def run_vkml_emulation_layer( `intermediate_path`: Directory to save the .pte and capture outputs. `elf_path`: Path to the Vulkan-capable executor_runner binary. """ - + exported_program = executorch_program_manager.exported_program() intermediate_path = Path(intermediate_path) intermediate_path.mkdir(exist_ok=True) elf_path = Path(elf_path) @@ -251,7 +267,19 @@ def run_vkml_emulation_layer( with open(pte_path, "wb") as f: f.write(executorch_program_manager.buffer) - cmd_line = [elf_path, "-model_path", pte_path] + input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) + + cmd_line = f"{elf_path} -model_path {pte_path}" + input_string = None + for input_path in input_paths: + if input_string is None: + input_string = f" -inputs={input_path}" + else: + input_string += f",{input_path}" + if input_string is not None: + cmd_line += input_string + cmd_line = cmd_line.split() + result = _run_cmd(cmd_line) result_stdout = result.stdout.decode() # noqa: F841 @@ -312,12 +340,7 @@ def run_corstone( with open(pte_path, "wb") as f: f.write(executorch_program_manager.buffer) - # Save inputs to file - input_names = get_input_names(exported_program) - input_paths = [] - for input_name, input_ in zip(input_names, inputs): - input_path = save_bytes(intermediate_path, input_, input_name) - input_paths.append(input_path) + input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) out_path = os.path.join(intermediate_path, "out") diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 434b4783bac..908a8cc6ea9 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -1,7 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. - * Copyright 2024-2025 Arm Limited and/or its affiliates. * All rights reserved. + * Copyright 2024-2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -18,6 +18,7 @@ * all fp32 tensors. */ +#include #include #include @@ -49,6 +50,7 @@ DEFINE_string( model_path, "model.pte", "Model serialized in flatbuffer format."); +DEFINE_string(inputs, "", "Comma-separated list of input files"); DEFINE_uint32(num_executions, 1, "Number of times to run the model."); #ifdef ET_EVENT_TRACER_ENABLED DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path."); @@ -58,6 +60,8 @@ DEFINE_int32( -1, "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); +using executorch::aten::ScalarType; +using executorch::aten::Tensor; using executorch::extension::FileDataLoader; using executorch::runtime::Error; using executorch::runtime::EValue; @@ -70,6 +74,8 @@ using executorch::runtime::MethodMeta; using executorch::runtime::Program; using executorch::runtime::Result; using executorch::runtime::Span; +using executorch::runtime::Tag; +using executorch::runtime::TensorInfo; /// Helper to manage resources for ETDump generation class EventTraceManager { @@ -156,6 +162,31 @@ int main(int argc, char** argv) { "FileDataLoader::from() failed: 0x%" PRIx32, (uint32_t)loader.error()); + std::vector inputs_storage; + std::vector> input_buffers; + + std::stringstream list_of_input_files(FLAGS_inputs); + std::string token; + + while (std::getline(list_of_input_files, token, ',')) { + std::ifstream input_file_handle(token, std::ios::binary | std::ios::ate); + if (!input_file_handle) { + ET_LOG(Error, "Failed to open input file: %s\n", token.c_str()); + return 1; + } + + std::streamsize file_size = input_file_handle.tellg(); + input_file_handle.seekg(0, std::ios::beg); + + inputs_storage.emplace_back(file_size, '\0'); + if (!input_file_handle.read(&inputs_storage.back()[0], file_size)) { + ET_LOG(Error, "Failed to read input file: %s\n", token.c_str()); + return 1; + } + + input_buffers.emplace_back(&inputs_storage.back()[0], file_size); + } + // Parse the program file. This is immutable, and can also be reused between // multiple execution invocations across multiple threads. Result program = Program::load(&loader.get()); @@ -254,7 +285,8 @@ int main(int argc, char** argv) { // Run the model. for (uint32_t i = 0; i < FLAGS_num_executions; i++) { ET_LOG(Debug, "Preparing inputs."); - // Allocate input tensors and set all of their elements to 1. The `inputs` + // Allocate input tensors and set all of their elements to 1 or to the + // contents of input_buffers if available. The `inputs` // variable owns the allocated memory and must live past the last call to // `execute()`. // @@ -262,7 +294,8 @@ int main(int argc, char** argv) { // because inputs whose space gets reused by memory planning (if // any such inputs exist) will not be preserved for the next // execution. - auto inputs = executorch::extension::prepare_input_tensors(*method); + auto inputs = executorch::extension::prepare_input_tensors( + *method, {}, input_buffers); ET_CHECK_MSG( inputs.ok(), "Could not prepare inputs: 0x%" PRIx32, diff --git a/extension/runner_util/inputs.cpp b/extension/runner_util/inputs.cpp index df3727b77d9..a4be0f93eba 100644 --- a/extension/runner_util/inputs.cpp +++ b/extension/runner_util/inputs.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -24,9 +25,22 @@ namespace extension { Result prepare_input_tensors( Method& method, - PrepareInputTensorsOptions options) { + PrepareInputTensorsOptions options, + const std::vector>& input_buffers) { MethodMeta method_meta = method.method_meta(); size_t num_inputs = method_meta.num_inputs(); + bool hard_code_inputs_to_ones = true; + + ET_CHECK_OR_RETURN_ERROR( + input_buffers.size() > 0 && num_inputs == input_buffers.size(), + InvalidArgument, + "Wrong number of inputs allocated compared to method %zu ? %zu", + num_inputs, + input_buffers.size()); + + if (input_buffers.size() > 0) { + hard_code_inputs_to_ones = false; + } // A large number of small allocations could exhaust the heap even if the // total size is smaller than the limit. @@ -94,9 +108,25 @@ Result prepare_input_tensors( } inputs[num_allocated++] = data_ptr; + // Write input data for input tensor + if (!hard_code_inputs_to_ones) { + auto [buffer, buffer_size] = input_buffers.at(i); + if (buffer_size != tensor_meta->nbytes()) { + ET_LOG( + Error, + "input size (%ld) and tensor size (%ld) mismatch!", + buffer_size, + tensor_meta->nbytes()); + BufferCleanup cleanup({inputs, num_allocated}); + return Error::InvalidArgument; + } + std::memcpy(data_ptr, buffer, buffer_size); + } + // Create the tensor and set it as the input. - Error err = - internal::fill_and_set_input(method, tensor_meta.get(), i, data_ptr); + Error err = internal::fill_and_set_input( + method, tensor_meta.get(), i, data_ptr, hard_code_inputs_to_ones); + if (err != Error::Ok) { ET_LOG( Error, "Failed to prepare input %zu: 0x%" PRIx32, i, (uint32_t)err); diff --git a/extension/runner_util/inputs.h b/extension/runner_util/inputs.h index 214b76d67e3..1a30e2cc4df 100644 --- a/extension/runner_util/inputs.h +++ b/extension/runner_util/inputs.h @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -8,6 +9,8 @@ #pragma once +#include + #include #include #include @@ -84,18 +87,20 @@ struct PrepareInputTensorsOptions { */ executorch::runtime::Result prepare_input_tensors( Method& method, - PrepareInputTensorsOptions options = {}); + PrepareInputTensorsOptions options = {}, + const std::vector>& input_buffers = {}); namespace internal { /** * INTERNAL-ONLY: Creates a Tensor using the provided shape and buffer, - * fills it with ones, and sets the input at `input_index`. + * fills it with ones by default, and sets the input at `input_index`. */ executorch::runtime::Error fill_and_set_input( Method& method, TensorInfo& tensor_meta, size_t input_index, - void* data_ptr); + void* data_ptr, + bool fill_tensor = true); } // namespace internal } // namespace extension diff --git a/extension/runner_util/inputs_aten.cpp b/extension/runner_util/inputs_aten.cpp index b89562a2f69..c3fdd524a13 100644 --- a/extension/runner_util/inputs_aten.cpp +++ b/extension/runner_util/inputs_aten.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -26,7 +27,8 @@ Error fill_and_set_input( Method& method, TensorInfo& tensor_meta, size_t input_index, - void* data_ptr) { + void* data_ptr, + bool fill_tensor) { // Convert the sizes array from int32_t to int64_t. std::vector sizes; for (auto s : tensor_meta.sizes()) { @@ -34,7 +36,10 @@ Error fill_and_set_input( } at::Tensor t = at::from_blob( data_ptr, sizes, at::TensorOptions(tensor_meta.scalar_type())); - t.fill_(1.0f); + + if (fill_tensor) { + t.fill_(1.0f); + } return method.set_input(t, input_index); } diff --git a/extension/runner_util/inputs_portable.cpp b/extension/runner_util/inputs_portable.cpp index 6f31acc31e1..9ebaaa7ce85 100644 --- a/extension/runner_util/inputs_portable.cpp +++ b/extension/runner_util/inputs_portable.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -56,7 +57,8 @@ Error fill_and_set_input( Method& method, TensorInfo& tensor_meta, size_t input_index, - void* data_ptr) { + void* data_ptr, + bool fill_tensor) { TensorImpl impl = TensorImpl( tensor_meta.scalar_type(), /*dim=*/tensor_meta.sizes().size(), @@ -68,7 +70,11 @@ Error fill_and_set_input( data_ptr, const_cast(tensor_meta.dim_order().data())); Tensor t(&impl); - ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t)); + + if (fill_tensor) { + ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t)); + } + return method.set_input(t, input_index); } From d9d24659af5992b4b7f8aecfa446f0876f45e449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Fri, 5 Sep 2025 13:39:20 +0200 Subject: [PATCH 2/3] Update default executor runner with output options By default not all output is printed. Adds option for printing all output. Also adds option to print output to file. Also update the Arm VKML unit test runner as a user that prints output to file. Enables acos_unit test to run on Vulkan runtime that depends on this. Change-Id: If61c1fe89c9da004fa9db4524e1413893549abce --- backends/arm/test/ops/test_acos.py | 13 +++- backends/arm/test/ops/test_add.py | 11 ++- backends/arm/test/runner_utils.py | 58 +++++++-------- .../executor_runner/executor_runner.cpp | 74 ++++++++++++++++++- extension/runner_util/inputs.cpp | 14 ++-- 5 files changed, 126 insertions(+), 44 deletions(-) diff --git a/backends/arm/test/ops/test_acos.py b/backends/arm/test/ops/test_acos.py index 102d979352e..28dadcf95be 100644 --- a/backends/arm/test/ops/test_acos.py +++ b/backends/arm/test/ops/test_acos.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from typing import Tuple +import pytest import torch from executorch.backends.arm.test import common @@ -102,8 +103,12 @@ def test_acos_vgf_FP(test_data: Tuple): [], [], tosa_version="TOSA-1.0+FP", + run_on_vulkan_runtime=True, ) - pipeline.run() + try: + pipeline.run() + except FileNotFoundError as e: + pytest.skip(f"VKML executor_runner not found - not built - skip {e}") @common.parametrize("test_data", test_data_suite) @@ -115,5 +120,9 @@ def test_acos_vgf_INT(test_data: Tuple): [], [], tosa_version="TOSA-1.0+INT", + run_on_vulkan_runtime=True, ) - pipeline.run() + try: + pipeline.run() + except FileNotFoundError as e: + pytest.skip(f"VKML executor_runner not found - not built - skip {e}") diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 834be848d98..970de7a56e4 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -7,6 +7,7 @@ from typing import Tuple +import pytest import torch from executorch.backends.arm.quantizer import arm_quantizer from executorch.backends.arm.test import common, conftest @@ -196,7 +197,10 @@ def test_add_tensor_vgf_FP(test_data: input_t1): tosa_version="TOSA-1.0+FP", run_on_vulkan_runtime=True, ) - pipeline.run() + try: + pipeline.run() + except FileNotFoundError as e: + pytest.skip(f"VKML executor_runner not found - not built - skip {e}") @common.parametrize("test_data", Add.test_data) @@ -210,4 +214,7 @@ def test_add_tensor_vgf_INT(test_data: input_t1): tosa_version="TOSA-1.0+INT", run_on_vulkan_runtime=True, ) - pipeline.run() + try: + pipeline.run() + except FileNotFoundError as e: + pytest.skip(f"VKML executor_runner not found - not built - skip {e}") diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 8a81f335bd2..d5c42b05d2f 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -243,6 +243,25 @@ def save_inputs_to_file( return input_file_paths +def get_output_from_file( + exported_program: ExportedProgram, + intermediate_path: str | Path, + output_base_name: str, +): + output_np = [] + output_node = exported_program.graph_module.graph.output_node() + for i, node in enumerate(output_node.args[0]): + output_shape = node.meta["val"].shape + output_dtype = node.meta["val"].dtype + tosa_ref_output = np.fromfile( + os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"), + _torch_to_numpy_dtype_dict[output_dtype], + ) + + output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) + return tuple(output_np) + + def run_vkml_emulation_layer( executorch_program_manager: ExecutorchProgramManager, inputs: Tuple[torch.Tensor], @@ -267,10 +286,13 @@ def run_vkml_emulation_layer( with open(pte_path, "wb") as f: f.write(executorch_program_manager.buffer) - input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) + output_base_name = "out" + out_path = os.path.join(intermediate_path, output_base_name) + + cmd_line = f"{elf_path} -model_path {pte_path} -output_file {out_path}" - cmd_line = f"{elf_path} -model_path {pte_path}" input_string = None + input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) for input_path in input_paths: if input_string is None: input_string = f" -inputs={input_path}" @@ -282,23 +304,11 @@ def run_vkml_emulation_layer( result = _run_cmd(cmd_line) - result_stdout = result.stdout.decode() # noqa: F841 # TODO: MLETORCH-1234: Support VGF e2e tests in VgfPipeline # TODO: Add regex to check for error or fault messages in stdout from Emulation Layer - # Regex to extract tensor values from stdout - output_np = [] - matches = re.findall( - r"Output\s+\d+:\s+tensor\(sizes=\[(.*?)\],\s+\[(.*?)\]\)", - result_stdout, - re.DOTALL, - ) - - for shape_str, values_str in matches: - shape = list(map(int, shape_str.split(","))) - values = list(map(float, re.findall(r"[-+]?\d*\.\d+|\d+", values_str))) - output_np.append(torch.tensor(values).reshape(shape)) + result_stdout = result.stdout.decode() # noqa: F841 - return tuple(output_np) + return get_output_from_file(exported_program, intermediate_path, output_base_name) def run_corstone( @@ -342,7 +352,8 @@ def run_corstone( input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path) - out_path = os.path.join(intermediate_path, "out") + output_base_name = "out" + out_path = os.path.join(intermediate_path, output_base_name) cmd_line = f"executor_runner -m {pte_path} -o {out_path}" for input_path in input_paths: @@ -424,18 +435,7 @@ def run_corstone( f"Corstone simulation failed:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}" ) - output_np = [] - output_node = exported_program.graph_module.graph.output_node() - for i, node in enumerate(output_node.args[0]): - output_shape = node.meta["val"].shape - output_dtype = node.meta["val"].dtype - tosa_ref_output = np.fromfile( - os.path.join(intermediate_path, f"out-{i}.bin"), - _torch_to_numpy_dtype_dict[output_dtype], - ) - - output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) - return tuple(output_np) + return get_output_from_file(exported_program, intermediate_path, output_base_name) def prep_data_for_save( diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index 908a8cc6ea9..0146a708364 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -51,6 +51,15 @@ DEFINE_string( "model.pte", "Model serialized in flatbuffer format."); DEFINE_string(inputs, "", "Comma-separated list of input files"); +DEFINE_string( + output_file, + "", + "Base name of output file. If not empty output will be written to the file(s)."); + +DEFINE_bool( + print_all_output, + false, + "Prints all output. By default only first and last 100 elements are printed."); DEFINE_uint32(num_executions, 1, "Number of times to run the model."); #ifdef ET_EVENT_TRACER_ENABLED DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path."); @@ -328,10 +337,67 @@ int main(int argc, char** argv) { ET_LOG(Info, "%zu outputs: ", outputs.size()); Error status = method->get_outputs(outputs.data(), outputs.size()); ET_CHECK(status == Error::Ok); - // Print the first and last 100 elements of long lists of scalars. - std::cout << executorch::extension::evalue_edge_items(100); - for (int i = 0; i < outputs.size(); ++i) { - std::cout << "Output " << i << ": " << outputs[i] << std::endl; + + if (FLAGS_output_file.size() > 0) { + for (int i = 0; i < outputs.size(); ++i) { + if (outputs[i].isTensor()) { + Tensor tensor = outputs[i].toTensor(); + + char out_filename[255]; + snprintf(out_filename, 255, "%s-%d.bin", FLAGS_output_file.c_str(), i); + ET_LOG(Info, "Writing output to file: %s", out_filename); + FILE* out_file = fopen(out_filename, "wb"); + auto written_size = + fwrite(tensor.const_data_ptr(), 1, tensor.nbytes(), out_file); + fclose(out_file); + } + } + } + + if (FLAGS_print_all_output) { + for (int i = 0; i < outputs.size(); ++i) { + if (outputs[i].isTensor()) { + Tensor tensor = outputs[i].toTensor(); + + for (int j = 0; j < tensor.numel(); ++j) { + if (tensor.scalar_type() == ScalarType::Int) { + printf( + "Output[%d][%d]: (int) %d\n", + i, + j, + tensor.const_data_ptr()[j]); + } else if (tensor.scalar_type() == ScalarType::Float) { + printf( + "Output[%d][%d]: (float) %f\n", + i, + j, + tensor.const_data_ptr()[j]); + } else if (tensor.scalar_type() == ScalarType::Char) { + printf( + "Output[%d][%d]: (char) %d\n", + i, + j, + tensor.const_data_ptr()[j]); + } else if (tensor.scalar_type() == ScalarType::Bool) { + printf( + "Output[%d][%d]: (bool) %s (0x%x)\n", + i, + j, + tensor.const_data_ptr()[j] ? "true " : "false", + tensor.const_data_ptr()[j]); + } + } + } else { + printf("Output[%d]: Not Tensor\n", i); + } + } + } else { + // Print the first and last 100 elements of long lists of scalars. + std::cout << executorch::extension::evalue_edge_items(100); + + for (int i = 0; i < outputs.size(); ++i) { + std::cout << "OutputX " << i << ": " << outputs[i] << std::endl; + } } if (tracer.get_event_tracer()) { diff --git a/extension/runner_util/inputs.cpp b/extension/runner_util/inputs.cpp index a4be0f93eba..eed37b5b27c 100644 --- a/extension/runner_util/inputs.cpp +++ b/extension/runner_util/inputs.cpp @@ -31,15 +31,15 @@ Result prepare_input_tensors( size_t num_inputs = method_meta.num_inputs(); bool hard_code_inputs_to_ones = true; - ET_CHECK_OR_RETURN_ERROR( - input_buffers.size() > 0 && num_inputs == input_buffers.size(), - InvalidArgument, - "Wrong number of inputs allocated compared to method %zu ? %zu", - num_inputs, - input_buffers.size()); - if (input_buffers.size() > 0) { hard_code_inputs_to_ones = false; + + ET_CHECK_OR_RETURN_ERROR( + num_inputs == input_buffers.size(), + InvalidArgument, + "Wrong number of inputs allocated compared to method %zu ? %zu", + num_inputs, + input_buffers.size()); } // A large number of small allocations could exhaust the heap even if the From 3a345f346becf4743ed43fe90e2ffa005f52a50d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Mon, 8 Sep 2025 10:51:21 +0200 Subject: [PATCH 3/3] Remove backends/arm/test/ops/.fuse_hidden00004ec100000001 The file was added by mistake. --- .../arm/test/ops/.fuse_hidden00004ec100000001 | 237 ------------------ 1 file changed, 237 deletions(-) delete mode 100644 backends/arm/test/ops/.fuse_hidden00004ec100000001 diff --git a/backends/arm/test/ops/.fuse_hidden00004ec100000001 b/backends/arm/test/ops/.fuse_hidden00004ec100000001 deleted file mode 100644 index 8376df47b39..00000000000 --- a/backends/arm/test/ops/.fuse_hidden00004ec100000001 +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -import pytest -import torch -from executorch.backends.arm.quantizer import arm_quantizer -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU55PipelineINT, - EthosU85PipelineINT, - TosaPipelineFP, - TosaPipelineINT, - VgfPipeline, -) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.specification import get_tosa_spec -from executorch.backends.xnnpack.test.tester import Quantize -from torchao.quantization.pt2e import HistogramObserver -from torchao.quantization.pt2e.quantizer import QuantizationSpec - -aten_op = "torch.ops.aten.add.Tensor" -exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" - -input_t1 = Tuple[torch.Tensor] # Input x - - -class Add(torch.nn.Module): - def forward(self, x: torch.Tensor): - return x + x - - test_data: list[input_t1] = { - "5d_float": lambda: (torch.FloatTensor([1, 2, 3, 5, 7]),), - "1d_ones": lambda: ((3 * torch.ones(8),)), - "1d_randn": lambda: (10 * torch.randn(8),), - "4d_ones_1": lambda: (torch.ones(1, 1, 4, 4),), - "4d_ones_2": lambda: (torch.ones(1, 3, 4, 2),), - } - - -input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y - - -class Add2(torch.nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor): - return x + y - - test_data: list[input_t2] = { - "5d_float": lambda: ( - torch.FloatTensor([1, 2, 3, 5, 7]), - (torch.FloatTensor([2, 1, 2, 1, 10])), - ), - "4d_ones": lambda: (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)), - "4d_randn_1": lambda: (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), - "4d_randn_2": lambda: (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), - "4d_randn_big": lambda: ( - (1 << 30) * torch.randn(1, 1, 4, 4), - torch.randn(1, 1, 4, 1), - ), - "4d_randn_1_mutltiple_broadcasts": lambda: ( - torch.randn(1, 4, 4, 1), - torch.ones(1, 1, 4, 4), - ), - "4d_big_small": lambda: ( - (10e10) * torch.randn(1, 10, 20, 30), - torch.randn(1, 10, 20, 30), - ), - } - - -class Add3(torch.nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor): - return x + y - - test_data: list[input_t2] = { - "3d_randn_diff_rank": lambda: (torch.randn(1, 4, 5), torch.randn(4, 1)), - "4d_randn_diff_rank": lambda: (torch.randn(1, 1, 4, 4), torch.randn(4, 1)), - "4d_randn_diff_rank_2": lambda: (torch.randn(4, 1), torch.randn(1, 1, 4, 5)), - } - - -@common.parametrize("test_data", Add.test_data) -def test_add_tensor_tosa_FP(test_data: input_t1): - pipeline = TosaPipelineFP[input_t1](Add(), test_data(), aten_op, exir_op) - pipeline.run() - - -@common.parametrize("test_data", Add.test_data) -def test_add_tensor_tosa_INT(test_data: input_t1): - pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op, qtol=0) - pipeline.run() - - -@common.parametrize("test_data", Add.test_data) -def test_add_tensor_tosa_INT_i32(test_data: input_t1): - pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op) - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"), - } - # Create a quantizer with int8 quantization on the input and output but int32 on everything else. - quantizer = arm_quantizer.TOSAQuantizer( - get_tosa_spec(common.get_tosa_compile_spec(tosa_profiles[tosa_version])) - ) - quantizer.set_io(arm_quantizer.get_symmetric_quantization_config()) - observer_options = {"eps": 2**-16} - observer = HistogramObserver.with_args(**observer_options) - input_act_qspec = QuantizationSpec( - torch.int32, - observer, - qscheme=torch.per_tensor_symmetric, - quant_max=2**31 - 1, - quant_min=-(2**31), - ) - output_act_qspec = QuantizationSpec( - torch.int32, - observer, - qscheme=torch.per_tensor_symmetric, - quant_max=2**31 - 1, - quant_min=-(2**31), - ) - # This quantization_config will be set as global config. - quantization_config = arm_quantizer.QuantizationConfig( - input_act_qspec, output_act_qspec, None, None - ) - quantize_stage = Quantize(quantizer, quantization_config) - pipeline.change_args("quantize", quantize_stage) - - # Check that we get the additional (dq -> q - pipeline.add_stage_after( - "export", pipeline.tester.check_count, {"torch.ops.quantized_decomposed": 8} - ) - pipeline.run() - - -@common.parametrize("test_data", Add.test_data) -@common.XfailIfNoCorstone300 -def test_add_tensor_u55_INT(test_data: input_t1): - pipeline = EthosU55PipelineINT[input_t1]( - Add(), test_data(), aten_op, exir_op, run_on_fvp=True - ) - pipeline.run() - - -@common.parametrize("test_data", Add.test_data) -@common.XfailIfNoCorstone320 -def test_add_tensor_u85_INT(test_data: input_t1): - pipeline = EthosU85PipelineINT[input_t1]( - Add(), test_data(), aten_op, exir_op, run_on_fvp=True - ) - pipeline.run() - - -@common.parametrize("test_data", Add2.test_data) -def test_add_tensor_tosa_FP_2(test_data: input_t2): - pipeline = TosaPipelineFP[input_t2](Add2(), test_data(), aten_op, exir_op) - pipeline.run() - - -@common.parametrize("test_data", Add3.test_data) -def test_add_tensor_tosa_FP_3(test_data: input_t2): - pipeline = TosaPipelineFP[input_t2](Add3(), test_data(), aten_op, exir_op) - pipeline.run() - - -@common.parametrize("test_data", Add3.test_data) -def test_add_tensor_tosa_INT_3(test_data: input_t2): - pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op, qtol=0) - pipeline.run() - - -@common.parametrize("test_data", Add2.test_data) -def test_add_tensor_tosa_INT_2(test_data: input_t2): - pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op, qtol=0) - pipeline.run() - - -@common.parametrize("test_data", Add2.test_data) -@common.XfailIfNoCorstone300 -def test_add_tensor_u55_INT_2(test_data: input_t2): - pipeline = EthosU55PipelineINT[input_t2]( - Add2(), test_data(), aten_op, exir_op, run_on_fvp=True - ) - pipeline.run() - - -@common.parametrize("test_data", Add2.test_data) -@common.XfailIfNoCorstone320 -def test_add_tensor_u85_INT_2(test_data: input_t2): - pipeline = EthosU85PipelineINT[input_t2]( - Add2(), test_data(), aten_op, exir_op, run_on_fvp=True - ) - pipeline.run() - - -# TODO/MLETORCH-1282: remove once inputs are not hard coded to ones -skip_keys = {"5d_float", "1d_ones", "1d_randn"} -filtered_test_data = {k: v for k, v in Add.test_data.items() if k not in skip_keys} - - -@common.parametrize("test_data", filtered_test_data) -@common.SkipIfNoModelConverter -def test_add_tensor_vgf_FP(test_data: input_t1): - pipeline = VgfPipeline[input_t1]( - Add(), - test_data(), - aten_op, - exir_op, - tosa_version="TOSA-1.0+FP", - run_on_vulkan_runtime=True, - ) - try: - pipeline.run() - except FileNotFoundError as e: - pytest.skip(f"VKML executor_runner not found - not built - skip {e}") - - -@common.parametrize("test_data", filtered_test_data) -@common.SkipIfNoModelConverter -def test_add_tensor_vgf_INT(test_data: input_t1): - pipeline = VgfPipeline[input_t1]( - Add(), - test_data(), - aten_op, - exir_op, - tosa_version="TOSA-1.0+INT", - run_on_vulkan_runtime=True, - ) - try: - pipeline.run() - except FileNotFoundError as e: - pytest.skip(f"VKML executor_runner not found - not built - skip {e}")