diff --git a/backends/vulkan/runtime/api/containers/ParamsBuffer.h b/backends/vulkan/runtime/api/containers/ParamsBuffer.h index df8d7946d6e..fed7c8fa729 100644 --- a/backends/vulkan/runtime/api/containers/ParamsBuffer.h +++ b/backends/vulkan/runtime/api/containers/ParamsBuffer.h @@ -56,12 +56,29 @@ class ParamsBuffer final { } // Fill the uniform buffer with data in block { - vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::MemoryAccessType::WRITE); + vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kWrite); Block* data_ptr = mapping.template data(); *data_ptr = block; } } + + template + T read() const { + T val; + if (sizeof(val) != nbytes_) { + VK_THROW( + "Attempted to store value from ParamsBuffer to type of different size"); + } + // Read value from uniform buffer and store in val + { + vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kRead); + T* data_ptr = mapping.template data(); + + val = *data_ptr; + } + return val; + } }; } // namespace api diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 64f24e3012d..967c892b0b4 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -416,6 +416,10 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { get_symint(idx)->set(val); } +int32_t ComputeGraph::read_symint(const ValueRef idx) { + return get_symint(idx)->get(); +} + SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { if (idx >= shared_objects_.size()) { shared_objects_.resize(static_cast(idx + 1)); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index d61ff7e61f6..a23cfb94c27 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -536,6 +536,8 @@ class ComputeGraph final { void set_symint(const ValueRef idx, const int32_t val); + int32_t read_symint(const ValueRef idx); + /* * Convenience function to add an input tensor along with its staging buffer */ diff --git a/backends/vulkan/runtime/graph/containers/SymInt.cpp b/backends/vulkan/runtime/graph/containers/SymInt.cpp index c91db84b787..a59a2d40141 100644 --- a/backends/vulkan/runtime/graph/containers/SymInt.cpp +++ b/backends/vulkan/runtime/graph/containers/SymInt.cpp @@ -17,6 +17,10 @@ void SymInt::set(const int32_t val) { gpu_buffer.update(val); } +int32_t SymInt::get() { + return gpu_buffer.read(); +} + void SymInt::operator=(const int32_t val) { gpu_buffer.update(val); } diff --git a/backends/vulkan/runtime/graph/containers/SymInt.h b/backends/vulkan/runtime/graph/containers/SymInt.h index 0c9fbee5fe2..bd361aabe5a 100644 --- a/backends/vulkan/runtime/graph/containers/SymInt.h +++ b/backends/vulkan/runtime/graph/containers/SymInt.h @@ -35,6 +35,8 @@ struct SymInt final { void set(const int32_t val); + int32_t get(); + void operator=(const int32_t val); }; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 4d291eec42a..fc92739669a 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1433,6 +1433,9 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) { int scalar_val = i - 3.0f; graph.set_symint(scalar, scalar_val); + int32_t scalar_val_read = graph.read_symint(scalar); + EXPECT_TRUE(scalar_val_read == scalar_val); + float val_a = i + 2.0f; float val_out = val_a + scalar_val;