diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index cb958cefea3..201278ac61b 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -350,6 +351,28 @@ class ComputeGraph final { return values_.at(idx).toTensor().logical_limits_ubo(); } + inline PushConstantDataInfo sizes_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorSizes); + } + + inline PushConstantDataInfo strides_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), + api::kTensorStrides); + } + + inline PushConstantDataInfo logical_limits_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), + api::kTensorLogicalLimits); + } + + inline PushConstantDataInfo numel_pc_of(const ValueRef idx) const { + return PushConstantDataInfo( + values_.at(idx).toConstTensor().get_uniform_data(), api::kTensorNumel); + } + // // Scalar Value Extraction // diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 5823f1f7728..87b4b5b5480 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -14,6 +14,22 @@ namespace vkcompute { +uint32_t PushConstantDataInfo::write( + void* dst, + const uint32_t dst_offset, + const uint32_t max_dst_size) const { + if (tensorUniformData != nullptr) { + return tensorUniformData->write_attribute( + dst, dst_offset, max_dst_size, payload_.attr); + } + + VK_CHECK_COND( + (dst_offset + payload_.dataSize) <= max_dst_size, + "Attempting to write push constant data outside data boundary."); + memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize); + return payload_.dataSize; +} + DispatchNode::DispatchNode( ComputeGraph& graph, const vkapi::ShaderInfo& shader, @@ -23,13 +39,15 @@ DispatchNode::DispatchNode( const vkapi::ParamsBindList& params, const vkapi::SpecVarList& spec_vars, const ResizeFunction& resize_fn, - const std::vector& resize_args) + const std::vector& resize_args, + const std::vector& push_constants) : ExecuteNode(resize_fn, resize_args, args, shader.kernel_name), shader_(shader), global_workgroup_size_(global_workgroup_size), local_workgroup_size_(local_workgroup_size), params_(params), - spec_vars_(spec_vars) { + spec_vars_(spec_vars), + push_constants_(push_constants) { graph.update_descriptor_counts(shader, /*execute = */ true); } @@ -57,8 +75,22 @@ void DispatchNode::encode(ComputeGraph* graph) { bind_params_to_descriptor_set(params_, descriptor_set, idx); + std::array push_constants_data; + uint32_t push_constants_offset = 0; + + for (const auto& push_constant : push_constants_) { + push_constants_offset += push_constant.write( + push_constants_data.data(), + push_constants_offset, + kMaxPushConstantSize); + } context->register_shader_dispatch( - descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); + descriptor_set, + pipeline_barrier, + shader_, + global_workgroup_size_, + push_constants_data.data(), + push_constants_offset); context->report_shader_dispatch_end(); } diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index ba7613bd14d..958637218e2 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -18,6 +18,51 @@ namespace vkcompute { class ComputeGraph; +constexpr uint32_t kMaxPushConstantSize = 128; +/* + * Represents a push constant data entry + * Which is either shared pointer to a tensor's uniform data with an attribute + * Or data with a maximum size of 16 bytes + */ +class PushConstantDataInfo { + std::shared_ptr tensorUniformData; + union Payload { + struct { + api::vTensor::Attribute attr; + }; + struct { + uint8_t data[16]; + uint32_t dataSize; + }; + }; + + Payload payload_; + + public: + explicit PushConstantDataInfo( + const std::shared_ptr& tensorUniformData, + api::vTensor::Attribute attr) + : tensorUniformData(tensorUniformData) { + payload_.attr = attr; + } + + explicit PushConstantDataInfo(const void* data, uint32_t dataLen) + : tensorUniformData(nullptr) { + VK_CHECK_COND( + dataLen <= 16, "Single push constant data size must be <= 16 bytes"); + payload_.dataSize = dataLen; + memcpy(payload_.data, data, payload_.dataSize); + } + + /* + * Function writes push constant data to the destination buffer + */ + uint32_t write( + void* dst, + const uint32_t dst_offset, + const uint32_t max_dst_size) const; +}; + /* * Represents a single shader execution op in a ML model. */ @@ -34,7 +79,8 @@ class DispatchNode final : public ExecuteNode { const vkapi::ParamsBindList& params, const vkapi::SpecVarList& spec_vars = {}, const ResizeFunction& resize_fn = nullptr, - const std::vector& resize_args = {}); + const std::vector& resize_args = {}, + const std::vector& push_constants = {}); ~DispatchNode() override = default; @@ -46,6 +92,7 @@ class DispatchNode final : public ExecuteNode { const utils::uvec3 local_workgroup_size_; const vkapi::ParamsBindList params_; const vkapi::SpecVarList spec_vars_; + const std::vector push_constants_; public: operator bool() const { diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl index 8d45e65b396..599879514e3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -19,8 +19,10 @@ layout(std430) buffer; ${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} -${layout_declare_ubo(2, "ivec4", "out_sizes")} -${layout_declare_ubo(3, "ivec4", "in_sizes")} +layout(push_constant) uniform PRECISION restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 060696a4fa6..fc5c7075222 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -76,12 +76,14 @@ void add_view_node( {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, // Parameter Buffers - {t_out->sizes_ubo(), t_in->sizes_ubo()}, + {}, // Specialization Constants {SV(t_in->packed_dim()), SV(t_out->packed_dim())}, // Resizing Logic resize_view_node, - {sizes})); + {sizes}, + // Push Constants + {{graph.sizes_pc_of(out), graph.sizes_pc_of(in)}})); } void view(ComputeGraph& graph, const std::vector& args) {