Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions runtime/backend/backend_init_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
*/

#pragma once
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/event_tracer.h>
#include <executorch/runtime/core/memory_allocator.h>
#include <executorch/runtime/core/named_data_map.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/core/span.h>

#include <cstring>

#ifdef __GNUC__
// Disable -Wdeprecated-declarations, as some builds use 'Werror'.
Expand All @@ -29,15 +35,17 @@ class BackendInitContext final {
MemoryAllocator* runtime_allocator,
EventTracer* event_tracer = nullptr,
const char* method_name = nullptr,
const NamedDataMap* named_data_map = nullptr)
const NamedDataMap* named_data_map = nullptr,
Span<const BackendOption> runtime_specs = {})
: runtime_allocator_(runtime_allocator),
#ifdef ET_EVENT_TRACER_ENABLED
event_tracer_(event_tracer),
#else
event_tracer_(nullptr),
#endif
method_name_(method_name),
named_data_map_(named_data_map) {
named_data_map_(named_data_map),
runtime_specs_(runtime_specs) {
}

/** Get the runtime allocator passed from Method. It's the same runtime
Expand Down Expand Up @@ -75,11 +83,58 @@ class BackendInitContext final {
return named_data_map_;
}

/**
* Get the runtime specs (load-time options) for this backend.
* These are per-delegate options passed at Module::load() time.
*
* @return Span of BackendOption containing the runtime specs, or empty span
* if no runtime specs were provided.
*/
Span<const BackendOption> runtime_specs() const {
return runtime_specs_;
}

/**
* Get a runtime spec value by key and type.
*
* @tparam T The expected type (bool, int, or const char*)
* @param key The option key to look up.
* @return Result containing the value if found and type matches,
* Error::NotFound if key doesn't exist,
* Error::InvalidArgument if key exists but type doesn't match.
*/
template <typename T>
Result<T> get_runtime_spec(const char* key) const {
static_assert(
std::is_same_v<T, bool> || std::is_same_v<T, int> ||
std::is_same_v<T, const char*>,
"get_runtime_spec<T> only supports bool, int, and const char*");

for (size_t i = 0; i < runtime_specs_.size(); ++i) {
const auto& opt = runtime_specs_[i];
if (std::strcmp(opt.key, key) == 0) {
if constexpr (std::is_same_v<T, const char*>) {
if (auto* arr = std::get_if<std::array<char, kMaxOptionValueLength>>(
&opt.value)) {
return arr->data();
}
} else {
if (auto* val = std::get_if<T>(&opt.value)) {
return *val;
}
}
return Error::InvalidArgument;
}
}
return Error::NotFound;
}

private:
MemoryAllocator* runtime_allocator_ = nullptr;
EventTracer* event_tracer_ = nullptr;
const char* method_name_ = nullptr;
const NamedDataMap* named_data_map_ = nullptr;
Span<const BackendOption> runtime_specs_;
};

} // namespace ET_RUNTIME_NAMESPACE
Expand Down
204 changes: 204 additions & 0 deletions runtime/backend/test/backend_init_context_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/backend/backend_init_context.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/platform/runtime.h>

#include <gtest/gtest.h>

using namespace ::testing;
using executorch::runtime::BackendInitContext;
using executorch::runtime::BackendOption;
using executorch::runtime::BackendOptions;
using executorch::runtime::Error;
using executorch::runtime::Span;

class BackendInitContextTest : public ::testing::Test {
protected:
void SetUp() override {
executorch::runtime::runtime_init();
}
};

// Test default constructor without runtime specs
TEST_F(BackendInitContextTest, DefaultConstructorNoRuntimeSpecs) {
BackendInitContext context(nullptr);

auto specs = context.runtime_specs();
EXPECT_EQ(specs.size(), 0);
}

// Test constructor with runtime specs
TEST_F(BackendInitContextTest, ConstructorWithRuntimeSpecs) {
BackendOptions<4> opts;
opts.set_option("compute_unit", "cpu_and_gpu");
opts.set_option("num_threads", 4);
opts.set_option("enable_profiling", true);

// Create a const span from the mutable view
auto view = opts.view();
Span<const BackendOption> const_span(view.data(), view.size());

BackendInitContext context(
nullptr, // runtime_allocator
nullptr, // event_tracer
"forward", // method_name
nullptr, // named_data_map
const_span // runtime_specs
);

auto specs = context.runtime_specs();
EXPECT_EQ(specs.size(), 3);
}

// Test get_runtime_spec<bool> with valid key
TEST_F(BackendInitContextTest, GetRuntimeSpecBoolValid) {
BackendOptions<2> opts;
opts.set_option("enable_profiling", true);
opts.set_option("debug_mode", false);

auto view = opts.view();
Span<const BackendOption> const_span(view.data(), view.size());

BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span);

auto result1 = context.get_runtime_spec<bool>("enable_profiling");
EXPECT_TRUE(result1.ok());
EXPECT_TRUE(result1.get());

auto result2 = context.get_runtime_spec<bool>("debug_mode");
EXPECT_TRUE(result2.ok());
EXPECT_FALSE(result2.get());
}

// Test get_runtime_spec<int> with valid key
TEST_F(BackendInitContextTest, GetRuntimeSpecIntValid) {
BackendOptions<2> opts;
opts.set_option("num_threads", 8);
opts.set_option("batch_size", 32);

auto view = opts.view();
Span<const BackendOption> const_span(view.data(), view.size());

BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span);

auto result1 = context.get_runtime_spec<int>("num_threads");
EXPECT_TRUE(result1.ok());
EXPECT_EQ(result1.get(), 8);

auto result2 = context.get_runtime_spec<int>("batch_size");
EXPECT_TRUE(result2.ok());
EXPECT_EQ(result2.get(), 32);
}

// Test get_runtime_spec<const char*> with valid key
TEST_F(BackendInitContextTest, GetRuntimeSpecStringValid) {
BackendOptions<2> opts;
opts.set_option("compute_unit", "cpu_and_gpu");
opts.set_option("cache_dir", "/tmp/cache");

auto view = opts.view();
Span<const BackendOption> const_span(view.data(), view.size());

BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span);

auto result1 = context.get_runtime_spec<const char*>("compute_unit");
EXPECT_TRUE(result1.ok());
EXPECT_STREQ(result1.get(), "cpu_and_gpu");

auto result2 = context.get_runtime_spec<const char*>("cache_dir");
EXPECT_TRUE(result2.ok());
EXPECT_STREQ(result2.get(), "/tmp/cache");
}

// Test get_runtime_spec<T> with non-existent key returns NotFound
TEST_F(BackendInitContextTest, GetRuntimeSpecNotFound) {
BackendOptions<1> opts;
opts.set_option("key", "value");

auto view = opts.view();
Span<const BackendOption> const_span(view.data(), view.size());

BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span);

auto bool_result = context.get_runtime_spec<bool>("nonexistent");
EXPECT_FALSE(bool_result.ok());
EXPECT_EQ(bool_result.error(), Error::NotFound);

auto int_result = context.get_runtime_spec<int>("nonexistent");
EXPECT_FALSE(int_result.ok());
EXPECT_EQ(int_result.error(), Error::NotFound);

auto string_result = context.get_runtime_spec<const char*>("nonexistent");
EXPECT_FALSE(string_result.ok());
EXPECT_EQ(string_result.error(), Error::NotFound);
}

// Test get_runtime_spec<T> with wrong type returns InvalidArgument
TEST_F(BackendInitContextTest, GetRuntimeSpecTypeMismatch) {
BackendOptions<3> opts;
opts.set_option("bool_opt", true);
opts.set_option("int_opt", 42);
opts.set_option("string_opt", "hello");

auto view = opts.view();
Span<const BackendOption> const_span(view.data(), view.size());

BackendInitContext context(nullptr, nullptr, nullptr, nullptr, const_span);

// Try to get bool as int
auto result1 = context.get_runtime_spec<int>("bool_opt");
EXPECT_FALSE(result1.ok());
EXPECT_EQ(result1.error(), Error::InvalidArgument);

// Try to get int as string
auto result2 = context.get_runtime_spec<const char*>("int_opt");
EXPECT_FALSE(result2.ok());
EXPECT_EQ(result2.error(), Error::InvalidArgument);

// Try to get string as bool
auto result3 = context.get_runtime_spec<bool>("string_opt");
EXPECT_FALSE(result3.ok());
EXPECT_EQ(result3.error(), Error::InvalidArgument);
}

// Test empty runtime specs
TEST_F(BackendInitContextTest, EmptyRuntimeSpecs) {
Span<const BackendOption> empty_span;
BackendInitContext context(nullptr, nullptr, nullptr, nullptr, empty_span);

EXPECT_EQ(context.runtime_specs().size(), 0);

// All lookups should return NotFound
auto bool_result = context.get_runtime_spec<bool>("any_key");
EXPECT_FALSE(bool_result.ok());
EXPECT_EQ(bool_result.error(), Error::NotFound);
}

// Test that other context fields still work
TEST_F(BackendInitContextTest, OtherFieldsStillWork) {
BackendOptions<1> opts;
opts.set_option("key", "value");

auto view = opts.view();
Span<const BackendOption> const_span(view.data(), view.size());

BackendInitContext context(
nullptr, // runtime_allocator
nullptr, // event_tracer
"forward", // method_name
nullptr, // named_data_map
const_span // runtime_specs
);

EXPECT_EQ(context.get_runtime_allocator(), nullptr);
EXPECT_EQ(context.event_tracer(), nullptr);
EXPECT_STREQ(context.get_method_name(), "forward");
EXPECT_EQ(context.get_named_data_map(), nullptr);
}
9 changes: 9 additions & 0 deletions runtime/backend/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def define_common_targets():
],
)

runtime.cxx_test(
name = "backend_init_context_test",
srcs = ["backend_init_context_test.cpp"],
deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/backend:interface",
],
)

runtime.cxx_test(
name = "backend_interface_update_test",
srcs = ["backend_interface_update_test.cpp"],
Expand Down
Loading