Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/executor/memory_manager.h>
#include <executorch/runtime/executor/platform_memory_allocator.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/executor/tensor_parser.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>
Expand All @@ -29,6 +30,8 @@
namespace executorch {
namespace runtime {

using internal::PlatformMemoryAllocator;

/**
* Runtime state for a backend delegate.
*/
Expand Down Expand Up @@ -548,7 +551,16 @@ Result<Method> Method::load(
const Program* program,
MemoryManager* memory_manager,
EventTracer* event_tracer) {
Method method(program, memory_manager, event_tracer);
MemoryAllocator* temp_allocator = memory_manager->temp_allocator();
if (temp_allocator == nullptr) {
PlatformMemoryAllocator* platform_allocator =
ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(
memory_manager->method_allocator(), PlatformMemoryAllocator);
new (platform_allocator) PlatformMemoryAllocator();
temp_allocator = platform_allocator;
}
Method method(program, memory_manager, event_tracer, temp_allocator);

Error err = method.init(s_plan);
if (err != Error::Ok) {
return err;
Expand Down Expand Up @@ -1039,16 +1051,14 @@ Error Method::execute_instruction() {
auto instruction = instructions->Get(step_state_.instr_idx);
size_t next_instr_idx = step_state_.instr_idx + 1;
Error err = Error::Ok;

switch (instruction->instr_args_type()) {
case executorch_flatbuffer::InstructionArguments::KernelCall: {
EXECUTORCH_SCOPE_PROF("OPERATOR_CALL");
internal::EventTracerProfileScope event_tracer_scope =
internal::EventTracerProfileScope(event_tracer_, "OPERATOR_CALL");
// TODO(T147221312): Also expose tensor resizer via the context.
// The temp_allocator passed can be null, but calling allocate_temp will
// fail
KernelRuntimeContext context(
event_tracer_, memory_manager_->temp_allocator());
KernelRuntimeContext context(event_tracer_, temp_allocator_);
auto args = chain.argument_lists_[step_state_.instr_idx];
chain.kernels_[step_state_.instr_idx](context, args.data());
// We reset the temp_allocator after the switch statement
Expand Down Expand Up @@ -1096,7 +1106,7 @@ Error Method::execute_instruction() {
step_state_.instr_idx);
BackendExecutionContext backend_execution_context(
/*event_tracer*/ event_tracer_,
/*temp_allocator*/ memory_manager_->temp_allocator());
/*temp_allocator*/ temp_allocator_);
err = delegates_[delegate_idx].Execute(
backend_execution_context,
chain.argument_lists_[step_state_.instr_idx].data());
Expand Down Expand Up @@ -1168,8 +1178,8 @@ Error Method::execute_instruction() {
err = Error::InvalidProgram;
}
// Reset the temp allocator for every instruction.
if (memory_manager_->temp_allocator() != nullptr) {
memory_manager_->temp_allocator()->reset();
if (temp_allocator_ != nullptr) {
temp_allocator_->reset();
}
if (err == Error::Ok) {
step_state_.instr_idx = next_instr_idx;
Expand Down
6 changes: 5 additions & 1 deletion runtime/executor/method.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Method final {
: step_state_(rhs.step_state_),
program_(rhs.program_),
memory_manager_(rhs.memory_manager_),
temp_allocator_(rhs.temp_allocator_),
serialization_plan_(rhs.serialization_plan_),
event_tracer_(rhs.event_tracer_),
n_value_(rhs.n_value_),
Expand Down Expand Up @@ -273,10 +274,12 @@ class Method final {
Method(
const Program* program,
MemoryManager* memory_manager,
EventTracer* event_tracer)
EventTracer* event_tracer,
MemoryAllocator* temp_allocator)
: step_state_(),
program_(program),
memory_manager_(memory_manager),
temp_allocator_(temp_allocator),
serialization_plan_(nullptr),
event_tracer_(event_tracer),
n_value_(0),
Expand Down Expand Up @@ -319,6 +322,7 @@ class Method final {
StepState step_state_;
const Program* program_;
MemoryManager* memory_manager_;
MemoryAllocator* temp_allocator_;
executorch_flatbuffer::ExecutionPlan* serialization_plan_;
EventTracer* event_tracer_;

Expand Down
111 changes: 111 additions & 0 deletions runtime/executor/platform_memory_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <stdio.h>
#include <cinttypes>
#include <cstdint>

#include <executorch/runtime/core/memory_allocator.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/platform/platform.h>

namespace executorch {
namespace runtime {
namespace internal {

/**
* PlatformMemoryAllocator is a memory allocator that uses a linked list to
* manage allocated nodes. It overrides the allocate method of MemoryAllocator
* using the PAL fallback allocator method `et_pal_allocate`.
*/
class PlatformMemoryAllocator final : public MemoryAllocator {
private:
// We allocate a little more than requested and use that memory as a node in
// a linked list, pushing the allocated buffers onto a list that's iterated
// and freed when the KernelRuntimeContext is destroyed.
struct AllocationNode {
void* data;
AllocationNode* next;
};

AllocationNode* head_ = nullptr;

public:
PlatformMemoryAllocator() : MemoryAllocator(0, nullptr) {}

void* allocate(size_t size, size_t alignment = kDefaultAlignment) override {
if (!isPowerOf2(alignment)) {
ET_LOG(Error, "Alignment %zu is not a power of 2", alignment);
return nullptr;
}

// Allocate enough memory for the node, the data and the alignment bump.
size_t alloc_size = sizeof(AllocationNode) + size + alignment;
void* node_memory = et_pal_allocate(alloc_size);

// If allocation failed, log message and return nullptr.
if (node_memory == nullptr) {
ET_LOG(Error, "Failed to allocate %zu bytes", alloc_size);
return nullptr;
}

// Compute data pointer.
uint8_t* data_ptr =
reinterpret_cast<uint8_t*>(node_memory) + sizeof(AllocationNode);

// Align the data pointer.
void* aligned_data_ptr = alignPointer(data_ptr, alignment);

// Assert that the alignment didn't overflow the allocated memory.
ET_DCHECK_MSG(
reinterpret_cast<uintptr_t>(aligned_data_ptr) + size <=
reinterpret_cast<uintptr_t>(node_memory) + alloc_size,
"aligned_data_ptr %p + size %zu > node_memory %p + alloc_size %zu",
aligned_data_ptr,
size,
node_memory,
alloc_size);

// Construct the node.
AllocationNode* new_node = reinterpret_cast<AllocationNode*>(node_memory);
new_node->data = aligned_data_ptr;
new_node->next = head_;
head_ = new_node;

// Return the aligned data pointer.
return head_->data;
}

void reset() override {
AllocationNode* current = head_;
while (current != nullptr) {
AllocationNode* next = current->next;
et_pal_free(current);
current = next;
}
head_ = nullptr;
}

~PlatformMemoryAllocator() override {
reset();
}

private:
// Disable copy and move.
PlatformMemoryAllocator(const PlatformMemoryAllocator&) = delete;
PlatformMemoryAllocator& operator=(const PlatformMemoryAllocator&) = delete;
PlatformMemoryAllocator(PlatformMemoryAllocator&&) noexcept = delete;
PlatformMemoryAllocator& operator=(PlatformMemoryAllocator&&) noexcept =
delete;
};

} // namespace internal
} // namespace runtime
} // namespace executorch
3 changes: 2 additions & 1 deletion runtime/executor/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ class Program final {
*
* @param[in] method_name The name of the method to load.
* @param[in] memory_manager The allocators to use during initialization and
* execution of the loaded method.
* execution of the loaded method. If `memory_manager.temp_allocator()` is
* null, the runtime will allocate temp memory using `et_pal_allocate()`.
* @param[in] event_tracer The event tracer to use for this method run.
*
* @returns The loaded method on success, or an error on failure.
Expand Down
3 changes: 3 additions & 0 deletions runtime/executor/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def define_common_targets():
"tensor_parser_exec_aten.cpp",
"tensor_parser{}.cpp".format(aten_suffix if aten_mode else "_portable"),
],
headers = [
"platform_memory_allocator.h",
],
exported_headers = [
"method.h",
"method_meta.h",
Expand Down
Loading