diff --git a/backends/vulkan/runtime/utils/VecUtils.h b/backends/vulkan/runtime/utils/VecUtils.h index ad4434cf5af..c084a563544 100644 --- a/backends/vulkan/runtime/utils/VecUtils.h +++ b/backends/vulkan/runtime/utils/VecUtils.h @@ -479,5 +479,49 @@ inline int64_t multiply_integers(Iter begin, Iter end) { begin, end, static_cast(1), std::multiplies<>()); } +class WorkgroupSize final { + uint32_t val; + + public: + explicit WorkgroupSize() : val(0) {} + explicit WorkgroupSize(const uint32_t x, const uint32_t y, const uint32_t z) { + // shift numbers by multiple of 11 bits, since each local workgroup axis can + // be 1024 at most and which is 0x400. only z axis can't store 1024, because + // it would overflow uint32_t storage. + if (z == 1024) { + throw std::runtime_error( + "Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage"); + } + val = x | (y << 11) | (z << 22); + } + + explicit WorkgroupSize(const uvec3& vec) { + // shift numbers by multiple of 11 bits, since each local workgroup axis can + // be 1024 at most and which is 0x400. only z axis can't store 1024, because + // it would overflow uint32_t storage. + if (vec[2u] == 1024) { + throw std::runtime_error( + "Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage"); + } + val = vec[0u] | (vec[1u] << 11) | (vec[2u] << 22); + } + + explicit inline operator uvec3() const { + return { + val & 0x7ffu, + (val >> 11) & 0x7ffu, + (val >> 22), + }; + } + + explicit inline operator uint32_t() const { + return val; + } + + inline constexpr uint32_t operator[](const int idx) const { + return (val >> (11 * idx)) & 0x7ffu; + } +}; + } // namespace utils } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Pipeline.cpp b/backends/vulkan/runtime/vk_api/Pipeline.cpp index 0c66a085ad9..51b59ed4d1f 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.cpp +++ b/backends/vulkan/runtime/vk_api/Pipeline.cpp @@ -174,6 +174,14 @@ void SpecVarList::append(const SpecVarList& other) { vars.insert(vars.end(), other.vars.begin(), other.vars.end()); } +void SpecVarList::reserve(const size_t size) { + vars.reserve(size); +} + +void SpecVarList::append(const SpecVar& other) { + vars.push_back(other); +} + std::vector SpecVarList::generate_map_entries() const { std::vector map_entries; diff --git a/backends/vulkan/runtime/vk_api/Pipeline.h b/backends/vulkan/runtime/vk_api/Pipeline.h index 5460a0acba7..b9f4e3d2a35 100644 --- a/backends/vulkan/runtime/vk_api/Pipeline.h +++ b/backends/vulkan/runtime/vk_api/Pipeline.h @@ -82,6 +82,10 @@ class SpecVarList final { void append(const SpecVarList& other); + void reserve(const size_t size); + + void append(const SpecVar& other); + std::vector generate_map_entries() const; friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);