Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/arm/arm_vela.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False)
np_path = os.path.join(tmpdir, "output", "out_vela.npz")
else:
np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
blocks = b""

blocks = b""
with np.load(np_path, allow_pickle=False) as data:
# Construct our modified output_blocks with data in a form easily
# digested on the device side
Expand All @@ -92,7 +92,7 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False)
if not isinstance(data["scratch_shape"][0], np.int64):
raise RuntimeError("Expected scratch to be int64")
block_length = int(data["scratch_shape"][0])
bin_blocks["scratch_data"] = b"\x00" * block_length
bin_blocks["scratch_size"] = struct.pack("<I", block_length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice


# Capture inputs and outputs
bin_blocks["inputs"] = vela_bin_pack_io("input", data)
Expand Down
37 changes: 26 additions & 11 deletions backends/arm/runtime/EthosUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ using executorch::runtime::FreeableBuffer;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::Result;

#define ETHOSU_NUM_BASE_ADDRS 3

namespace executorch {
namespace backends {
namespace arm {
Expand Down Expand Up @@ -181,23 +183,33 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
}
EXECUTORCH_PROF_END(event_tracer, event_tracer_local_scope);

MemoryAllocator* temp_allocator = context.get_temp_allocator();
// Use a temporary allocator for the intermediate tensors of the
// computation. The allocator is released in runtime/executor/method.cpp at
// the end of the execution of the Ethos-U custom delegate
char* ethosu_scratch =
static_cast<char*>(temp_allocator->allocate(handles.scratch_data_size));
extern size_t ethosu_fast_scratch_size;
extern unsigned char* ethosu_fast_scratch;
Comment on lines +192 to +193
Copy link
Contributor

@digantdesai digantdesai May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this is not the cleanest. Let me think of a better way to do this.

ET_LOG(
Debug,
"EthosUBackend::execute: Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n",
"EthosUBackend::execute: Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n fast scratch %p %zu\n",
handles.cmd_data,
handles.cmd_data_size,
handles.weight_data,
handles.weight_data_size,
handles.scratch_data,
handles.scratch_data_size);
ethosu_scratch,
handles.scratch_data_size,
ethosu_fast_scratch,
ethosu_fast_scratch_size);

// Write argument values (from EValue tensor) into Ethos-U scratch
// TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM
// or DRAM output for compatible data layouts.
for (int i = 0; i < handles.inputs->count; i++) {
auto tensor_count = 1, io_count = 1;
auto tensor_in = args[i]->toTensor();
char* scratch_addr = handles.scratch_data + handles.inputs->io[i].offset;
char* scratch_addr = ethosu_scratch + handles.inputs->io[i].offset;

// We accept:
bool supported = 0;
Expand Down Expand Up @@ -294,13 +306,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
// Ethos-U low level driver expected order for Ethos U-55, we have
// constant weight data, then scratch (which contains input and output)
// scratch is written above in this function.
uint64_t bases[2] = {

uint64_t bases[ETHOSU_NUM_BASE_ADDRS] = {
static_cast<uint64_t>(
reinterpret_cast<uintptr_t>((handles.weight_data))),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(ethosu_scratch)),
static_cast<uint64_t>(
reinterpret_cast<uintptr_t>((handles.scratch_data)))};
size_t bases_size[2] = {
handles.weight_data_size, handles.scratch_data_size};
reinterpret_cast<uintptr_t>(ethosu_fast_scratch))};
size_t bases_size[ETHOSU_NUM_BASE_ADDRS] = {
handles.weight_data_size,
handles.scratch_data_size,
ethosu_fast_scratch_size};
int result = 0;
EXECUTORCH_PROF_START(
event_tracer, event_tracer_local_scope, "+EthosUBackend::execute()NPU");
Expand All @@ -310,7 +326,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
handles.cmd_data_size,
bases,
bases_size,
2, /* fixed array of pointers to binary interface*/
3, /* fixed array of pointers to binary interface*/
nullptr);
EXECUTORCH_PROF_END(event_tracer, event_tracer_local_scope);

Expand All @@ -325,8 +341,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
// Write outputs from scratch into EValue pointers
for (int i = 0; i < handles.outputs->count; i++) {
int tensor_count = 1, io_count = 1;
const char* output_addr =
handles.scratch_data + handles.outputs->io[i].offset;
const char* output_addr = ethosu_scratch + handles.outputs->io[i].offset;
// Process input EValue into scratch
// Outputs are in the index immediately after inputs
auto tensor_out = args[handles.inputs->count + i]->toTensor();
Expand Down
9 changes: 5 additions & 4 deletions backends/arm/runtime/VelaBinStream.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Arm Limited and/or its affiliates.
* Copyright 2023, 2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -71,9 +71,10 @@ bool vela_bin_read(const char* data, VelaHandles* handles, int size) {
} else if (!strncmp(b->name, "weight_data", strlen("weight_data"))) {
handles->weight_data = b->data;
handles->weight_data_size = b->size;
} else if (!strncmp(b->name, "scratch_data", strlen("scratch_data"))) {
handles->scratch_data = b->data;
handles->scratch_data_size = b->size;
} else if (!strncmp(b->name, "scratch_size", strlen("scratch_size"))) {
const uint32_t* scratch_size_ptr =
reinterpret_cast<const uint32_t*>(b->data);
handles->scratch_data_size = *scratch_size_ptr;
} else if (!strncmp(b->name, "inputs", strlen("inputs"))) {
handles->inputs = (VelaIOs*)b->data;
} else if (!strncmp(b->name, "outputs", strlen("outputs"))) {
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/scripts/build_executor_runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ then
memory_mode="Shared_Sram"
if [[ ${target} =~ "ethos-u85" ]]
then
memory_mode="Sram_Only"
memory_mode="Dedicated_Sram_384KB"
fi
fi

Expand Down
20 changes: 10 additions & 10 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,28 @@ def __init__(self):
# (t, c, n, s) = (6, 96, 1, 1)
# 1. 1x1 CONV2d + ReLU6 (Pointwise)
self.pointwise_conv2d = torch.nn.Conv2d(
in_channels=64, out_channels=384, kernel_size=1, stride=1, groups=1
) ## (1, 384, 81, 81)
self.batch_norm2d_16 = torch.nn.BatchNorm2d(384, affine=False)
in_channels=32, out_channels=128, kernel_size=1, stride=1, groups=1
) ## (1, 128, 81, 81)
self.batch_norm2d_16 = torch.nn.BatchNorm2d(128, affine=False)
self.relu6 = torch.nn.ReLU6()

# 2. 3x3 DepthwiseConv2d + ReLu6
self.depthwise_conv2d = torch.nn.Conv2d(
in_channels=384,
out_channels=384,
in_channels=128,
out_channels=128,
kernel_size=3,
padding=1,
stride=1,
groups=384,
) ## (1, 384, H, W)
groups=128,
) ## (1, 128, H, W)

# 3. Linear 1x1 Conv2d
self.pointwise_conv2d_linear = torch.nn.Conv2d(
in_channels=384, out_channels=64, kernel_size=1, stride=1, groups=1
) ## (1, 64, 81, 81)
in_channels=128, out_channels=32, kernel_size=1, stride=1, groups=1
) ## (1, 32, 81, 81)

def get_inputs(self) -> Tuple[torch.Tensor]:
return (torch.randn(1, 64, 81, 81),)
return (torch.randn(1, 32, 81, 81),)

def forward(self, x):
input = x
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/test/test_arm_baremetal.sh
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ test_models_ethos-u85() { # End to End model tests using model_test.py

# Ethos-U85
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=mv2 --extra_flags="-DET_ATOL=2.00 -DET_RTOL=2.00"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-512 --model=mv3 --extra_flags="-DET_ATOL=5.00 -DET_RTOL=5.00"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=mv2 --extra_flags="-DET_ATOL=2.00 -DET_RTOL=2.00"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-512 --model=mv3 --extra_flags="-DET_ATOL=5.00 -DET_RTOL=5.00"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=lstm --extra_flags="-DET_ATOL=0.03 -DET_RTOL=0.03"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=w2l --extra_flags="-DET_ATOL=0.01 -DET_RTOL=0.01"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=ic4 --extra_flags="-DET_ATOL=0.8 -DET_RTOL=0.8" --timeout=2400
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=w2l --extra_flags="-DET_ATOL=0.01 -DET_RTOL=0.01"
python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=ic4 --extra_flags="-DET_ATOL=0.8 -DET_RTOL=0.8" --timeout=2400

echo "${TEST_SUITE_NAME}: PASS"
}
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_args():
if "u55" in args.target:
args.memory_mode = "Shared_Sram"
elif "u85" in args.target:
args.memory_mode = "Sram_Only"
args.memory_mode = "Dedicated_Sram_384KB"
else:
raise RuntimeError(f"Invalid target name {args.target}")

Expand Down
82 changes: 45 additions & 37 deletions examples/arm/executor_runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ project(arm_executor_runner)

option(SEMIHOSTING "Enable semihosting" OFF)
option(ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE "Set ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE to specify memory alloction pool size" OFF)
option(ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE "Set ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE to specify temp alloction pool size" OFF)
option(ET_BUNDLE_IO "Set to compile in BundleIO support" OFF)
option(ET_ATOL "Set atol to use for BundleIO testing" OFF)
option(ET_RTOL "Set rtol to use for BundleIO testing" OFF)
Expand Down Expand Up @@ -99,20 +98,45 @@ if(NOT ${SEMIHOSTING})
get_filename_component(ET_PTE_FILE_PATH ${ET_PTE_FILE_PATH} REALPATH)
endif()

if(SYSTEM_CONFIG MATCHES "Ethos_U55")
add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-300 target)
elseif(SYSTEM_CONFIG MATCHES "Ethos_U85")
add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-320 target)
else()
message(FATAL_ERROR "Unsupported SYSTEM_CONFIG ${SYSTEM_CONFIG}.")
endif()

if(MEMORY_MODE MATCHES "Dedicated_Sram")
target_compile_definitions(ethosu_target_common INTERFACE
ETHOSU_MODEL=1
ETHOSU_ARENA=1)
elseif(MEMORY_MODE MATCHES "Shared_Sram" OR MEMORY_MODE MATCHES "Sram_Only")
target_compile_definitions(ethosu_target_common INTERFACE
ETHOSU_MODEL=1
ETHOSU_ARENA=0)
else()
message(FATAL_ERROR "Unsupported MEMORY_MODE ${MEMORY_MODE}. Memory_mode can be Shared_Sram, Sram_Only or Dedicated_Sram(applicable for the Ethos-U85)")
endif()

# By default, use 2MB of temporary scratch buffer
# For Dedicated_Sram, use 128MB for the temporary scratch buffer and
# 384KB for the fast scratch buffer(the cache, applicable only for Ethos-U65 and Ethos-U85)
set(ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE 0x200000)
if(MEMORY_MODE MATCHES "Dedicated_Sram")
set(ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE 0x8000000)
set(ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE 0x60000)
endif()
message(STATUS "ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE = ${ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE}")
message(STATUS "ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE = ${ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE}")

# Dependencies from the Ethos-U Core This is the platform target of
# Corstone-300, that includes ethosu_core_driver and bare-metal bringup
# libraries. We link against ethosu_target_init which includes all of these
# dependencies.
if(SYSTEM_CONFIG STREQUAL "Ethos_U55_High_End_Embedded")
add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-300 target)
if(SYSTEM_CONFIG MATCHES "Ethos_U55_High_End_Embedded")
set(TARGET_BOARD "corstone-300")
if(MEMORY_MODE STREQUAL "Shared_Sram")
if(MEMORY_MODE MATCHES "Shared_Sram")
target_compile_definitions(ethosu_target_common INTERFACE
# ETHOSU_MODEL=0 place pte file/data in SRAM area
# ETHOSU_MODEL=1 place pte file/data in DDR area
ETHOSU_MODEL=1
# Configure NPU architecture timing adapters
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -144,7 +168,7 @@ if(SYSTEM_CONFIG STREQUAL "Ethos_U55_High_End_Embedded")
ETHOSU_TA_HISTBIN_1=0
ETHOSU_TA_HISTCNT_1=0
)
elseif(MEMORY_MODE STREQUAL "Sram_Only")
elseif(MEMORY_MODE MATCHES "Sram_Only")
target_compile_definitions(ethosu_target_common INTERFACE
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -180,14 +204,11 @@ if(SYSTEM_CONFIG STREQUAL "Ethos_U55_High_End_Embedded")
else()
message(FATAL_ERROR "Unsupported memory_mode ${MEMORY_MODE} for the Ethos-U55. The Ethos-U55 supports only Shared_Sram and Sram_Only.")
endif()
elseif(SYSTEM_CONFIG STREQUAL "Ethos_U55_Deep_Embedded")
elseif(SYSTEM_CONFIG MATCHES "Ethos_U55_Deep_Embedded")
add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-300 target)
set(TARGET_BOARD "corstone-300")
if(MEMORY_MODE STREQUAL "Shared_Sram")
if(MEMORY_MODE MATCHES "Shared_Sram")
target_compile_definitions(ethosu_target_common INTERFACE
# ETHOSU_MODEL=0 place pte file/data in SRAM area
# ETHOSU_MODEL=1 place pte file/data in DDR area
ETHOSU_MODEL=1
# Configure NPU architecture timing adapters
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -219,9 +240,8 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U55_Deep_Embedded")
ETHOSU_TA_HISTBIN_1=0
ETHOSU_TA_HISTCNT_1=0
)
elseif(MEMORY_MODE STREQUAL "Sram_Only")
elseif(MEMORY_MODE MATCHES "Sram_Only")
target_compile_definitions(ethosu_target_common INTERFACE
ETHOSU_MODEL=1
# Configure NPU architecture timing adapters
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -256,14 +276,11 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U55_Deep_Embedded")
else()
message(FATAL_ERROR "Unsupported memory_mode ${MEMORY_MODE} for the Ethos-U55. The Ethos-U55 supports only Shared_Sram and Sram_Only.")
endif()
elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Low")
elseif(SYSTEM_CONFIG MATCHES "Ethos_U85_SYS_DRAM_Low")
add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-320 target)
set(TARGET_BOARD "corstone-320")
if(MEMORY_MODE STREQUAL "Dedicated_Sram")
if(MEMORY_MODE MATCHES "Dedicated_Sram")
target_compile_definitions(ethosu_target_common INTERFACE
# ETHOSU_MODEL=0 place pte file/data in SRAM area
# ETHOSU_MODEL=1 place pte file/data in DDR area
ETHOSU_MODEL=1
# Configure NPU architecture timing adapters
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -295,11 +312,8 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Low")
ETHOSU_TA_HISTBIN_1=0
ETHOSU_TA_HISTCNT_1=0
)
elseif(MEMORY_MODE STREQUAL "Sram_Only")
elseif(MEMORY_MODE MATCHES "Sram_Only")
target_compile_definitions(ethosu_target_common INTERFACE
# ETHOSU_MODEL=0 place pte file/data in SRAM area
# ETHOSU_MODEL=1 place pte file/data in DDR area
ETHOSU_MODEL=1
# Configure NPU architecture timing adapters
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -333,13 +347,9 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Low")
)
endif()
elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Mid" OR SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_High")
add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-320 target)
set(TARGET_BOARD "corstone-320")
if(MEMORY_MODE STREQUAL "Dedicated_Sram")
if(MEMORY_MODE MATCHES "Dedicated_Sram")
target_compile_definitions(ethosu_target_common INTERFACE
# ETHOSU_MODEL=0 place pte file/data in SRAM area
# ETHOSU_MODEL=1 place pte file/data in DDR area
ETHOSU_MODEL=1
# Configure NPU architecture timing adapters
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -371,11 +381,8 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Mid" OR SYSTEM_CONFIG STREQUAL
ETHOSU_TA_HISTBIN_1=0
ETHOSU_TA_HISTCNT_1=0
)
elseif(MEMORY_MODE STREQUAL "Sram_Only")
elseif(MEMORY_MODE MATCHES "Sram_Only")
target_compile_definitions(ethosu_target_common INTERFACE
# ETHOSU_MODEL=0 place pte file/data in SRAM area
# ETHOSU_MODEL=1 place pte file/data in DDR area
ETHOSU_MODEL=1
# Configure NPU architecture timing adapters
# This is just example numbers and you should make this match your hardware
# SRAM
Expand Down Expand Up @@ -434,7 +441,7 @@ endif()
# the memory traffic of Region 1 should pass via the external memory(3) and the traffic for Region 2 should pass via the SRAM(0)
#

if(MEMORY_MODE STREQUAL "Sram_Only")
if(MEMORY_MODE MATCHES "Sram_Only")
target_compile_definitions(ethosu_core_driver PRIVATE
NPU_QCONFIG=1
NPU_REGIONCFG_0=1
Expand All @@ -445,7 +452,7 @@ if(MEMORY_MODE STREQUAL "Sram_Only")
NPU_REGIONCFG_5=0
NPU_REGIONCFG_6=0
NPU_REGIONCFG_7=0)
elseif(MEMORY_MODE STREQUAL "Dedicated_Sram")
elseif(MEMORY_MODE MATCHES "Dedicated_Sram")
target_compile_definitions(ethosu_core_driver PRIVATE
NPU_QCONFIG=3
NPU_REGIONCFG_0=3
Expand Down Expand Up @@ -632,8 +639,9 @@ if(ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE)
target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE})
endif()

if(ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE)
target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE})
target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE})
if(DEFINED ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE)
target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE})
endif()

if(ET_BUNDLE_IO)
Expand Down
Loading
Loading