diff --git a/backends/arm/arm_vela.py b/backends/arm/arm_vela.py index c931d49547f..c47a5c58f49 100644 --- a/backends/arm/arm_vela.py +++ b/backends/arm/arm_vela.py @@ -73,8 +73,8 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False) np_path = os.path.join(tmpdir, "output", "out_vela.npz") else: np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz") - blocks = b"" + blocks = b"" with np.load(np_path, allow_pickle=False) as data: # Construct our modified output_blocks with data in a form easily # digested on the device side @@ -92,7 +92,7 @@ def vela_compile(tosa_flatbuffer: bytes, args: List[str], verbose: bool = False) if not isinstance(data["scratch_shape"][0], np.int64): raise RuntimeError("Expected scratch to be int64") block_length = int(data["scratch_shape"][0]) - bin_blocks["scratch_data"] = b"\x00" * block_length + bin_blocks["scratch_size"] = struct.pack("(temp_allocator->allocate(handles.scratch_data_size)); + extern size_t ethosu_fast_scratch_size; + extern unsigned char* ethosu_fast_scratch; ET_LOG( Debug, - "EthosUBackend::execute: Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n", + "EthosUBackend::execute: Running program data:\n cmd %p %zu\n weight %p %zu\n scratch %p %zu\n fast scratch %p %zu\n", handles.cmd_data, handles.cmd_data_size, handles.weight_data, handles.weight_data_size, - handles.scratch_data, - handles.scratch_data_size); + ethosu_scratch, + handles.scratch_data_size, + ethosu_fast_scratch, + ethosu_fast_scratch_size); // Write argument values (from EValue tensor) into Ethos-U scratch // TODO(MLETORCH-123): Optimise into direct write from Vela into the SRAM @@ -197,7 +209,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { for (int i = 0; i < handles.inputs->count; i++) { auto tensor_count = 1, io_count = 1; auto tensor_in = args[i]->toTensor(); - char* scratch_addr = handles.scratch_data + handles.inputs->io[i].offset; + char* scratch_addr = ethosu_scratch + handles.inputs->io[i].offset; // We accept: bool supported = 0; @@ -294,13 +306,17 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { // Ethos-U low level driver expected order for Ethos U-55, we have // constant weight data, then scratch (which contains input and output) // scratch is written above in this function. - uint64_t bases[2] = { + + uint64_t bases[ETHOSU_NUM_BASE_ADDRS] = { static_cast( reinterpret_cast((handles.weight_data))), + static_cast(reinterpret_cast(ethosu_scratch)), static_cast( - reinterpret_cast((handles.scratch_data)))}; - size_t bases_size[2] = { - handles.weight_data_size, handles.scratch_data_size}; + reinterpret_cast(ethosu_fast_scratch))}; + size_t bases_size[ETHOSU_NUM_BASE_ADDRS] = { + handles.weight_data_size, + handles.scratch_data_size, + ethosu_fast_scratch_size}; int result = 0; EXECUTORCH_PROF_START( event_tracer, event_tracer_local_scope, "+EthosUBackend::execute()NPU"); @@ -310,7 +326,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { handles.cmd_data_size, bases, bases_size, - 2, /* fixed array of pointers to binary interface*/ + 3, /* fixed array of pointers to binary interface*/ nullptr); EXECUTORCH_PROF_END(event_tracer, event_tracer_local_scope); @@ -325,8 +341,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { // Write outputs from scratch into EValue pointers for (int i = 0; i < handles.outputs->count; i++) { int tensor_count = 1, io_count = 1; - const char* output_addr = - handles.scratch_data + handles.outputs->io[i].offset; + const char* output_addr = ethosu_scratch + handles.outputs->io[i].offset; // Process input EValue into scratch // Outputs are in the index immediately after inputs auto tensor_out = args[handles.inputs->count + i]->toTensor(); diff --git a/backends/arm/runtime/VelaBinStream.cpp b/backends/arm/runtime/VelaBinStream.cpp index a26fe9f23e2..fbd9e2daadb 100644 --- a/backends/arm/runtime/VelaBinStream.cpp +++ b/backends/arm/runtime/VelaBinStream.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2023 Arm Limited and/or its affiliates. + * Copyright 2023, 2025 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -71,9 +71,10 @@ bool vela_bin_read(const char* data, VelaHandles* handles, int size) { } else if (!strncmp(b->name, "weight_data", strlen("weight_data"))) { handles->weight_data = b->data; handles->weight_data_size = b->size; - } else if (!strncmp(b->name, "scratch_data", strlen("scratch_data"))) { - handles->scratch_data = b->data; - handles->scratch_data_size = b->size; + } else if (!strncmp(b->name, "scratch_size", strlen("scratch_size"))) { + const uint32_t* scratch_size_ptr = + reinterpret_cast(b->data); + handles->scratch_data_size = *scratch_size_ptr; } else if (!strncmp(b->name, "inputs", strlen("inputs"))) { handles->inputs = (VelaIOs*)b->data; } else if (!strncmp(b->name, "outputs", strlen("outputs"))) { diff --git a/backends/arm/scripts/build_executor_runner.sh b/backends/arm/scripts/build_executor_runner.sh index 807821d427f..9e2f3954c53 100755 --- a/backends/arm/scripts/build_executor_runner.sh +++ b/backends/arm/scripts/build_executor_runner.sh @@ -103,7 +103,7 @@ then memory_mode="Shared_Sram" if [[ ${target} =~ "ethos-u85" ]] then - memory_mode="Sram_Only" + memory_mode="Dedicated_Sram_384KB" fi fi diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index c06a6e666ec..bddc30f04ab 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -41,28 +41,28 @@ def __init__(self): # (t, c, n, s) = (6, 96, 1, 1) # 1. 1x1 CONV2d + ReLU6 (Pointwise) self.pointwise_conv2d = torch.nn.Conv2d( - in_channels=64, out_channels=384, kernel_size=1, stride=1, groups=1 - ) ## (1, 384, 81, 81) - self.batch_norm2d_16 = torch.nn.BatchNorm2d(384, affine=False) + in_channels=32, out_channels=128, kernel_size=1, stride=1, groups=1 + ) ## (1, 128, 81, 81) + self.batch_norm2d_16 = torch.nn.BatchNorm2d(128, affine=False) self.relu6 = torch.nn.ReLU6() # 2. 3x3 DepthwiseConv2d + ReLu6 self.depthwise_conv2d = torch.nn.Conv2d( - in_channels=384, - out_channels=384, + in_channels=128, + out_channels=128, kernel_size=3, padding=1, stride=1, - groups=384, - ) ## (1, 384, H, W) + groups=128, + ) ## (1, 128, H, W) # 3. Linear 1x1 Conv2d self.pointwise_conv2d_linear = torch.nn.Conv2d( - in_channels=384, out_channels=64, kernel_size=1, stride=1, groups=1 - ) ## (1, 64, 81, 81) + in_channels=128, out_channels=32, kernel_size=1, stride=1, groups=1 + ) ## (1, 32, 81, 81) def get_inputs(self) -> Tuple[torch.Tensor]: - return (torch.randn(1, 64, 81, 81),) + return (torch.randn(1, 32, 81, 81),) def forward(self, x): input = x diff --git a/backends/arm/test/test_arm_baremetal.sh b/backends/arm/test/test_arm_baremetal.sh index 330f0f138d0..6764dd27d96 100755 --- a/backends/arm/test/test_arm_baremetal.sh +++ b/backends/arm/test/test_arm_baremetal.sh @@ -206,11 +206,11 @@ test_models_ethos-u85() { # End to End model tests using model_test.py # Ethos-U85 echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85" - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=mv2 --extra_flags="-DET_ATOL=2.00 -DET_RTOL=2.00" - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-512 --model=mv3 --extra_flags="-DET_ATOL=5.00 -DET_RTOL=5.00" + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=mv2 --extra_flags="-DET_ATOL=2.00 -DET_RTOL=2.00" + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-512 --model=mv3 --extra_flags="-DET_ATOL=5.00 -DET_RTOL=5.00" python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=lstm --extra_flags="-DET_ATOL=0.03 -DET_RTOL=0.03" - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=w2l --extra_flags="-DET_ATOL=0.01 -DET_RTOL=0.01" - python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=ic4 --extra_flags="-DET_ATOL=0.8 -DET_RTOL=0.8" --timeout=2400 + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-128 --model=w2l --extra_flags="-DET_ATOL=0.01 -DET_RTOL=0.01" + python3 backends/arm/test/test_model.py --test_output=arm_test/test_model --target=ethos-u85-256 --model=ic4 --extra_flags="-DET_ATOL=0.8 -DET_RTOL=0.8" --timeout=2400 echo "${TEST_SUITE_NAME}: PASS" } diff --git a/backends/arm/test/test_model.py b/backends/arm/test/test_model.py index b0fd2f2a381..072583ef862 100755 --- a/backends/arm/test/test_model.py +++ b/backends/arm/test/test_model.py @@ -81,7 +81,7 @@ def get_args(): if "u55" in args.target: args.memory_mode = "Shared_Sram" elif "u85" in args.target: - args.memory_mode = "Sram_Only" + args.memory_mode = "Dedicated_Sram_384KB" else: raise RuntimeError(f"Invalid target name {args.target}") diff --git a/examples/arm/executor_runner/CMakeLists.txt b/examples/arm/executor_runner/CMakeLists.txt index 1568bef0301..6816a55d443 100644 --- a/examples/arm/executor_runner/CMakeLists.txt +++ b/examples/arm/executor_runner/CMakeLists.txt @@ -8,7 +8,6 @@ project(arm_executor_runner) option(SEMIHOSTING "Enable semihosting" OFF) option(ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE "Set ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE to specify memory alloction pool size" OFF) -option(ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE "Set ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE to specify temp alloction pool size" OFF) option(ET_BUNDLE_IO "Set to compile in BundleIO support" OFF) option(ET_ATOL "Set atol to use for BundleIO testing" OFF) option(ET_RTOL "Set rtol to use for BundleIO testing" OFF) @@ -99,20 +98,45 @@ if(NOT ${SEMIHOSTING}) get_filename_component(ET_PTE_FILE_PATH ${ET_PTE_FILE_PATH} REALPATH) endif() +if(SYSTEM_CONFIG MATCHES "Ethos_U55") + add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-300 target) +elseif(SYSTEM_CONFIG MATCHES "Ethos_U85") + add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-320 target) +else() + message(FATAL_ERROR "Unsupported SYSTEM_CONFIG ${SYSTEM_CONFIG}.") +endif() +if(MEMORY_MODE MATCHES "Dedicated_Sram") + target_compile_definitions(ethosu_target_common INTERFACE + ETHOSU_MODEL=1 + ETHOSU_ARENA=1) +elseif(MEMORY_MODE MATCHES "Shared_Sram" OR MEMORY_MODE MATCHES "Sram_Only") + target_compile_definitions(ethosu_target_common INTERFACE + ETHOSU_MODEL=1 + ETHOSU_ARENA=0) +else() + message(FATAL_ERROR "Unsupported MEMORY_MODE ${MEMORY_MODE}. Memory_mode can be Shared_Sram, Sram_Only or Dedicated_Sram(applicable for the Ethos-U85)") +endif() + +# By default, use 2MB of temporary scratch buffer +# For Dedicated_Sram, use 128MB for the temporary scratch buffer and +# 384KB for the fast scratch buffer(the cache, applicable only for Ethos-U65 and Ethos-U85) +set(ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE 0x200000) +if(MEMORY_MODE MATCHES "Dedicated_Sram") + set(ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE 0x8000000) + set(ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE 0x60000) +endif() +message(STATUS "ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE = ${ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE}") +message(STATUS "ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE = ${ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE}") # Dependencies from the Ethos-U Core This is the platform target of # Corstone-300, that includes ethosu_core_driver and bare-metal bringup # libraries. We link against ethosu_target_init which includes all of these # dependencies. -if(SYSTEM_CONFIG STREQUAL "Ethos_U55_High_End_Embedded") - add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-300 target) +if(SYSTEM_CONFIG MATCHES "Ethos_U55_High_End_Embedded") set(TARGET_BOARD "corstone-300") - if(MEMORY_MODE STREQUAL "Shared_Sram") + if(MEMORY_MODE MATCHES "Shared_Sram") target_compile_definitions(ethosu_target_common INTERFACE - # ETHOSU_MODEL=0 place pte file/data in SRAM area - # ETHOSU_MODEL=1 place pte file/data in DDR area - ETHOSU_MODEL=1 # Configure NPU architecture timing adapters # This is just example numbers and you should make this match your hardware # SRAM @@ -144,7 +168,7 @@ if(SYSTEM_CONFIG STREQUAL "Ethos_U55_High_End_Embedded") ETHOSU_TA_HISTBIN_1=0 ETHOSU_TA_HISTCNT_1=0 ) - elseif(MEMORY_MODE STREQUAL "Sram_Only") + elseif(MEMORY_MODE MATCHES "Sram_Only") target_compile_definitions(ethosu_target_common INTERFACE # This is just example numbers and you should make this match your hardware # SRAM @@ -180,14 +204,11 @@ if(SYSTEM_CONFIG STREQUAL "Ethos_U55_High_End_Embedded") else() message(FATAL_ERROR "Unsupported memory_mode ${MEMORY_MODE} for the Ethos-U55. The Ethos-U55 supports only Shared_Sram and Sram_Only.") endif() -elseif(SYSTEM_CONFIG STREQUAL "Ethos_U55_Deep_Embedded") +elseif(SYSTEM_CONFIG MATCHES "Ethos_U55_Deep_Embedded") add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-300 target) set(TARGET_BOARD "corstone-300") - if(MEMORY_MODE STREQUAL "Shared_Sram") + if(MEMORY_MODE MATCHES "Shared_Sram") target_compile_definitions(ethosu_target_common INTERFACE - # ETHOSU_MODEL=0 place pte file/data in SRAM area - # ETHOSU_MODEL=1 place pte file/data in DDR area - ETHOSU_MODEL=1 # Configure NPU architecture timing adapters # This is just example numbers and you should make this match your hardware # SRAM @@ -219,9 +240,8 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U55_Deep_Embedded") ETHOSU_TA_HISTBIN_1=0 ETHOSU_TA_HISTCNT_1=0 ) - elseif(MEMORY_MODE STREQUAL "Sram_Only") + elseif(MEMORY_MODE MATCHES "Sram_Only") target_compile_definitions(ethosu_target_common INTERFACE - ETHOSU_MODEL=1 # Configure NPU architecture timing adapters # This is just example numbers and you should make this match your hardware # SRAM @@ -256,14 +276,11 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U55_Deep_Embedded") else() message(FATAL_ERROR "Unsupported memory_mode ${MEMORY_MODE} for the Ethos-U55. The Ethos-U55 supports only Shared_Sram and Sram_Only.") endif() -elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Low") +elseif(SYSTEM_CONFIG MATCHES "Ethos_U85_SYS_DRAM_Low") add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-320 target) set(TARGET_BOARD "corstone-320") - if(MEMORY_MODE STREQUAL "Dedicated_Sram") + if(MEMORY_MODE MATCHES "Dedicated_Sram") target_compile_definitions(ethosu_target_common INTERFACE - # ETHOSU_MODEL=0 place pte file/data in SRAM area - # ETHOSU_MODEL=1 place pte file/data in DDR area - ETHOSU_MODEL=1 # Configure NPU architecture timing adapters # This is just example numbers and you should make this match your hardware # SRAM @@ -295,11 +312,8 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Low") ETHOSU_TA_HISTBIN_1=0 ETHOSU_TA_HISTCNT_1=0 ) - elseif(MEMORY_MODE STREQUAL "Sram_Only") + elseif(MEMORY_MODE MATCHES "Sram_Only") target_compile_definitions(ethosu_target_common INTERFACE - # ETHOSU_MODEL=0 place pte file/data in SRAM area - # ETHOSU_MODEL=1 place pte file/data in DDR area - ETHOSU_MODEL=1 # Configure NPU architecture timing adapters # This is just example numbers and you should make this match your hardware # SRAM @@ -333,13 +347,9 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Low") ) endif() elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Mid" OR SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_High") - add_subdirectory(${ETHOS_SDK_PATH}/core_platform/targets/corstone-320 target) set(TARGET_BOARD "corstone-320") - if(MEMORY_MODE STREQUAL "Dedicated_Sram") + if(MEMORY_MODE MATCHES "Dedicated_Sram") target_compile_definitions(ethosu_target_common INTERFACE - # ETHOSU_MODEL=0 place pte file/data in SRAM area - # ETHOSU_MODEL=1 place pte file/data in DDR area - ETHOSU_MODEL=1 # Configure NPU architecture timing adapters # This is just example numbers and you should make this match your hardware # SRAM @@ -371,11 +381,8 @@ elseif(SYSTEM_CONFIG STREQUAL "Ethos_U85_SYS_DRAM_Mid" OR SYSTEM_CONFIG STREQUAL ETHOSU_TA_HISTBIN_1=0 ETHOSU_TA_HISTCNT_1=0 ) - elseif(MEMORY_MODE STREQUAL "Sram_Only") + elseif(MEMORY_MODE MATCHES "Sram_Only") target_compile_definitions(ethosu_target_common INTERFACE - # ETHOSU_MODEL=0 place pte file/data in SRAM area - # ETHOSU_MODEL=1 place pte file/data in DDR area - ETHOSU_MODEL=1 # Configure NPU architecture timing adapters # This is just example numbers and you should make this match your hardware # SRAM @@ -434,7 +441,7 @@ endif() # the memory traffic of Region 1 should pass via the external memory(3) and the traffic for Region 2 should pass via the SRAM(0) # -if(MEMORY_MODE STREQUAL "Sram_Only") +if(MEMORY_MODE MATCHES "Sram_Only") target_compile_definitions(ethosu_core_driver PRIVATE NPU_QCONFIG=1 NPU_REGIONCFG_0=1 @@ -445,7 +452,7 @@ if(MEMORY_MODE STREQUAL "Sram_Only") NPU_REGIONCFG_5=0 NPU_REGIONCFG_6=0 NPU_REGIONCFG_7=0) - elseif(MEMORY_MODE STREQUAL "Dedicated_Sram") + elseif(MEMORY_MODE MATCHES "Dedicated_Sram") target_compile_definitions(ethosu_core_driver PRIVATE NPU_QCONFIG=3 NPU_REGIONCFG_0=3 @@ -632,8 +639,9 @@ if(ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE) target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE}) endif() -if(ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE) - target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE}) +target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE}) +if(DEFINED ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE) + target_compile_definitions(arm_executor_runner PUBLIC ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=${ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE}) endif() if(ET_BUNDLE_IO) diff --git a/examples/arm/executor_runner/arm_executor_runner.cpp b/examples/arm/executor_runner/arm_executor_runner.cpp index ed93d2acd8b..e5313345f6c 100644 --- a/examples/arm/executor_runner/arm_executor_runner.cpp +++ b/examples/arm/executor_runner/arm_executor_runner.cpp @@ -128,17 +128,41 @@ const float et_rtol = 0.01; * The temp_allocation_pool is used for allocating temporary data during kernel * or delegate execution. This will be reset after each kernel or delegate call. * Currently a MemoryAllocator is used but a PlatformMemoryAllocator is probably - * a better fit + * a better fit. + * + * The Corstone-300/Corstone-320 platforms have 2MB/4MB of SRAM respectively. + * For Shared_Sram, ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE is + * 2MB and the linker script places the .bss.tensor_arena symbol in the SRAM. + * For Dedicated_Sram, the .bss.tensor_arena symbol is placed in the DDR in the + * linker script. Hence, we allocate 128MB in DDR and 384KB in the SRAM + * (.bss.ethosu_scratch is placed in the SRAM). The examples/arm/CMakeLists.txt + * contains the logic for the sizes of + * ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE and + * ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE */ -#if !defined(ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE) -#define ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE (1 * 1024 * 1024) -#endif const size_t temp_allocation_pool_size = - ET_ARM_BAREMETAL_TEMP_ALLOCATOR_POOL_SIZE; + ET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE; unsigned char __attribute__(( - section("input_data_sec"), + section(".bss.tensor_arena"), aligned(16))) temp_allocation_pool[temp_allocation_pool_size]; +namespace executorch { +namespace backends { +namespace arm { +#if defined(ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE) +size_t ethosu_fast_scratch_size = + ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE; +unsigned char __attribute__((section(".bss.ethosu_scratch"), aligned(16))) +dedicated_sram[ET_ARM_BAREMETAL_FAST_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE]; +unsigned char* ethosu_fast_scratch = dedicated_sram; +#else +size_t ethosu_fast_scratch_size = 0; +unsigned char* ethosu_fast_scratch = nullptr; +#endif +} // namespace arm +} // namespace backends +} // namespace executorch + void et_pal_init(void) { // Enable ARM PMU Clock ARM_PMU_Enable(); @@ -207,7 +231,7 @@ namespace { class ArmMemoryAllocator : public executorch::runtime::MemoryAllocator { public: ArmMemoryAllocator(uint32_t size, uint8_t* base_address) - : MemoryAllocator(size, base_address), used_(0) {} + : MemoryAllocator(size, base_address), used_(0), peak_used_(0) {} void* allocate(size_t size, size_t alignment = kDefaultAlignment) override { void* ret = executorch::runtime::MemoryAllocator::allocate(size, alignment); @@ -222,6 +246,8 @@ class ArmMemoryAllocator : public executorch::runtime::MemoryAllocator { } else { used_ = (used_ | (alignment - 1)) + 1 + size; } + if (used_ > peak_used_) + peak_used_ = used_; } return ret; } @@ -231,13 +257,25 @@ class ArmMemoryAllocator : public executorch::runtime::MemoryAllocator { return used_; } + // Returns the peak memory usage of the allocator's memory buffer + // Peak usage is useful when doing multiple allocations & resets + size_t peak_used() const { + return peak_used_; + } + // Returns the free size of the allocator's memory buffer. size_t free_size() const { return executorch::runtime::MemoryAllocator::size() - used_; } + void reset() { + executorch::runtime::MemoryAllocator::reset(); + used_ = 0; + } + private: size_t used_; + size_t peak_used_; }; Result prepare_input_tensors( @@ -682,11 +720,11 @@ int main(int argc, const char* argv[]) { if (temp_allocator.size() > 0) { ET_LOG( Info, - "temp_allocator_used: %zu / %zu free: %zu ( used: %zu %% ) ", - temp_allocator.used_size(), + "peak_temp_allocator: %zu / %zu free: %zu ( used: %zu %% ) ", + temp_allocator.peak_used(), temp_allocator.size(), temp_allocator.free_size(), - 100 * temp_allocator.used_size() / temp_allocator.size()); + 100 * temp_allocator.peak_used() / temp_allocator.size()); } if (status != Error::Ok) { diff --git a/examples/arm/run.sh b/examples/arm/run.sh index 89ac5cd30a8..750c251596c 100755 --- a/examples/arm/run.sh +++ b/examples/arm/run.sh @@ -110,7 +110,7 @@ then memory_mode="Shared_Sram" if [[ ${target} =~ "ethos-u85" ]] then - memory_mode="Sram_Only" + memory_mode="Dedicated_Sram_384KB" fi fi