diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index f4ba98b31fd..6ee29d45f18 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -38,6 +38,10 @@ # Basic configuration settings for shaders DEFAULT_ENV: Dict[str, Any] = { "PRECISION": "highp", + # B is shorthand for "binding". This is used to automatically increment the + # layout binding index when declaring layout bindings. Note that a container + # type is used because integers are immutable in Python. + "B": [0], } # Establishes relationships between different tensor types and different GLSL types @@ -179,8 +183,14 @@ def get_access_qualifier(access_type: Optional[str]) -> str: raise AssertionError(f"Invalid access type: {access_type}") +def get_slot_val(slot: Union[int, List[int]]) -> int: + if isinstance(slot, list): + return slot[0] + return slot + + def layout_declare_buffer( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -192,15 +202,18 @@ def layout_declare_buffer( array_type = buffer_scalar_type(dtype) out_str = f""" -layout(set = 0, binding = {slot}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{ +layout(set = 0, binding = {get_slot_val(slot)}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{ {array_type} {var_name}[]; }}; """ + + if isinstance(slot, list): + slot[0] = slot[0] + 1 return out_str def layout_declare_image( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -209,11 +222,16 @@ def layout_declare_image( ) -> str: image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype] image_type = TYPE_MAPPINGS["IMAGE_T"][image_ndim][dtype] - return f"layout(set = 0, binding = {slot}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};" + + ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};" + + if isinstance(slot, list): + slot[0] = slot[0] + 1 + return ret_str def layout_declare_sampler( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -222,11 +240,16 @@ def layout_declare_sampler( image_ndim: int = 3, ) -> str: sampler_type = TYPE_MAPPINGS["SAMPLER_T"][image_ndim][dtype] - return f"layout(set = 0, binding = {slot}) uniform {precision} {sampler_type} {var_name};" + + ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} {sampler_type} {var_name};" + + if isinstance(slot, list): + slot[0] = slot[0] + 1 + return ret_str def layout_declare_tensor( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -262,7 +285,9 @@ def layout_declare_tensor( ) -def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str: +def layout_declare_ubo( + slot: Union[int, List[int]], *args, precision: str = "PRECISION" +) -> str: assert len(args) % 2 == 0 var_list = list(zip(args[::2], args[1::2])) @@ -272,12 +297,14 @@ def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str: ubo_name += var_name + "_" out_str = f""" -layout(set = 0, binding = {slot}) uniform {precision} restrict readonly {ubo_name}UBO {{ +layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{ """ for type_name, var_name in var_list: out_str += f"{type_name} {var_name};\n" out_str += "};" + if isinstance(slot, list): + slot[0] = slot[0] + 1 return out_str diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index a8f57f57d2a..37c54f959de 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -47,6 +47,72 @@ VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt) #undef VALUE_PTR_CLASS_IMPL +// +// TmpTensor +// + +TmpTensor::TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout memory_layout) + : graph_p(graph_ptr), + sobj_idx(get_sobj_idx()), + vref(graph_p->add_tensor( + sizes, + dtype, + storage_type, + memory_layout, + sobj_idx)) {} + +TmpTensor::TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type) + : graph_p(graph_ptr), + sobj_idx(get_sobj_idx()), + vref(graph_p->add_tensor(sizes, dtype, storage_type, sobj_idx)) {} + +TmpTensor::TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::GPUMemoryLayout memory_layout) + : graph_p(graph_ptr), + sobj_idx(get_sobj_idx()), + vref(graph_p->add_tensor(sizes, dtype, memory_layout, sobj_idx)) {} + +TmpTensor::TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype) + : graph_p(graph_ptr), + sobj_idx(get_sobj_idx()), + vref(graph_p->add_tensor(sizes, dtype, sobj_idx)) {} + +TmpTensor::~TmpTensor() { + // Lifetime of this temporary tensor is expired; return the shared object to + // the pool, as long as the sobj index is valid + if (sobj_idx >= 0) { + graph_p->tmp_shared_object_idxs_.emplace(sobj_idx); + } +} + +int64_t TmpTensor::get_sobj_idx() { + int64_t sobj_idx; + // If no available temporary shared objects, request a new one to be created + if (graph_p->tmp_shared_object_idxs_.empty()) { + sobj_idx = graph_p->shared_objects_.size(); + } else { + // Get the first available shared object idx + sobj_idx = graph_p->tmp_shared_object_idxs_.top(); + graph_p->tmp_shared_object_idxs_.pop(); + } + return sobj_idx; +} + // // ComputeGraph // diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index ac5e0d6c9d1..210b03c4cad 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -11,6 +11,7 @@ // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName #include +#include #include @@ -67,6 +68,79 @@ DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt); #undef DECL_VALUE_PTR_CLASS +// +// TmpTensor +// + +/* + * This struct is used to recycle the memory of temporary tensors that are + * created during the execution of a node. Upon construction, this struct will + * check the `tmp_shared_object_idxs_` of the provided `ComputeGraph` instance + * if any shared objects are available; if not, then a new one is created. A + * tensor value is then added to the `ComputeGraph` instance with the requested + * specifications. Upon destruction, the shared object index of the temporary + * tensor is returned to `tmp_shared_object_idxs_`. + * + * Note that instances of this struct can be used as if they were `ValueRef` due + * to implementation of a custom casting operator. + * + * This class should only be used to create tensors whose lifetimes exist only + * in a well defined scope (i.e. within a function). + */ +struct TmpTensor { + ComputeGraph* graph_p; + int64_t sobj_idx; + ValueRef vref; + + // + // Match all available overloads of `add_tensor` + // + + TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout memory_layout); + + TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type); + + TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::GPUMemoryLayout memory_layout); + + TmpTensor( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype); + + // No copy construction or assignment + TmpTensor(TmpTensor& other) = delete; + TmpTensor& operator=(TmpTensor& other) = delete; + + // No move construction or assignment + TmpTensor(TmpTensor&& other) = delete; + TmpTensor& operator=(TmpTensor&& other) = delete; + + // Custom cast to ValueRef + operator ValueRef() const { + return vref; + }; + + ~TmpTensor(); + + private: + // Helper function to get first available shared object index or request a new + // one to be created. + int64_t get_sobj_idx(); +}; + // // ComputeGraph // @@ -94,7 +168,12 @@ class ComputeGraph final { vkapi::DescriptorPoolConfig execute_descriptor_counts_; std::unique_ptr context_; + std::vector shared_objects_; + // This stack is used by `TmpTensor` instances to recycle shared objects + // for temporary tensors. See the comments of `TmpTensor` for more details + std::stack tmp_shared_object_idxs_; + std::vector values_; std::vector param_ubos_; @@ -593,6 +672,8 @@ class ComputeGraph final { friend class BoolListPtr; friend class ValueListPtr; friend class SymIntPtr; + + friend struct TmpTensor; }; template diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index 50a2b5e548c..8773f0c0b04 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -29,6 +29,11 @@ inline bool is_valid(ValueRef value_ref) { struct IOValueRef { ValueRef value; ValueRef staging; + + // Custom cast to ValueRef + operator ValueRef() const { + return value; + }; }; /* diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a0bfefafa02..aa48117129d 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1518,6 +1518,105 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { } } +TEST(VulkanComputeGraphTest, test_simple_graph_with_tmp_tensors) { + GraphConfig config; + ComputeGraph graph(config); + + std::vector size_big = {8, 64, 124}; + std::vector size_small = {8, 1, 124}; + + // Build graph + + IOValueRef a = graph.add_input_tensor( + size_big, vkapi::kFloat, /*shared_object_idx = */ 0); + IOValueRef b = graph.add_input_tensor( + size_small, vkapi::kFloat, /*shared_object_idx = */ 1); + + IOValueRef out = {}; + + out.value = + graph.add_tensor(size_big, vkapi::kFloat, /*shared_object_idx = */ 2); + + // Perform the following compute + // + // a, b, out; + // { + // inter; + // { + // tmp = a + b + // tmp2 = tmp + a + // inter = tmp2 + b + // } + // { + // tmp = inter + b; + // tmp2 = tmp + a + // out = tmp2 + b; + // } + // } + { + TmpTensor inter(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(inter.sobj_idx == 3); + { + TmpTensor tmp(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp.sobj_idx == 4); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {a, b, kDummyValueRef, tmp}); + + TmpTensor tmp2(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp2.sobj_idx == 5); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp, a, kDummyValueRef, tmp2}); + + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp2, b, kDummyValueRef, inter}); + } + { + TmpTensor tmp(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp.sobj_idx == 4); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {inter, b, kDummyValueRef, tmp}); + + TmpTensor tmp2(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp2.sobj_idx == 5); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp, a, kDummyValueRef, tmp2}); + + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp2, b, kDummyValueRef, out}); + } + } + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + // Run graph + + for (float i = 5.0f; i < 30.0f; i += 10.0f) { + float val_a = i + 2.0f; + float val_b = i + 1.5f; + float val_tmp = val_a + val_b; + float val_tmp2 = val_tmp + val_a; + float val_inter = val_tmp2 + val_b; + float val_tmp_2 = val_inter + val_b; + float val_tmp2_2 = val_tmp_2 + val_a; + float val_out = val_tmp2_2 + val_b; + + fill_vtensor(graph, a, val_a); + fill_vtensor(graph, b, val_b); + + graph.execute(); + + EXTRACT_TENSOR(out); + + // Sanity check that the values are correct + for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + CHECK_VALUE(data_out, i, val_out); + } + } +} + TEST(VulkanComputeGraphTest, test_large_graph) { auto build_start_time = std::chrono::system_clock::now(); GraphConfig config;