Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backends/arm/test/ops/test_acos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Tuple

import pytest
import torch

from executorch.backends.arm.test import common
Expand Down Expand Up @@ -102,8 +103,12 @@ def test_acos_vgf_FP(test_data: Tuple):
[],
[],
tosa_version="TOSA-1.0+FP",
run_on_vulkan_runtime=True,
)
pipeline.run()
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we do it like this -

class TOSAPipelineMaker(BasePipelineMaker, Generic[T]):

Else I imagine we have to write this for every single test..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed we shall not do a try/except in every unit test. Just wanted to enable a test that tested this. We will either do a central try/except or add xfails for remaining tests in coming patches.

pipeline.run()
except FileNotFoundError as e:
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")


@common.parametrize("test_data", test_data_suite)
Expand All @@ -115,5 +120,9 @@ def test_acos_vgf_INT(test_data: Tuple):
[],
[],
tosa_version="TOSA-1.0+INT",
run_on_vulkan_runtime=True,
)
pipeline.run()
try:
pipeline.run()
except FileNotFoundError as e:
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")
9 changes: 2 additions & 7 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,7 @@ def test_add_tensor_u85_INT_2(test_data: input_t2):
pipeline.run()


# TODO/MLETORCH-1282: remove once inputs are not hard coded to ones
skip_keys = {"5d_float", "1d_ones", "1d_randn"}
filtered_test_data = {k: v for k, v in Add.test_data.items() if k not in skip_keys}


@common.parametrize("test_data", filtered_test_data)
@common.parametrize("test_data", Add.test_data)
@common.SkipIfNoModelConverter
@pytest.mark.xfail(reason="'Failed to load VKML extensions' error in ci.")
def test_add_tensor_vgf_FP(test_data: input_t1):
Expand All @@ -225,7 +220,7 @@ def test_add_tensor_vgf_FP(test_data: input_t1):
pytest.skip(f"VKML executor_runner not found - not built - skip {e}")


@common.parametrize("test_data", filtered_test_data)
@common.parametrize("test_data", Add.test_data)
@common.SkipIfNoModelConverter
@pytest.mark.xfail(reason="'Failed to load VKML extensions' error in ci.")
def test_add_tensor_vgf_INT(test_data: input_t1):
Expand Down
93 changes: 58 additions & 35 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,48 @@ def run_target(
elif target_board == "vkml_emulation_layer":
return run_vkml_emulation_layer(
executorch_program_manager,
inputs,
intermediate_path,
elf_path,
)


def save_inputs_to_file(
exported_program: ExportedProgram,
inputs: Tuple[torch.Tensor],
intermediate_path: str | Path,
):
input_file_paths = []
input_names = get_input_names(exported_program)
for input_name, input_ in zip(input_names, inputs):
input_path = save_bytes(intermediate_path, input_, input_name)
input_file_paths.append(input_path)

return input_file_paths


def get_output_from_file(
exported_program: ExportedProgram,
intermediate_path: str | Path,
output_base_name: str,
):
output_np = []
output_node = exported_program.graph_module.graph.output_node()
for i, node in enumerate(output_node.args[0]):
output_shape = node.meta["val"].shape
output_dtype = node.meta["val"].dtype
tosa_ref_output = np.fromfile(
os.path.join(intermediate_path, f"{output_base_name}-{i}.bin"),
_torch_to_numpy_dtype_dict[output_dtype],
)

output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
return tuple(output_np)


def run_vkml_emulation_layer(
executorch_program_manager: ExecutorchProgramManager,
inputs: Tuple[torch.Tensor],
intermediate_path: str | Path,
elf_path: str | Path,
):
Expand All @@ -239,7 +274,7 @@ def run_vkml_emulation_layer(
`intermediate_path`: Directory to save the .pte and capture outputs.
`elf_path`: Path to the Vulkan-capable executor_runner binary.
"""

exported_program = executorch_program_manager.exported_program()
intermediate_path = Path(intermediate_path)
intermediate_path.mkdir(exist_ok=True)
elf_path = Path(elf_path)
Expand All @@ -251,26 +286,29 @@ def run_vkml_emulation_layer(
with open(pte_path, "wb") as f:
f.write(executorch_program_manager.buffer)

cmd_line = [str(elf_path), "-model_path", pte_path]
output_base_name = "out"
out_path = os.path.join(intermediate_path, output_base_name)

cmd_line = f"{elf_path} -model_path {pte_path} -output_file {out_path}"

input_string = None
input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path)
for input_path in input_paths:
if input_string is None:
input_string = f" -inputs={input_path}"
else:
input_string += f",{input_path}"
if input_string is not None:
cmd_line += input_string
cmd_line = cmd_line.split()

result = _run_cmd(cmd_line)

result_stdout = result.stdout.decode() # noqa: F841
# TODO: MLETORCH-1234: Support VGF e2e tests in VgfPipeline
# TODO: Add regex to check for error or fault messages in stdout from Emulation Layer
# Regex to extract tensor values from stdout
output_np = []
matches = re.findall(
r"Output\s+\d+:\s+tensor\(sizes=\[(.*?)\],\s+\[(.*?)\]\)",
result_stdout,
re.DOTALL,
)

for shape_str, values_str in matches:
shape = list(map(int, shape_str.split(",")))
values = list(map(float, re.findall(r"[-+]?\d*\.\d+|\d+", values_str)))
output_np.append(torch.tensor(values).reshape(shape))
result_stdout = result.stdout.decode() # noqa: F841

return tuple(output_np)
return get_output_from_file(exported_program, intermediate_path, output_base_name)


def run_corstone(
Expand Down Expand Up @@ -312,14 +350,10 @@ def run_corstone(
with open(pte_path, "wb") as f:
f.write(executorch_program_manager.buffer)

# Save inputs to file
input_names = get_input_names(exported_program)
input_paths = []
for input_name, input_ in zip(input_names, inputs):
input_path = save_bytes(intermediate_path, input_, input_name)
input_paths.append(input_path)
input_paths = save_inputs_to_file(exported_program, inputs, intermediate_path)

out_path = os.path.join(intermediate_path, "out")
output_base_name = "out"
out_path = os.path.join(intermediate_path, output_base_name)

cmd_line = f"executor_runner -m {pte_path} -o {out_path}"
for input_path in input_paths:
Expand Down Expand Up @@ -401,18 +435,7 @@ def run_corstone(
f"Corstone simulation failed:\ncmd: {' '.join(command_args)}\nlog: \n {result_stdout}\n{result.stderr.decode()}"
)

output_np = []
output_node = exported_program.graph_module.graph.output_node()
for i, node in enumerate(output_node.args[0]):
output_shape = node.meta["val"].shape
output_dtype = node.meta["val"].dtype
tosa_ref_output = np.fromfile(
os.path.join(intermediate_path, f"out-{i}.bin"),
_torch_to_numpy_dtype_dict[output_dtype],
)

output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
return tuple(output_np)
return get_output_from_file(exported_program, intermediate_path, output_base_name)


def prep_data_for_save(
Expand Down
113 changes: 106 additions & 7 deletions examples/portable/executor_runner/executor_runner.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* Copyright 2024-2025 Arm Limited and/or its affiliates.
* All rights reserved.
* Copyright 2024-2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand All @@ -18,6 +18,7 @@
* all fp32 tensors.
*/

#include <fstream>
#include <iostream>
#include <memory>

Expand Down Expand Up @@ -49,6 +50,16 @@ DEFINE_string(
model_path,
"model.pte",
"Model serialized in flatbuffer format.");
DEFINE_string(inputs, "", "Comma-separated list of input files");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JFYI you can also use bundled-io in the PTE.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, perhaps add that in a separate patch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outputs can also be saved with ETDump but that is maybe a bigger/separate patch this file should probably be re-synced with examples/devtools/example_runner/example_runner.cpp someday.

Or @digantdesai do you want to avoid adding "inputs" as an argument and only handle it via BundleIO?

E.g. is this something you prefere us to fix now in this PR and later after 1.0/GA?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just spoted your "LGTM" so that answers my question :)

DEFINE_string(
output_file,
"",
"Base name of output file. If not empty output will be written to the file(s).");

DEFINE_bool(
print_all_output,
false,
"Prints all output. By default only first and last 100 elements are printed.");
DEFINE_uint32(num_executions, 1, "Number of times to run the model.");
#ifdef ET_EVENT_TRACER_ENABLED
DEFINE_string(etdump_path, "model.etdump", "Write ETDump data to this path.");
Expand All @@ -58,6 +69,8 @@ DEFINE_int32(
-1,
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");

using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::extension::FileDataLoader;
using executorch::runtime::Error;
using executorch::runtime::EValue;
Expand All @@ -70,6 +83,8 @@ using executorch::runtime::MethodMeta;
using executorch::runtime::Program;
using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::Tag;
using executorch::runtime::TensorInfo;

/// Helper to manage resources for ETDump generation
class EventTraceManager {
Expand Down Expand Up @@ -156,6 +171,31 @@ int main(int argc, char** argv) {
"FileDataLoader::from() failed: 0x%" PRIx32,
(uint32_t)loader.error());

std::vector<std::string> inputs_storage;
std::vector<std::pair<char*, size_t>> input_buffers;

std::stringstream list_of_input_files(FLAGS_inputs);
std::string token;

while (std::getline(list_of_input_files, token, ',')) {
std::ifstream input_file_handle(token, std::ios::binary | std::ios::ate);
if (!input_file_handle) {
ET_LOG(Error, "Failed to open input file: %s\n", token.c_str());
return 1;
}

std::streamsize file_size = input_file_handle.tellg();
input_file_handle.seekg(0, std::ios::beg);

inputs_storage.emplace_back(file_size, '\0');
if (!input_file_handle.read(&inputs_storage.back()[0], file_size)) {
ET_LOG(Error, "Failed to read input file: %s\n", token.c_str());
return 1;
}

input_buffers.emplace_back(&inputs_storage.back()[0], file_size);
}

// Parse the program file. This is immutable, and can also be reused between
// multiple execution invocations across multiple threads.
Result<Program> program = Program::load(&loader.get());
Expand Down Expand Up @@ -254,15 +294,17 @@ int main(int argc, char** argv) {
// Run the model.
for (uint32_t i = 0; i < FLAGS_num_executions; i++) {
ET_LOG(Debug, "Preparing inputs.");
// Allocate input tensors and set all of their elements to 1. The `inputs`
// Allocate input tensors and set all of their elements to 1 or to the
// contents of input_buffers if available. The `inputs`
// variable owns the allocated memory and must live past the last call to
// `execute()`.
//
// NOTE: we have to re-prepare input tensors on every execution
// because inputs whose space gets reused by memory planning (if
// any such inputs exist) will not be preserved for the next
// execution.
auto inputs = executorch::extension::prepare_input_tensors(*method);
auto inputs = executorch::extension::prepare_input_tensors(
*method, {}, input_buffers);
ET_CHECK_MSG(
inputs.ok(),
"Could not prepare inputs: 0x%" PRIx32,
Expand Down Expand Up @@ -295,10 +337,67 @@ int main(int argc, char** argv) {
ET_LOG(Info, "%zu outputs: ", outputs.size());
Error status = method->get_outputs(outputs.data(), outputs.size());
ET_CHECK(status == Error::Ok);
// Print the first and last 100 elements of long lists of scalars.
std::cout << executorch::extension::evalue_edge_items(100);
for (int i = 0; i < outputs.size(); ++i) {
std::cout << "Output " << i << ": " << outputs[i] << std::endl;

if (FLAGS_output_file.size() > 0) {
for (int i = 0; i < outputs.size(); ++i) {
if (outputs[i].isTensor()) {
Tensor tensor = outputs[i].toTensor();

char out_filename[255];
snprintf(out_filename, 255, "%s-%d.bin", FLAGS_output_file.c_str(), i);
ET_LOG(Info, "Writing output to file: %s", out_filename);
FILE* out_file = fopen(out_filename, "wb");
auto written_size =
fwrite(tensor.const_data_ptr<char>(), 1, tensor.nbytes(), out_file);
fclose(out_file);
}
}
}

if (FLAGS_print_all_output) {
for (int i = 0; i < outputs.size(); ++i) {
if (outputs[i].isTensor()) {
Tensor tensor = outputs[i].toTensor();

for (int j = 0; j < tensor.numel(); ++j) {
if (tensor.scalar_type() == ScalarType::Int) {
printf(
"Output[%d][%d]: (int) %d\n",
i,
j,
tensor.const_data_ptr<int>()[j]);
} else if (tensor.scalar_type() == ScalarType::Float) {
printf(
"Output[%d][%d]: (float) %f\n",
i,
j,
tensor.const_data_ptr<float>()[j]);
} else if (tensor.scalar_type() == ScalarType::Char) {
printf(
"Output[%d][%d]: (char) %d\n",
i,
j,
tensor.const_data_ptr<int8_t>()[j]);
} else if (tensor.scalar_type() == ScalarType::Bool) {
printf(
"Output[%d][%d]: (bool) %s (0x%x)\n",
i,
j,
tensor.const_data_ptr<int8_t>()[j] ? "true " : "false",
tensor.const_data_ptr<int8_t>()[j]);
}
}
} else {
printf("Output[%d]: Not Tensor\n", i);
}
}
} else {
// Print the first and last 100 elements of long lists of scalars.
std::cout << executorch::extension::evalue_edge_items(100);

for (int i = 0; i < outputs.size(); ++i) {
std::cout << "OutputX " << i << ": " << outputs[i] << std::endl;
}
}

if (tracer.get_event_tracer()) {
Expand Down
Loading
Loading