Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ set(WEBGPU_SRCS
runtime/ops/add/BinaryOp.cpp
runtime/ops/rms_norm/RmsNorm.cpp
runtime/ops/update_cache/UpdateCache.cpp
runtime/ops/select_as_symint/SelectAsSymint.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
6 changes: 6 additions & 0 deletions backends/webgpu/runtime/WebGPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ Error WebGPUBackend::execute(
}
graph->copy_inputs(inputs);

// Populate live SymInts (dynamic input_pos) from inputs before execute.
graph->update_symints_from_inputs(inputs);

// Re-derive dispatch state for changed SymInts (Vulkan propagate_resize).
graph->propagate_resize();

// Execute the compute graph
graph->execute();

Expand Down
101 changes: 101 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,76 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) {
return buffer;
}

void WebGPUGraph::update_symints_from_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs) {
for (const auto& src : symint_sources_) {
int pos = -1;
for (size_t i = 0; i < input_ids_.size(); i++) {
if (input_ids_[i] == src.input_tensor_id) {
pos = static_cast<int>(i);
break;
}
}
if (pos < 0 || pos >= static_cast<int>(inputs.size())) {
throw std::runtime_error(
"select_as_symint: source tensor is not a graph input");
}
const auto& dims = tensors_[src.input_tensor_id].dims;
int dim = src.dim < 0 ? src.dim + static_cast<int>(dims.size()) : src.dim;
int index = src.index;
if (dim >= 0 && dim < static_cast<int>(dims.size()) && index < 0) {
index += static_cast<int>(dims[dim]);
}
int64_t numel = 1;
for (int64_t d : dims) {
numel *= d;
}
int64_t stride = 1;
for (size_t i = static_cast<size_t>(dim) + 1; i < dims.size(); i++) {
stride *= dims[i];
}
const int64_t offset = static_cast<int64_t>(index) * stride;
const void* host = inputs[pos].first;
const size_t elem_size =
numel > 0 ? inputs[pos].second / static_cast<size_t>(numel) : 0;
int32_t val;
if (elem_size == sizeof(int64_t)) {
val = static_cast<int32_t>(static_cast<const int64_t*>(host)[offset]);
} else if (elem_size == sizeof(int32_t)) {
val = static_cast<const int32_t*>(host)[offset];
} else {
throw std::runtime_error(
"select_as_symint: unsupported input element size");
}
set_symint(src.symint_id, val);
}
}

void WebGPUGraph::set_symint(int id, int32_t val) {
auto it = symints_.find(id);
if (it == symints_.end()) {
throw std::runtime_error("WebGPUGraph::set_symint: id is not a SymInt");
}
if (it->second.value != val) {
it->second.value = val;
wgpuQueueWriteBuffer(
queue_, it->second.buffer, 0, &it->second.value, sizeof(int32_t));
dirty_symints_.insert(id);
}
}

void WebGPUGraph::propagate_resize() {
if (dirty_symints_.empty()) {
return;
}
for (auto& hook : resize_hooks_) {
if (dirty_symints_.count(hook.symint_id) != 0) {
hook.fn(*this);
}
}
dirty_symints_.clear();
}

WebGPUGraph::~WebGPUGraph() {
for (size_t i = 0; i < tensors_.size(); i++) {
if (tensors_[i].buffer &&
Expand All @@ -76,6 +146,16 @@ WebGPUGraph::~WebGPUGraph() {
wgpuBufferRelease(buf);
}
}
for (auto& buf : owned_uniform_buffers_) {
if (buf) {
wgpuBufferRelease(buf);
}
}
for (auto& kv : symints_) {
if (kv.second.buffer) {
wgpuBufferRelease(kv.second.buffer);
}
}
for (auto& buf : output_staging_buffers_) {
if (buf) {
wgpuBufferRelease(buf);
Expand Down Expand Up @@ -236,6 +316,27 @@ void WebGPUGraph::build(
bools_[i] = val->value_as_Bool()->bool_val();
break;
}
case vkgraph::GraphTypes::SymInt: {
// Live scalar: small Uniform buffer the CPU rewrites per execute.
value_types_[i] = ValueType::SymInt;
SymIntSlot slot;
slot.value = static_cast<int32_t>(val->value_as_SymInt()->value());
// 16B satisfies the WGSL uniform min binding size; int32 in first 4.
constexpr size_t kSymIntUniformBytes = 16;
WGPUBufferDescriptor d = {};
d.size = kSymIntUniformBytes;
d.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
d.mappedAtCreation = true;
slot.buffer = wgpuDeviceCreateBuffer(device_, &d);
void* mapped =
wgpuBufferGetMappedRange(slot.buffer, 0, kSymIntUniformBytes);
std::memset(mapped, 0, kSymIntUniformBytes);
std::memcpy(mapped, &slot.value, sizeof(int32_t));
wgpuBufferUnmap(slot.buffer);
symints_[i] = slot;
add_uniform_buffer_bytes(kSymIntUniformBytes);
break;
}
default:
value_types_[i] = ValueType::Null;
break;
Expand Down
74 changes: 73 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
#include <webgpu/webgpu.h>

#include <cstdint>
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include <executorch/runtime/core/named_data_map.h>
Expand Down Expand Up @@ -104,6 +106,52 @@ class WebGPUGraph {
return ints_[id];
}

// Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO.
void set_symint(int id, int32_t val); // rewrites the live buffer in place
// read_symint throws (fail-loud) if id is not a SymInt.
int32_t read_symint(int id) const {
return symints_.at(id).value;
}
// symint_buffer returns nullptr if id is not a SymInt.
WGPUBuffer symint_buffer(int id) const {
auto it = symints_.find(id);
return it == symints_.end() ? nullptr : it->second.buffer;
}

// Records that a SymInt's value is read from input_tensor[index] along dim.
struct SymIntSource {
int symint_id;
int input_tensor_id;
int dim;
int index;
};
void
add_symint_source(int symint_id, int input_tensor_id, int dim, int index) {
symint_sources_.push_back({symint_id, input_tensor_id, dim, index});
}
const std::vector<SymIntSource>& symint_sources() const {
return symint_sources_;
}

// Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl.
void update_symints_from_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs);

// Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize.
void add_resize_hook(int symint_id, std::function<void(WebGPUGraph&)> fn) {
resize_hooks_.push_back({symint_id, std::move(fn)});
}
// Run hooks for changed SymInts then clear; call before execute().
void propagate_resize();

// Mutable dispatch access for resize hooks (to rewrite workgroup_count_x).
WebGPUDispatch& dispatch_at(size_t i) {
return dispatches_[i];
}
size_t num_dispatches() const {
return dispatches_.size();
}

WGPUDevice device() const {
return device_;
}
Expand All @@ -119,6 +167,11 @@ class WebGPUGraph {
uniform_buffer_bytes_ += bytes;
}

// Keep a uniform alive for the graph's lifetime; released in the dtor.
void own_uniform_buffer(WGPUBuffer buffer) {
owned_uniform_buffers_.push_back(buffer);
}

// Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA).
WGPUBuffer create_scratch_buffer(size_t nbytes);

Expand Down Expand Up @@ -149,7 +202,7 @@ class WebGPUGraph {
return static_cast<int>(value_types_.size());
}

enum class ValueType { Tensor, Int, Double, Bool, Null, String };
enum class ValueType { Tensor, Int, Double, Bool, Null, String, SymInt };

ValueType get_value_type(int id) const {
return value_types_[id];
Expand All @@ -168,6 +221,22 @@ class WebGPUGraph {
std::vector<double> doubles_;
std::vector<bool> bools_;

// SymInt (live scalar): id -> {live Uniform buffer, current value}, sparse.
struct SymIntSlot {
WGPUBuffer buffer = nullptr;
int32_t value = 0;
};
std::unordered_map<int, SymIntSlot> symints_;
std::vector<SymIntSource> symint_sources_;

// Resize hooks + the set of SymInts changed since the last propagate_resize.
struct ResizeHook {
int symint_id;
std::function<void(WebGPUGraph&)> fn;
};
std::vector<ResizeHook> resize_hooks_;
std::unordered_set<int> dirty_symints_;

std::vector<int> input_ids_;
std::vector<int> output_ids_;

Expand All @@ -179,6 +248,9 @@ class WebGPUGraph {
// Long-lived scratch storage buffers for fused ops (e.g. SDPA temporaries).
std::vector<WGPUBuffer> scratch_buffers_;

// Uniform buffers owned for the graph's lifetime; released in the dtor.
std::vector<WGPUBuffer> owned_uniform_buffers_;

// Staging buffers for reading back outputs (MapRead | CopyDst).
std::vector<WGPUBuffer> output_staging_buffers_;

Expand Down
41 changes: 41 additions & 0 deletions backends/webgpu/runtime/ops/select_as_symint/SelectAsSymint.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>

#include <stdexcept>

namespace executorch::backends::webgpu {

namespace {

// et_vk.select_as_symint: out SymInt = x[index] along dim; read at execute.
void select_as_symint_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int x_id = args.at(0);
const int dim_id = args.at(1);
const int index_id = args.at(2);
const int out_id = args.at(3);

if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::SymInt) {
throw std::runtime_error("select_as_symint: output is not a SymInt");
}
graph.add_symint_source(
out_id,
x_id,
static_cast<int>(graph.get_int(dim_id)),
static_cast<int>(graph.get_int(index_id)));
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl);
}

} // namespace executorch::backends::webgpu
Loading