diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index 5321c20aaa4..b4e3165d8f4 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -76,7 +76,7 @@ Result WebGPUBackend::init( } try { - graph->build(flatbuffer_data, constant_data); + graph->build(flatbuffer_data, constant_data, context.get_named_data_map()); } catch (const std::exception& e) { ET_LOG(Error, "WebGPU graph build failed: %s", e.what()); graph->~WebGPUGraph(); diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 91404fb164f..2af5917c296 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -93,7 +94,8 @@ WebGPUGraph::~WebGPUGraph() { void WebGPUGraph::build( const void* flatbuffer_data, - const uint8_t* constant_data) { + const uint8_t* constant_data, + const executorch::runtime::NamedDataMap* named_data_map) { if (!device_) { auto* ctx = get_default_webgpu_context(); if (ctx) { @@ -165,6 +167,31 @@ void WebGPUGraph::build( const uint8_t* src = constant_data + vk_bytes->offset(); wgpuQueueWriteBuffer( queue_, tensor.buffer, 0, src, tensor.nbytes); + } else if ( + vk_bytes->named_key() != nullptr && + named_data_map != nullptr) { + // Constant stored in the PTE named-data map. + auto buf = + named_data_map->get_data(vk_bytes->named_key()->c_str()); + if (!buf.ok()) { + throw std::runtime_error( + std::string("WebGPU: named constant '") + + vk_bytes->named_key()->c_str() + + "' not found in NamedDataMap"); + } + if (buf->size() < tensor.nbytes) { + throw std::runtime_error( + std::string("WebGPU: named constant '") + + vk_bytes->named_key()->c_str() + "' undersized: have " + + std::to_string(buf->size()) + " bytes, need " + + std::to_string(tensor.nbytes)); + } + wgpuQueueWriteBuffer( + queue_, tensor.buffer, 0, buf->data(), tensor.nbytes); + buf->Free(); + } else { + throw std::runtime_error( + "WebGPU: constant has no inline offset and no named-data key"); } } } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 3aa96917a4e..749c9f8c841 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -15,6 +15,8 @@ #include #include +#include + namespace executorch { namespace backends { namespace webgpu { @@ -66,7 +68,10 @@ class WebGPUGraph { // Build the graph from a deserialized VkGraph flatbuffer and constant data. // The flatbuffer_data pointer must remain valid during build(). - void build(const void* flatbuffer_data, const uint8_t* constant_data); + void build( + const void* flatbuffer_data, + const uint8_t* constant_data, + const executorch::runtime::NamedDataMap* named_data_map = nullptr); // Copy input tensor data from host pointers into GPU buffers. void copy_inputs(const std::vector>& inputs);