From 8d89b6881e9c8b9b16c5978254d96cd7624c1eea Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 12 Nov 2024 19:59:25 -0800 Subject: [PATCH] Expose method name as part of backend init context (#6622) Summary: Provide the method name to backend so they can load the corresponding method name accordingly. The most immediate need is that the qnn context binary can include two methods, one for prefill and one for decode. Since we don't allow backend access multi methods at the moment, we do it in a hacky way via following ## AOT: ``` class LLama_transformer(): def prefill() def decode() ``` Then we will have two custom ops from two to_backends ops, and both will have two context binary ``` QAT (prefill) -> to_backend(...) => prefill.qcir flatbuffers QAT (decode) -> to_backend(...) => decode.qcir flatbuffers => graph prefill( custom_op_prefill() -> context_binary (two graphs) ) graph decode() custom_op_decode() -> context_binary (two graphs) ) ``` Since two context binary from these two customs ops will be exactly the same and they can be deduplicate during emit via these two lines https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L136 and here https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L1065-L1066 ``` .pte instrucions [ "prefill" [instructions: call_delegate(prefill_input)] "decode": [instructions: call_delegate(decode_input)] "delegate_payload:: Dict[bytes, index]) ] ``` ## Runtime After we expose the method name via this change, the backend can access the method name, and load the same method as the top level method ``` Result QNNBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef compile_specs) { const char* method_name = context.get_method_name() // for example, "prefill" handle = qnn_backend.load(method_name) return handle } ``` This is to unblock sharing weight between prefill and decode for using htp backend. Reviewed By: dbort Differential Revision: D65386597 --- runtime/backend/backend_execution_context.h | 15 ++++++- runtime/backend/backend_init_context.h | 18 +++++++- runtime/executor/method.cpp | 9 ++-- .../test/backend_integration_test.cpp | 44 ++++++++++++++++++- 4 files changed, 78 insertions(+), 8 deletions(-) diff --git a/runtime/backend/backend_execution_context.h b/runtime/backend/backend_execution_context.h index 7890f4f9528..d2790b158ef 100644 --- a/runtime/backend/backend_execution_context.h +++ b/runtime/backend/backend_execution_context.h @@ -21,8 +21,11 @@ class BackendExecutionContext final { public: BackendExecutionContext( EventTracer* event_tracer = nullptr, - MemoryAllocator* temp_allocator = nullptr) - : event_tracer_(event_tracer), temp_allocator_(temp_allocator) {} + MemoryAllocator* temp_allocator = nullptr, + const char* method_name = nullptr) + : event_tracer_(event_tracer), + temp_allocator_(temp_allocator), + method_name_(method_name) {} /** * Returns a pointer to an instance of EventTracer to do profiling/debugging @@ -52,9 +55,17 @@ class BackendExecutionContext final { return temp_allocator_; } + /** + * Get the name of the executing method from the ExecuTorch runtime. + */ + const char* get_method_name() const { + return method_name_; + } + private: EventTracer* event_tracer_ = nullptr; MemoryAllocator* temp_allocator_ = nullptr; + const char* method_name_ = nullptr; }; } // namespace runtime diff --git a/runtime/backend/backend_init_context.h b/runtime/backend/backend_init_context.h index 7541349318e..051266662c6 100644 --- a/runtime/backend/backend_init_context.h +++ b/runtime/backend/backend_init_context.h @@ -18,8 +18,10 @@ namespace runtime { */ class BackendInitContext final { public: - explicit BackendInitContext(MemoryAllocator* runtime_allocator) - : runtime_allocator_(runtime_allocator) {} + explicit BackendInitContext( + MemoryAllocator* runtime_allocator, + const char* method_name = nullptr) + : runtime_allocator_(runtime_allocator), method_name_(method_name) {} /** Get the runtime allocator passed from Method. It's the same runtime * executor used by the standard executor runtime and the life span is the @@ -29,8 +31,20 @@ class BackendInitContext final { return runtime_allocator_; } + /** Get the loaded method name from ExecuTorch runtime. Usually it's + * "forward", however, if there are multiple methods in the .pte file, it can + * be different. One example is that we may have prefill and decode methods in + * the same .pte file. In this case, when client loads "prefill" method, the + * `get_method_name` function will return "prefill", when client loads + * "decode" method, the `get_method_name` function will return "decode". + */ + const char* get_method_name() const { + return method_name_; + } + private: MemoryAllocator* runtime_allocator_ = nullptr; + const char* method_name_ = nullptr; }; } // namespace runtime diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 8e74a508f3e..4208cf36c55 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -626,7 +626,9 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) { for (size_t i = 0; i < n_delegate; ++i) { const auto& delegate = *delegates->Get(i); - BackendInitContext backend_init_context(method_allocator); + BackendInitContext backend_init_context( + method_allocator, + /*method_name=*/serialization_plan_->name()->c_str()); Error err = BackendDelegate::Init( delegate, program_, backend_init_context, &delegates_[i]); if (err != Error::Ok) { @@ -1097,8 +1099,9 @@ Error Method::execute_instruction() { n_delegate_, step_state_.instr_idx); BackendExecutionContext backend_execution_context( - /*event_tracer*/ event_tracer_, - /*temp_allocator*/ temp_allocator_); + /*event_tracer=*/event_tracer_, + /*temp_allocator=*/temp_allocator_, + /*method_name=*/serialization_plan_->name()->c_str()); err = delegates_[delegate_idx].Execute( backend_execution_context, chain.argument_lists_[step_state_.instr_idx].data()); diff --git a/runtime/executor/test/backend_integration_test.cpp b/runtime/executor/test/backend_integration_test.cpp index 2db1a16af17..37a653b4d9c 100644 --- a/runtime/executor/test/backend_integration_test.cpp +++ b/runtime/executor/test/backend_integration_test.cpp @@ -95,7 +95,7 @@ class StubBackend final : public BackendInterface { } Error execute( - ET_UNUSED BackendExecutionContext& context, + BackendExecutionContext& context, DelegateHandle* handle, EValue** args) const override { if (execute_fn_) { @@ -530,6 +530,48 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) { EXPECT_EQ(backend_load_was_called, using_segments()); } +TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) { + Result loader = FileDataLoader::from(program_path()); + ASSERT_EQ(loader.error(), Error::Ok); + const void* processed_data = nullptr; + StubBackend::singleton().install_init( + [&](FreeableBuffer* processed, + ET_UNUSED ArrayRef compile_specs, + ET_UNUSED BackendInitContext& backend_init_context) + -> Result { + auto method_name = backend_init_context.get_method_name(); + // Ensure that we can get the method name during init via context + EXPECT_STREQ(method_name, "forward"); + processed_data = processed->data(); + return nullptr; + }); + Result program = Program::load(&loader.get()); + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = program->load_method("forward", &mmm.get()); + EXPECT_TRUE(method.ok()); + ASSERT_EQ(program.error(), Error::Ok); +} + +TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) { + Result loader = FileDataLoader::from(program_path()); + ASSERT_EQ(loader.error(), Error::Ok); + StubBackend::singleton().install_execute( + [&](BackendExecutionContext& backend_execution_context, + ET_UNUSED DelegateHandle* handle, + ET_UNUSED EValue** args) -> Error { + // Ensure that we can get the method name during execution via context + auto method_name = backend_execution_context.get_method_name(); + EXPECT_STREQ(method_name, "forward"); + return Error::Ok; + }); + Result program = Program::load(&loader.get()); + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = program->load_method("forward", &mmm.get()); + EXPECT_TRUE(method.ok()); + Error err = method->execute(); + ASSERT_EQ(err, Error::Ok); +} + // TODO: Add more tests for the runtime-to-backend interface. E.g.: // - Errors during init() or execute() result in runtime init/execution failures // - Correct values are passed to init()/execute()