diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 3351c213d4a..9b1476f2290 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -34,6 +34,7 @@ set(WEBGPU_SRCS runtime/ops/add/BinaryOp.cpp runtime/ops/rms_norm/RmsNorm.cpp runtime/ops/update_cache/UpdateCache.cpp + runtime/ops/select_as_symint/SelectAsSymint.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index b4e3165d8f4..e934b529c58 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -106,6 +106,12 @@ Error WebGPUBackend::execute( } graph->copy_inputs(inputs); + // Populate live SymInts (dynamic input_pos) from inputs before execute. + graph->update_symints_from_inputs(inputs); + + // Re-derive dispatch state for changed SymInts (Vulkan propagate_resize). + graph->propagate_resize(); + // Execute the compute graph graph->execute(); diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 6aa4d3b1422..5bd4de0ac11 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -59,6 +59,76 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) { return buffer; } +void WebGPUGraph::update_symints_from_inputs( + const std::vector>& inputs) { + for (const auto& src : symint_sources_) { + int pos = -1; + for (size_t i = 0; i < input_ids_.size(); i++) { + if (input_ids_[i] == src.input_tensor_id) { + pos = static_cast(i); + break; + } + } + if (pos < 0 || pos >= static_cast(inputs.size())) { + throw std::runtime_error( + "select_as_symint: source tensor is not a graph input"); + } + const auto& dims = tensors_[src.input_tensor_id].dims; + int dim = src.dim < 0 ? src.dim + static_cast(dims.size()) : src.dim; + int index = src.index; + if (dim >= 0 && dim < static_cast(dims.size()) && index < 0) { + index += static_cast(dims[dim]); + } + int64_t numel = 1; + for (int64_t d : dims) { + numel *= d; + } + int64_t stride = 1; + for (size_t i = static_cast(dim) + 1; i < dims.size(); i++) { + stride *= dims[i]; + } + const int64_t offset = static_cast(index) * stride; + const void* host = inputs[pos].first; + const size_t elem_size = + numel > 0 ? inputs[pos].second / static_cast(numel) : 0; + int32_t val; + if (elem_size == sizeof(int64_t)) { + val = static_cast(static_cast(host)[offset]); + } else if (elem_size == sizeof(int32_t)) { + val = static_cast(host)[offset]; + } else { + throw std::runtime_error( + "select_as_symint: unsupported input element size"); + } + set_symint(src.symint_id, val); + } +} + +void WebGPUGraph::set_symint(int id, int32_t val) { + auto it = symints_.find(id); + if (it == symints_.end()) { + throw std::runtime_error("WebGPUGraph::set_symint: id is not a SymInt"); + } + if (it->second.value != val) { + it->second.value = val; + wgpuQueueWriteBuffer( + queue_, it->second.buffer, 0, &it->second.value, sizeof(int32_t)); + dirty_symints_.insert(id); + } +} + +void WebGPUGraph::propagate_resize() { + if (dirty_symints_.empty()) { + return; + } + for (auto& hook : resize_hooks_) { + if (dirty_symints_.count(hook.symint_id) != 0) { + hook.fn(*this); + } + } + dirty_symints_.clear(); +} + WebGPUGraph::~WebGPUGraph() { for (size_t i = 0; i < tensors_.size(); i++) { if (tensors_[i].buffer && @@ -76,6 +146,16 @@ WebGPUGraph::~WebGPUGraph() { wgpuBufferRelease(buf); } } + for (auto& buf : owned_uniform_buffers_) { + if (buf) { + wgpuBufferRelease(buf); + } + } + for (auto& kv : symints_) { + if (kv.second.buffer) { + wgpuBufferRelease(kv.second.buffer); + } + } for (auto& buf : output_staging_buffers_) { if (buf) { wgpuBufferRelease(buf); @@ -236,6 +316,27 @@ void WebGPUGraph::build( bools_[i] = val->value_as_Bool()->bool_val(); break; } + case vkgraph::GraphTypes::SymInt: { + // Live scalar: small Uniform buffer the CPU rewrites per execute. + value_types_[i] = ValueType::SymInt; + SymIntSlot slot; + slot.value = static_cast(val->value_as_SymInt()->value()); + // 16B satisfies the WGSL uniform min binding size; int32 in first 4. + constexpr size_t kSymIntUniformBytes = 16; + WGPUBufferDescriptor d = {}; + d.size = kSymIntUniformBytes; + d.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + d.mappedAtCreation = true; + slot.buffer = wgpuDeviceCreateBuffer(device_, &d); + void* mapped = + wgpuBufferGetMappedRange(slot.buffer, 0, kSymIntUniformBytes); + std::memset(mapped, 0, kSymIntUniformBytes); + std::memcpy(mapped, &slot.value, sizeof(int32_t)); + wgpuBufferUnmap(slot.buffer); + symints_[i] = slot; + add_uniform_buffer_bytes(kSymIntUniformBytes); + break; + } default: value_types_[i] = ValueType::Null; break; diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index aa3dadc13ab..558b37f8bd1 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -11,8 +11,10 @@ #include #include +#include #include #include +#include #include #include @@ -104,6 +106,52 @@ class WebGPUGraph { return ints_[id]; } + // Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO. + void set_symint(int id, int32_t val); // rewrites the live buffer in place + // read_symint throws (fail-loud) if id is not a SymInt. + int32_t read_symint(int id) const { + return symints_.at(id).value; + } + // symint_buffer returns nullptr if id is not a SymInt. + WGPUBuffer symint_buffer(int id) const { + auto it = symints_.find(id); + return it == symints_.end() ? nullptr : it->second.buffer; + } + + // Records that a SymInt's value is read from input_tensor[index] along dim. + struct SymIntSource { + int symint_id; + int input_tensor_id; + int dim; + int index; + }; + void + add_symint_source(int symint_id, int input_tensor_id, int dim, int index) { + symint_sources_.push_back({symint_id, input_tensor_id, dim, index}); + } + const std::vector& symint_sources() const { + return symint_sources_; + } + + // Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl. + void update_symints_from_inputs( + const std::vector>& inputs); + + // Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize. + void add_resize_hook(int symint_id, std::function fn) { + resize_hooks_.push_back({symint_id, std::move(fn)}); + } + // Run hooks for changed SymInts then clear; call before execute(). + void propagate_resize(); + + // Mutable dispatch access for resize hooks (to rewrite workgroup_count_x). + WebGPUDispatch& dispatch_at(size_t i) { + return dispatches_[i]; + } + size_t num_dispatches() const { + return dispatches_.size(); + } + WGPUDevice device() const { return device_; } @@ -119,6 +167,11 @@ class WebGPUGraph { uniform_buffer_bytes_ += bytes; } + // Keep a uniform alive for the graph's lifetime; released in the dtor. + void own_uniform_buffer(WGPUBuffer buffer) { + owned_uniform_buffers_.push_back(buffer); + } + // Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA). WGPUBuffer create_scratch_buffer(size_t nbytes); @@ -149,7 +202,7 @@ class WebGPUGraph { return static_cast(value_types_.size()); } - enum class ValueType { Tensor, Int, Double, Bool, Null, String }; + enum class ValueType { Tensor, Int, Double, Bool, Null, String, SymInt }; ValueType get_value_type(int id) const { return value_types_[id]; @@ -168,6 +221,22 @@ class WebGPUGraph { std::vector doubles_; std::vector bools_; + // SymInt (live scalar): id -> {live Uniform buffer, current value}, sparse. + struct SymIntSlot { + WGPUBuffer buffer = nullptr; + int32_t value = 0; + }; + std::unordered_map symints_; + std::vector symint_sources_; + + // Resize hooks + the set of SymInts changed since the last propagate_resize. + struct ResizeHook { + int symint_id; + std::function fn; + }; + std::vector resize_hooks_; + std::unordered_set dirty_symints_; + std::vector input_ids_; std::vector output_ids_; @@ -179,6 +248,9 @@ class WebGPUGraph { // Long-lived scratch storage buffers for fused ops (e.g. SDPA temporaries). std::vector scratch_buffers_; + // Uniform buffers owned for the graph's lifetime; released in the dtor. + std::vector owned_uniform_buffers_; + // Staging buffers for reading back outputs (MapRead | CopyDst). std::vector output_staging_buffers_; diff --git a/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp new file mode 100644 index 00000000000..fdfc8244d07 --- /dev/null +++ b/backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +namespace executorch::backends::webgpu { + +namespace { + +// et_vk.select_as_symint: out SymInt = x[index] along dim; read at execute. +void select_as_symint_impl(WebGPUGraph& graph, const std::vector& args) { + const int x_id = args.at(0); + const int dim_id = args.at(1); + const int index_id = args.at(2); + const int out_id = args.at(3); + + if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::SymInt) { + throw std::runtime_error("select_as_symint: output is not a SymInt"); + } + graph.add_symint_source( + out_id, + x_id, + static_cast(graph.get_int(dim_id)), + static_cast(graph.get_int(index_id))); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl); +} + +} // namespace executorch::backends::webgpu