From 93fc4c61a4af82dd2aeb790db7b5265fb4e56bb9 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 30 Sep 2024 09:30:00 -0700 Subject: [PATCH] [ET-VK][ez] Add API to read value of `SymInt` and`ParamsBuffer` ## Context This diff adds an API to read the value from a `SymInt`. This functionality will be useful because `SymInt`s may be needed to set tensor sizes, in addition to being used as arguments to shaders. Differential Revision: [D63642093](https://our.internmc.facebook.com/intern/diff/D63642093/) [ghstack-poisoned] --- .../runtime/api/containers/ParamsBuffer.h | 19 ++++++++++++++++++- .../vulkan/runtime/graph/ComputeGraph.cpp | 4 ++++ backends/vulkan/runtime/graph/ComputeGraph.h | 2 ++ .../runtime/graph/containers/SymInt.cpp | 4 ++++ .../vulkan/runtime/graph/containers/SymInt.h | 2 ++ .../vulkan/test/vulkan_compute_api_test.cpp | 3 +++ 6 files changed, 33 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/api/containers/ParamsBuffer.h b/backends/vulkan/runtime/api/containers/ParamsBuffer.h index df8d7946d6e..fed7c8fa729 100644 --- a/backends/vulkan/runtime/api/containers/ParamsBuffer.h +++ b/backends/vulkan/runtime/api/containers/ParamsBuffer.h @@ -56,12 +56,29 @@ class ParamsBuffer final { } // Fill the uniform buffer with data in block { - vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::MemoryAccessType::WRITE); + vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kWrite); Block* data_ptr = mapping.template data(); *data_ptr = block; } } + + template + T read() const { + T val; + if (sizeof(val) != nbytes_) { + VK_THROW( + "Attempted to store value from ParamsBuffer to type of different size"); + } + // Read value from uniform buffer and store in val + { + vkapi::MemoryMap mapping(vulkan_buffer_, vkapi::kRead); + T* data_ptr = mapping.template data(); + + val = *data_ptr; + } + return val; + } }; } // namespace api diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 64f24e3012d..967c892b0b4 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -416,6 +416,10 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { get_symint(idx)->set(val); } +int32_t ComputeGraph::read_symint(const ValueRef idx) { + return get_symint(idx)->get(); +} + SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { if (idx >= shared_objects_.size()) { shared_objects_.resize(static_cast(idx + 1)); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index d61ff7e61f6..a23cfb94c27 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -536,6 +536,8 @@ class ComputeGraph final { void set_symint(const ValueRef idx, const int32_t val); + int32_t read_symint(const ValueRef idx); + /* * Convenience function to add an input tensor along with its staging buffer */ diff --git a/backends/vulkan/runtime/graph/containers/SymInt.cpp b/backends/vulkan/runtime/graph/containers/SymInt.cpp index c91db84b787..a59a2d40141 100644 --- a/backends/vulkan/runtime/graph/containers/SymInt.cpp +++ b/backends/vulkan/runtime/graph/containers/SymInt.cpp @@ -17,6 +17,10 @@ void SymInt::set(const int32_t val) { gpu_buffer.update(val); } +int32_t SymInt::get() { + return gpu_buffer.read(); +} + void SymInt::operator=(const int32_t val) { gpu_buffer.update(val); } diff --git a/backends/vulkan/runtime/graph/containers/SymInt.h b/backends/vulkan/runtime/graph/containers/SymInt.h index 0c9fbee5fe2..bd361aabe5a 100644 --- a/backends/vulkan/runtime/graph/containers/SymInt.h +++ b/backends/vulkan/runtime/graph/containers/SymInt.h @@ -35,6 +35,8 @@ struct SymInt final { void set(const int32_t val); + int32_t get(); + void operator=(const int32_t val); }; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 9a99b11f758..0fe3467ed05 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1361,6 +1361,9 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) { int scalar_val = i - 3.0f; graph.set_symint(scalar, scalar_val); + int32_t scalar_val_read = graph.read_symint(scalar); + EXPECT_TRUE(scalar_val_read == scalar_val); + float val_a = i + 2.0f; float val_out = val_a + scalar_val;