From 3d1a8084a72cfc174d79568c4d0b4c5650b5325a Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 24 Sep 2024 09:12:46 -0700 Subject: [PATCH] [ET-VK][ez] Use `MemoryAccessFlags` instead of `MemoryAccessType` when binding ## Context Correct `ComputeGraph` functions to accept `MemoryAccessFlags` instead of `MemoryAccessType`. `MemoryAccessType` correspond to only a single bit, i.e. `READ` or `WRITE` but `MemoryAccessFlags` allows us to express a combination of bits which is the intended behaviour. Differential Revision: [D63327080](https://our.internmc.facebook.com/intern/diff/D63327080/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ops/ExecuteNode.h | 6 +++--- backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp | 2 +- backends/vulkan/runtime/graph/ops/utils/BindingUtils.h | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index dece9ddb50d..92efda30229 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -21,16 +21,16 @@ class ComputeGraph; * access permission. */ struct ArgGroup { - ArgGroup(const ValueRef ref, const vkapi::MemoryAccessType access) + ArgGroup(const ValueRef ref, const vkapi::MemoryAccessFlags access) : refs{ref}, access(access) {} ArgGroup( const std::vector& refs, - const vkapi::MemoryAccessType access) + const vkapi::MemoryAccessFlags access) : refs(refs), access(access) {} const std::vector refs; - const vkapi::MemoryAccessType access; + const vkapi::MemoryAccessFlags access; }; /* diff --git a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp index 2cfb34a052e..b3a72e27c43 100644 --- a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp @@ -13,7 +13,7 @@ namespace vkcompute { void bind_tensor_to_descriptor_set( api::vTensor& tensor, vkapi::PipelineBarrier& pipeline_barrier, - const vkapi::MemoryAccessType accessType, + const vkapi::MemoryAccessFlags accessType, vkapi::DescriptorSet& descriptor_set, const uint32_t idx) { if (tensor.buffer()) { diff --git a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h index eed39a97979..671a18f7e91 100644 --- a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h @@ -19,7 +19,7 @@ namespace vkcompute { void bind_tensor_to_descriptor_set( api::vTensor& tensor, vkapi::PipelineBarrier& pipeline_barrier, - const vkapi::MemoryAccessType accessType, + const vkapi::MemoryAccessFlags accessType, vkapi::DescriptorSet& descriptor_set, const uint32_t idx);