diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 4ec02aee921..a6ed7e354a9 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,8 @@ namespace executorch { namespace runtime { +using internal::PlatformMemoryAllocator; + /** * Runtime state for a backend delegate. */ @@ -548,7 +551,16 @@ Result Method::load( const Program* program, MemoryManager* memory_manager, EventTracer* event_tracer) { - Method method(program, memory_manager, event_tracer); + MemoryAllocator* temp_allocator = memory_manager->temp_allocator(); + if (temp_allocator == nullptr) { + PlatformMemoryAllocator* platform_allocator = + ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR( + memory_manager->method_allocator(), PlatformMemoryAllocator); + new (platform_allocator) PlatformMemoryAllocator(); + temp_allocator = platform_allocator; + } + Method method(program, memory_manager, event_tracer, temp_allocator); + Error err = method.init(s_plan); if (err != Error::Ok) { return err; @@ -1039,16 +1051,14 @@ Error Method::execute_instruction() { auto instruction = instructions->Get(step_state_.instr_idx); size_t next_instr_idx = step_state_.instr_idx + 1; Error err = Error::Ok; + switch (instruction->instr_args_type()) { case executorch_flatbuffer::InstructionArguments::KernelCall: { EXECUTORCH_SCOPE_PROF("OPERATOR_CALL"); internal::EventTracerProfileScope event_tracer_scope = internal::EventTracerProfileScope(event_tracer_, "OPERATOR_CALL"); // TODO(T147221312): Also expose tensor resizer via the context. - // The temp_allocator passed can be null, but calling allocate_temp will - // fail - KernelRuntimeContext context( - event_tracer_, memory_manager_->temp_allocator()); + KernelRuntimeContext context(event_tracer_, temp_allocator_); auto args = chain.argument_lists_[step_state_.instr_idx]; chain.kernels_[step_state_.instr_idx](context, args.data()); // We reset the temp_allocator after the switch statement @@ -1096,7 +1106,7 @@ Error Method::execute_instruction() { step_state_.instr_idx); BackendExecutionContext backend_execution_context( /*event_tracer*/ event_tracer_, - /*temp_allocator*/ memory_manager_->temp_allocator()); + /*temp_allocator*/ temp_allocator_); err = delegates_[delegate_idx].Execute( backend_execution_context, chain.argument_lists_[step_state_.instr_idx].data()); @@ -1168,8 +1178,8 @@ Error Method::execute_instruction() { err = Error::InvalidProgram; } // Reset the temp allocator for every instruction. - if (memory_manager_->temp_allocator() != nullptr) { - memory_manager_->temp_allocator()->reset(); + if (temp_allocator_ != nullptr) { + temp_allocator_->reset(); } if (err == Error::Ok) { step_state_.instr_idx = next_instr_idx; diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 7d96096accf..0a35d6b9282 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -53,6 +53,7 @@ class Method final { : step_state_(rhs.step_state_), program_(rhs.program_), memory_manager_(rhs.memory_manager_), + temp_allocator_(rhs.temp_allocator_), serialization_plan_(rhs.serialization_plan_), event_tracer_(rhs.event_tracer_), n_value_(rhs.n_value_), @@ -273,10 +274,12 @@ class Method final { Method( const Program* program, MemoryManager* memory_manager, - EventTracer* event_tracer) + EventTracer* event_tracer, + MemoryAllocator* temp_allocator) : step_state_(), program_(program), memory_manager_(memory_manager), + temp_allocator_(temp_allocator), serialization_plan_(nullptr), event_tracer_(event_tracer), n_value_(0), @@ -319,6 +322,7 @@ class Method final { StepState step_state_; const Program* program_; MemoryManager* memory_manager_; + MemoryAllocator* temp_allocator_; executorch_flatbuffer::ExecutionPlan* serialization_plan_; EventTracer* event_tracer_; diff --git a/runtime/executor/platform_memory_allocator.h b/runtime/executor/platform_memory_allocator.h new file mode 100644 index 00000000000..09195a460ac --- /dev/null +++ b/runtime/executor/platform_memory_allocator.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace runtime { +namespace internal { + +/** + * PlatformMemoryAllocator is a memory allocator that uses a linked list to + * manage allocated nodes. It overrides the allocate method of MemoryAllocator + * using the PAL fallback allocator method `et_pal_allocate`. + */ +class PlatformMemoryAllocator final : public MemoryAllocator { + private: + // We allocate a little more than requested and use that memory as a node in + // a linked list, pushing the allocated buffers onto a list that's iterated + // and freed when the KernelRuntimeContext is destroyed. + struct AllocationNode { + void* data; + AllocationNode* next; + }; + + AllocationNode* head_ = nullptr; + + public: + PlatformMemoryAllocator() : MemoryAllocator(0, nullptr) {} + + void* allocate(size_t size, size_t alignment = kDefaultAlignment) override { + if (!isPowerOf2(alignment)) { + ET_LOG(Error, "Alignment %zu is not a power of 2", alignment); + return nullptr; + } + + // Allocate enough memory for the node, the data and the alignment bump. + size_t alloc_size = sizeof(AllocationNode) + size + alignment; + void* node_memory = et_pal_allocate(alloc_size); + + // If allocation failed, log message and return nullptr. + if (node_memory == nullptr) { + ET_LOG(Error, "Failed to allocate %zu bytes", alloc_size); + return nullptr; + } + + // Compute data pointer. + uint8_t* data_ptr = + reinterpret_cast(node_memory) + sizeof(AllocationNode); + + // Align the data pointer. + void* aligned_data_ptr = alignPointer(data_ptr, alignment); + + // Assert that the alignment didn't overflow the allocated memory. + ET_DCHECK_MSG( + reinterpret_cast(aligned_data_ptr) + size <= + reinterpret_cast(node_memory) + alloc_size, + "aligned_data_ptr %p + size %zu > node_memory %p + alloc_size %zu", + aligned_data_ptr, + size, + node_memory, + alloc_size); + + // Construct the node. + AllocationNode* new_node = reinterpret_cast(node_memory); + new_node->data = aligned_data_ptr; + new_node->next = head_; + head_ = new_node; + + // Return the aligned data pointer. + return head_->data; + } + + void reset() override { + AllocationNode* current = head_; + while (current != nullptr) { + AllocationNode* next = current->next; + et_pal_free(current); + current = next; + } + head_ = nullptr; + } + + ~PlatformMemoryAllocator() override { + reset(); + } + + private: + // Disable copy and move. + PlatformMemoryAllocator(const PlatformMemoryAllocator&) = delete; + PlatformMemoryAllocator& operator=(const PlatformMemoryAllocator&) = delete; + PlatformMemoryAllocator(PlatformMemoryAllocator&&) noexcept = delete; + PlatformMemoryAllocator& operator=(PlatformMemoryAllocator&&) noexcept = + delete; +}; + +} // namespace internal +} // namespace runtime +} // namespace executorch diff --git a/runtime/executor/program.h b/runtime/executor/program.h index a599cc958e0..f7469eb2192 100644 --- a/runtime/executor/program.h +++ b/runtime/executor/program.h @@ -123,7 +123,8 @@ class Program final { * * @param[in] method_name The name of the method to load. * @param[in] memory_manager The allocators to use during initialization and - * execution of the loaded method. + * execution of the loaded method. If `memory_manager.temp_allocator()` is + * null, the runtime will allocate temp memory using `et_pal_allocate()`. * @param[in] event_tracer The event tracer to use for this method run. * * @returns The loaded method on success, or an error on failure. diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 46f997a80ad..cc91255d7b5 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -65,6 +65,9 @@ def define_common_targets(): "tensor_parser_exec_aten.cpp", "tensor_parser{}.cpp".format(aten_suffix if aten_mode else "_portable"), ], + headers = [ + "platform_memory_allocator.h", + ], exported_headers = [ "method.h", "method_meta.h", diff --git a/runtime/executor/test/kernel_integration_test.cpp b/runtime/executor/test/kernel_integration_test.cpp index 616398b7416..4f1ac0240b9 100644 --- a/runtime/executor/test/kernel_integration_test.cpp +++ b/runtime/executor/test/kernel_integration_test.cpp @@ -34,6 +34,7 @@ using executorch::runtime::FreeableBuffer; using executorch::runtime::Kernel; using executorch::runtime::KernelKey; using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::MemoryAllocator; using executorch::runtime::Method; using executorch::runtime::Program; using executorch::runtime::Result; @@ -59,10 +60,26 @@ struct KernelControl { // returning. Error fail_value = Error::Ok; + // If true, the kernel should allocate temporary memory. + bool allocate_temp_memory = false; + + // If true, the kernel should simulate allocating temporary memory. + bool simulate_temp_memory_allocation = false; + + // The size of the temporary memory to allocate. + int temp_memory_size = 0; + + // The total size of all allocations. + int total_allocated_size = 0; + void reset() { call_count = 0; call_context_fail = false; fail_value = Error::Ok; + allocate_temp_memory = false; + simulate_temp_memory_allocation = false; + temp_memory_size = 0; + total_allocated_size = 0; } /** @@ -117,6 +134,33 @@ struct KernelControl { if (control->call_context_fail) { context.fail(control->fail_value); } + + // Allocate temporary memory. + if (control->allocate_temp_memory) { + Result temp_mem_res = + context.allocate_temp(control->temp_memory_size); + if (temp_mem_res.ok()) { + control->total_allocated_size += control->temp_memory_size; + // We actually use the memory, to test default memory allocation was + // successful. + uint8_t* array = (uint8_t*)(temp_mem_res.get()); + for (int i = 0; i < control->temp_memory_size; i++) { + array[i] = i % 256; + } + } + } + + // Simulate allocating temporary memory. We use this, for testing that when + // a temp allocator is provided, the kernel will use it, instead of + // allocating memory with the default platform memory allocator. + // The provided TempMemoryAllocator class in this file, simulates allocating + // memory instead of actually allocating anything. + if (control->simulate_temp_memory_allocation) { + Result temp_mem_res = + context.allocate_temp(control->temp_memory_size); + control->total_allocated_size += control->temp_memory_size; + EXPECT_EQ(temp_mem_res.error(), Error::Ok); + } } static bool registered_; @@ -126,6 +170,44 @@ struct KernelControl { bool KernelControl::registered_ = false; KernelControl KernelControl::singleton_; +/** + * MemoryAllocator that keeps track of the number/sizes of its allocations, + * to test the case where the user provides a temp allocator. + */ +class TempMemoryAllocator final : public MemoryAllocator { + public: + TempMemoryAllocator() : MemoryAllocator(0, nullptr) {} + + // The number of times allocate() has been called. + int number_of_allocations = 0; + + // The number of times reset() has been called. + int number_of_resets = 0; + + // The amount of memory currently allocated (should go to 0 when reset is + // called). + int currently_allocated_size = 0; + + // The total size of all allocations. + int total_allocated_size = 0; + + void* allocate(size_t size, ET_UNUSED size_t alignment = kDefaultAlignment) + override { + number_of_allocations += 1; + currently_allocated_size += size; + total_allocated_size += size; + // This is a simulation, we don't actually allocate memory. But we need to + // return a non-null pointer, so we return a bad, non-zero address that will + // crash if anyone tries to dereference it. + return (void*)1; + } + + void reset() override { + number_of_resets += 1; + currently_allocated_size = 0; + } +}; + class KernelIntegrationTest : public ::testing::Test { protected: void SetUp() override { @@ -152,7 +234,9 @@ class KernelIntegrationTest : public ::testing::Test { // Load the forward method. mmm_ = std::make_unique( - kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + kDefaultNonConstMemBytes, + kDefaultRuntimeMemBytes, + temp_allocator_.get()); Result method = program_->load_method("forward", &mmm_->get()); ASSERT_EQ(method.error(), Error::Ok); method_ = std::make_unique(std::move(method.get())); @@ -185,6 +269,19 @@ class KernelIntegrationTest : public ::testing::Test { // The KernelControl associated with method_. KernelControl* control_; + + // The temp memory allocator provided by the user. By default, none is + // provided. + std::unique_ptr temp_allocator_ = nullptr; +}; + +class KernelTempMemoryAllocatorIntegrationTest : public KernelIntegrationTest { + protected: + void SetUp() override { + // Create a temp allocator for the test before calling the parent SetUp. + temp_allocator_ = std::make_unique(); + KernelIntegrationTest::SetUp(); + } }; TEST_F(KernelIntegrationTest, KernelHookIsCalled) { @@ -222,3 +319,63 @@ TEST_F(KernelIntegrationTest, FailurePropagates) { EXPECT_EQ(err, Error::Ok); EXPECT_EQ(control_->call_count, 3); } + +TEST_F(KernelIntegrationTest, DefaultPlatformMemoryAllocator) { + // Tell the kernel to allocate memory. Since no temp allocator is provided, + // this will allocate memory using the default platform memory allocator. + control_->allocate_temp_memory = true; + + control_->temp_memory_size = 4; + // This is not a simulation. This actually allocates memory, using the + // default platform memory allocator. + Error err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 1); + EXPECT_EQ(control_->total_allocated_size, 4); + + control_->temp_memory_size = 8; + // This is not a simulation. This actually allocates memory, using the + // default platform memory allocator. + err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 2); + EXPECT_EQ(control_->total_allocated_size, 12); +} + +TEST_F(KernelTempMemoryAllocatorIntegrationTest, UsingTempMemoryAllocator) { + // In this test we provide a temp allocator to the method, and tell the kernel + // to allocate memory using it. We want to make sure that the kernel uses the + // temp allocator, and that the temp allocator is reset after the execution. + // Since we are testing that the kernel uses the temp allocator, and not the + // temp allocator itself, we don't need to test the actual allocation of + // memory. Therefore, we set simulate_temp_memory_allocation to true, so that + // the kernel will not actually allocate memory, but will instead simulate + // allocating memory. + // The provided TempMemoryAllocator, simulates allocating memory by increasing + // total_allocated_size and currently_allocated_size by the requested size. + // We simulate resetting the allocator by setting currently_allocated_size + // back to 0. + control_->simulate_temp_memory_allocation = true; + + control_->temp_memory_size = 4; + Error err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 1); + EXPECT_EQ(control_->total_allocated_size, 4); + EXPECT_EQ(temp_allocator_->number_of_allocations, 1); + EXPECT_EQ(temp_allocator_->total_allocated_size, 4); + // The temp allocator should have been reset after the execution. + EXPECT_EQ(temp_allocator_->number_of_resets, 1); + EXPECT_EQ(temp_allocator_->currently_allocated_size, 0); + + control_->temp_memory_size = 8; + err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 2); + EXPECT_EQ(control_->total_allocated_size, 12); + EXPECT_EQ(temp_allocator_->number_of_allocations, 2); + EXPECT_EQ(temp_allocator_->total_allocated_size, 12); + // The temp allocator should have been reset after the execution. + EXPECT_EQ(temp_allocator_->number_of_resets, 2); + EXPECT_EQ(temp_allocator_->currently_allocated_size, 0); +} diff --git a/runtime/executor/test/managed_memory_manager.h b/runtime/executor/test/managed_memory_manager.h index 667aa35ca24..a01091527b0 100644 --- a/runtime/executor/test/managed_memory_manager.h +++ b/runtime/executor/test/managed_memory_manager.h @@ -27,7 +27,8 @@ class ManagedMemoryManager { public: ManagedMemoryManager( size_t planned_memory_bytes, - size_t method_allocator_bytes) + size_t method_allocator_bytes, + MemoryAllocator* temp_allocator = nullptr) : planned_memory_buffer_(new uint8_t[planned_memory_bytes]), planned_memory_span_( planned_memory_buffer_.get(), @@ -35,7 +36,7 @@ class ManagedMemoryManager { planned_memory_({&planned_memory_span_, 1}), method_allocator_pool_(new uint8_t[method_allocator_bytes]), method_allocator_(method_allocator_bytes, method_allocator_pool_.get()), - memory_manager_(&method_allocator_, &planned_memory_) {} + memory_manager_(&method_allocator_, &planned_memory_, temp_allocator) {} MemoryManager& get() { return memory_manager_; diff --git a/runtime/platform/default/minimal.cpp b/runtime/platform/default/minimal.cpp index e1db2083f4a..8236f993188 100644 --- a/runtime/platform/default/minimal.cpp +++ b/runtime/platform/default/minimal.cpp @@ -47,3 +47,9 @@ void et_pal_emit_log_message( ET_UNUSED size_t line, ET_UNUSED const char* message, ET_UNUSED size_t length) {} + +void* et_pal_allocate(ET_UNUSED size_t size) { + return nullptr; +} + +void et_pal_free(ET_UNUSED void* ptr) {} diff --git a/runtime/platform/default/posix.cpp b/runtime/platform/default/posix.cpp index cfc8cafc491..aba504f53e0 100644 --- a/runtime/platform/default/posix.cpp +++ b/runtime/platform/default/posix.cpp @@ -170,3 +170,26 @@ void et_pal_emit_log_message( message); fflush(ET_LOG_OUTPUT_FILE); } + +/** + * NOTE: Core runtime code must not call this directly. It may only be called by + * a MemoryAllocator wrapper. + * + * Allocates size bytes of memory via malloc. + * + * @param[in] size Number of bytes to allocate. + * @returns the allocated memory, or nullptr on failure. Must be freed using + * et_pal_free(). + */ +void* et_pal_allocate(size_t size) { + return malloc(size); +} + +/** + * Frees memory allocated by et_pal_allocate(). + * + * @param[in] ptr Pointer to memory to free. May be nullptr. + */ +void et_pal_free(void* ptr) { + free(ptr); +} diff --git a/runtime/platform/platform.h b/runtime/platform/platform.h index e29dad8e9a8..03cdef8eb2f 100644 --- a/runtime/platform/platform.h +++ b/runtime/platform/platform.h @@ -115,4 +115,23 @@ void et_pal_emit_log_message( const char* message, size_t length) ET_INTERNAL_PLATFORM_WEAKNESS; +/** + * NOTE: Core runtime code must not call this directly. It may only be called by + * a MemoryAllocator wrapper. + * + * Allocates size bytes of memory. + * + * @param[in] size Number of bytes to allocate. + * @returns the allocated memory, or nullptr on failure. Must be freed using + * et_pal_free(). + */ +void* et_pal_allocate(size_t size) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * Frees memory allocated by et_pal_allocate(). + * + * @param[in] ptr Pointer to memory to free. May be nullptr. + */ +void et_pal_free(void* ptr) ET_INTERNAL_PLATFORM_WEAKNESS; + } // extern "C" diff --git a/runtime/platform/test/executor_pal_override_test.cpp b/runtime/platform/test/executor_pal_override_test.cpp index bb9ea2ce589..9bc500e652e 100644 --- a/runtime/platform/test/executor_pal_override_test.cpp +++ b/runtime/platform/test/executor_pal_override_test.cpp @@ -53,12 +53,29 @@ class PalSpy : public PlatformIntercept { last_log_message_args.length = length; } + void* allocate(size_t size) override { + ++allocate_call_count; + last_allocated_size = size; + last_allocated_ptr = (void*)0x1234; + return nullptr; + } + + void free(void* ptr) override { + ++free_call_count; + last_freed_ptr = ptr; + } + virtual ~PalSpy() = default; size_t init_call_count = 0; size_t current_ticks_call_count = 0; size_t emit_log_message_call_count = 0; et_tick_ratio_t tick_ns_multiplier = {1, 1}; + size_t allocate_call_count = 0; + size_t free_call_count = 0; + size_t last_allocated_size = 0; + void* last_allocated_ptr = nullptr; + void* last_freed_ptr = nullptr; /// The args that were passed to the most recent call to emit_log_message(). struct { @@ -158,4 +175,33 @@ TEST(ExecutorPalOverrideTest, TickToNsMultiplier) { EXPECT_EQ(et_pal_ticks_to_ns_multiplier().denominator, 1); } +TEST(ExecutorPalOverrideTest, AllocateSmokeTest) { + PalSpy spy; + InterceptWith iw(spy); + + // Validate that et_pal_allocate is overridden. + EXPECT_EQ(spy.allocate_call_count, 0); + EXPECT_EQ(spy.last_allocated_ptr, nullptr); + et_pal_allocate(4); + EXPECT_EQ(spy.allocate_call_count, 1); + EXPECT_EQ(spy.last_allocated_size, 4); + EXPECT_EQ(spy.last_allocated_ptr, (void*)0x1234); +} + +TEST(ExecutorPalOverrideTest, FreeSmokeTest) { + PalSpy spy; + InterceptWith iw(spy); + + et_pal_allocate(4); + EXPECT_EQ(spy.last_allocated_size, 4); + EXPECT_EQ(spy.last_allocated_ptr, (void*)0x1234); + + // Validate that et_pal_free is overridden. + EXPECT_EQ(spy.free_call_count, 0); + EXPECT_EQ(spy.last_freed_ptr, nullptr); + et_pal_free(spy.last_allocated_ptr); + EXPECT_EQ(spy.free_call_count, 1); + EXPECT_EQ(spy.last_freed_ptr, (void*)0x1234); +} + #endif diff --git a/runtime/platform/test/stub_platform.cpp b/runtime/platform/test/stub_platform.cpp index f7ad2f9ee63..8cee404e4e1 100644 --- a/runtime/platform/test/stub_platform.cpp +++ b/runtime/platform/test/stub_platform.cpp @@ -75,6 +75,16 @@ void et_pal_emit_log_message( timestamp, level, filename, function, line, message, length); } +void* et_pal_allocate(size_t size) { + ASSERT_INTERCEPT_INSTALLED(); + return platform_intercept->allocate(size); +} + +void et_pal_free(void* ptr) { + ASSERT_INTERCEPT_INSTALLED(); + platform_intercept->free(ptr); +} + } // extern "C" #include diff --git a/runtime/platform/test/stub_platform.h b/runtime/platform/test/stub_platform.h index af3756f3136..de5599b53b0 100644 --- a/runtime/platform/test/stub_platform.h +++ b/runtime/platform/test/stub_platform.h @@ -45,6 +45,12 @@ class PlatformIntercept { ET_UNUSED const char* message, ET_UNUSED size_t length) {} + virtual void* allocate(ET_UNUSED size_t size) { + return nullptr; + } + + virtual void free(ET_UNUSED void* ptr) {} + virtual ~PlatformIntercept() = default; };