diff --git a/runtime/backend/backend_execution_context.h b/runtime/backend/backend_execution_context.h index 7890f4f9528..d2790b158ef 100644 --- a/runtime/backend/backend_execution_context.h +++ b/runtime/backend/backend_execution_context.h @@ -21,8 +21,11 @@ class BackendExecutionContext final { public: BackendExecutionContext( EventTracer* event_tracer = nullptr, - MemoryAllocator* temp_allocator = nullptr) - : event_tracer_(event_tracer), temp_allocator_(temp_allocator) {} + MemoryAllocator* temp_allocator = nullptr, + const char* method_name = nullptr) + : event_tracer_(event_tracer), + temp_allocator_(temp_allocator), + method_name_(method_name) {} /** * Returns a pointer to an instance of EventTracer to do profiling/debugging @@ -52,9 +55,17 @@ class BackendExecutionContext final { return temp_allocator_; } + /** + * Get the name of the executing method from the ExecuTorch runtime. + */ + const char* get_method_name() const { + return method_name_; + } + private: EventTracer* event_tracer_ = nullptr; MemoryAllocator* temp_allocator_ = nullptr; + const char* method_name_ = nullptr; }; } // namespace runtime diff --git a/runtime/backend/backend_init_context.h b/runtime/backend/backend_init_context.h index 7541349318e..051266662c6 100644 --- a/runtime/backend/backend_init_context.h +++ b/runtime/backend/backend_init_context.h @@ -18,8 +18,10 @@ namespace runtime { */ class BackendInitContext final { public: - explicit BackendInitContext(MemoryAllocator* runtime_allocator) - : runtime_allocator_(runtime_allocator) {} + explicit BackendInitContext( + MemoryAllocator* runtime_allocator, + const char* method_name = nullptr) + : runtime_allocator_(runtime_allocator), method_name_(method_name) {} /** Get the runtime allocator passed from Method. It's the same runtime * executor used by the standard executor runtime and the life span is the @@ -29,8 +31,20 @@ class BackendInitContext final { return runtime_allocator_; } + /** Get the loaded method name from ExecuTorch runtime. Usually it's + * "forward", however, if there are multiple methods in the .pte file, it can + * be different. One example is that we may have prefill and decode methods in + * the same .pte file. In this case, when client loads "prefill" method, the + * `get_method_name` function will return "prefill", when client loads + * "decode" method, the `get_method_name` function will return "decode". + */ + const char* get_method_name() const { + return method_name_; + } + private: MemoryAllocator* runtime_allocator_ = nullptr; + const char* method_name_ = nullptr; }; } // namespace runtime diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 8e74a508f3e..4208cf36c55 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -626,7 +626,9 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) { for (size_t i = 0; i < n_delegate; ++i) { const auto& delegate = *delegates->Get(i); - BackendInitContext backend_init_context(method_allocator); + BackendInitContext backend_init_context( + method_allocator, + /*method_name=*/serialization_plan_->name()->c_str()); Error err = BackendDelegate::Init( delegate, program_, backend_init_context, &delegates_[i]); if (err != Error::Ok) { @@ -1097,8 +1099,9 @@ Error Method::execute_instruction() { n_delegate_, step_state_.instr_idx); BackendExecutionContext backend_execution_context( - /*event_tracer*/ event_tracer_, - /*temp_allocator*/ temp_allocator_); + /*event_tracer=*/event_tracer_, + /*temp_allocator=*/temp_allocator_, + /*method_name=*/serialization_plan_->name()->c_str()); err = delegates_[delegate_idx].Execute( backend_execution_context, chain.argument_lists_[step_state_.instr_idx].data()); diff --git a/runtime/executor/test/backend_integration_test.cpp b/runtime/executor/test/backend_integration_test.cpp index 2db1a16af17..37a653b4d9c 100644 --- a/runtime/executor/test/backend_integration_test.cpp +++ b/runtime/executor/test/backend_integration_test.cpp @@ -95,7 +95,7 @@ class StubBackend final : public BackendInterface { } Error execute( - ET_UNUSED BackendExecutionContext& context, + BackendExecutionContext& context, DelegateHandle* handle, EValue** args) const override { if (execute_fn_) { @@ -530,6 +530,48 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) { EXPECT_EQ(backend_load_was_called, using_segments()); } +TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) { + Result loader = FileDataLoader::from(program_path()); + ASSERT_EQ(loader.error(), Error::Ok); + const void* processed_data = nullptr; + StubBackend::singleton().install_init( + [&](FreeableBuffer* processed, + ET_UNUSED ArrayRef compile_specs, + ET_UNUSED BackendInitContext& backend_init_context) + -> Result { + auto method_name = backend_init_context.get_method_name(); + // Ensure that we can get the method name during init via context + EXPECT_STREQ(method_name, "forward"); + processed_data = processed->data(); + return nullptr; + }); + Result program = Program::load(&loader.get()); + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = program->load_method("forward", &mmm.get()); + EXPECT_TRUE(method.ok()); + ASSERT_EQ(program.error(), Error::Ok); +} + +TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) { + Result loader = FileDataLoader::from(program_path()); + ASSERT_EQ(loader.error(), Error::Ok); + StubBackend::singleton().install_execute( + [&](BackendExecutionContext& backend_execution_context, + ET_UNUSED DelegateHandle* handle, + ET_UNUSED EValue** args) -> Error { + // Ensure that we can get the method name during execution via context + auto method_name = backend_execution_context.get_method_name(); + EXPECT_STREQ(method_name, "forward"); + return Error::Ok; + }); + Result program = Program::load(&loader.get()); + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = program->load_method("forward", &mmm.get()); + EXPECT_TRUE(method.ok()); + Error err = method->execute(); + ASSERT_EQ(err, Error::Ok); +} + // TODO: Add more tests for the runtime-to-backend interface. E.g.: // - Errors during init() or execute() result in runtime init/execution failures // - Correct values are passed to init()/execute()