From efee9908511bcf803f59d6ac407bda5b0b92da96 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Sep 2024 13:11:40 -0700 Subject: [PATCH 1/7] [ET-VK] Add test to track sizes of various objects ## Context Add a simple test to track the sizes of various important objects in the Vulkan compute graph API over time. The test uses some loose thresholds to alert when an object has grown unexpectedly large. Differential Revision: [D62144400](https://our.internmc.facebook.com/intern/diff/D62144400/) [ghstack-poisoned] --- .../vulkan/test/vulkan_compute_api_test.cpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index f3c60a21376..2f9c3d22f57 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -992,6 +992,28 @@ TEST_F(VulkanComputeAPITest, texture_virtual_resize) { graph.get_tensor(name.value)->staging_buffer_numel()); \ graph.copy_from_staging(name.staging, data_##name.data(), data_##name.size()); +// The purpose of this test is simply to track the size of various classes over +// time, in the interest of making sure that they doesn't grow too large. +TEST_F(VulkanComputeAPITest, print_object_sizes) { +#define PRINT_SIZE(name) \ + std::cout << #name << " size: " << sizeof(name) << " B" << std::endl + PRINT_SIZE(vTensor); + PRINT_SIZE(Value); + PRINT_SIZE(StagingBuffer); + PRINT_SIZE(ComputeGraph); + PRINT_SIZE(ExecuteNode); +#undef PRINT_SIZE + + // The actual sizes of each object is dependent on the platform. However, we + // can alert ourselves to any significant changes in the sizes of these + // objects by checking the `sizeof()` the class against some loose thresholds. + EXPECT_TRUE(sizeof(vTensor) < 1800); + EXPECT_TRUE(sizeof(Value) < 2400); + EXPECT_TRUE(sizeof(StagingBuffer) < 500); + EXPECT_TRUE(sizeof(ComputeGraph) < 500); + EXPECT_TRUE(sizeof(ExecuteNode) < 500); +} + TEST(VulkanComputeGraphTest, test_values_scalars) { GraphConfig config; ComputeGraph graph(config); From d4e400c0fe83cf6777872b434a6523eb5d8dd325 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Sep 2024 13:11:42 -0700 Subject: [PATCH 2/7] [ET-VK] Add type for symbolic integers ## Context Introduce the `SymInt` class which allows representation of symbolic integers in a Vulkan graph. Please see the comments documentation of the `SymInt` class for more details regarding why the `Int` type is not sufficient for symbolic integers. Differential Revision: [D62144399](https://our.internmc.facebook.com/intern/diff/D62144399/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ComputeGraph.cpp | 24 ++++++++ backends/vulkan/runtime/graph/ComputeGraph.h | 18 +++++- .../runtime/graph/containers/SymInt.cpp | 24 ++++++++ .../vulkan/runtime/graph/containers/SymInt.h | 41 +++++++++++++ .../vulkan/runtime/graph/containers/Types.cpp | 1 + .../vulkan/runtime/graph/containers/Types.h | 1 + .../vulkan/runtime/graph/containers/Value.h | 9 +++ .../vulkan/test/glsl/scalar_add_texture.glsl | 29 ++++++++++ .../vulkan/test/vulkan_compute_api_test.cpp | 58 +++++++++++++++++++ 9 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 backends/vulkan/runtime/graph/containers/SymInt.cpp create mode 100644 backends/vulkan/runtime/graph/containers/SymInt.h create mode 100644 backends/vulkan/test/glsl/scalar_add_texture.glsl diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 6c3ec88eaa7..a8f57f57d2a 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -43,6 +43,7 @@ VALUE_PTR_CLASS_IMPL(IntListPtr, std::vector, IntList) VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector, DoubleList) VALUE_PTR_CLASS_IMPL(BoolListPtr, std::vector, BoolList) VALUE_PTR_CLASS_IMPL(ValueListPtr, std::vector, ValueList) +VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt) #undef VALUE_PTR_CLASS_IMPL @@ -261,6 +262,13 @@ ValueRef ComputeGraph::add_string(std::string&& str) { return idx; } +ValueRef ComputeGraph::add_symint(const int32_t val) { + ValueRef idx(static_cast(values_.size())); + check_no_active_value_ptrs(); + values_.emplace_back(SymInt(context(), val)); + return idx; +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { @@ -300,6 +308,22 @@ ValueRef ComputeGraph::set_output_tensor( return idx; } +vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer( + const ValueRef idx) { + if (values_.at(idx).isInt()) { + const int32_t val = extract_scalar(idx); + create_params_buffer(val); + } else if (values_.at(idx).isSymInt()) { + SymIntPtr symint = get_symint(idx); + return vkapi::BufferBindInfo(symint->gpu_buffer.buffer()); + } + VK_THROW("Cannot create a int param buffer for the given value"); +} + +void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) { + get_symint(idx)->set(val); +} + SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { if (idx >= shared_objects_.size()) { shared_objects_.resize(static_cast(idx + 1)); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 9b04b08a70e..ac5e0d6c9d1 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -63,6 +63,7 @@ DECL_VALUE_PTR_CLASS(IntListPtr, std::vector) DECL_VALUE_PTR_CLASS(DoubleListPtr, std::vector) DECL_VALUE_PTR_CLASS(BoolListPtr, std::vector) DECL_VALUE_PTR_CLASS(ValueListPtr, std::vector) +DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt); #undef DECL_VALUE_PTR_CLASS @@ -154,6 +155,7 @@ class ComputeGraph final { GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(DoubleListPtr, double_list, DoubleList) GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(BoolListPtr, bool_list, BoolList) GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(ValueListPtr, value_list, ValueList) + GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(SymIntPtr, symint, SymInt); #undef GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS @@ -422,15 +424,28 @@ class ComputeGraph final { ValueRef add_string(std::string&& str); + ValueRef add_symint(const int32_t val); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); template - const vkapi::BufferBindInfo create_params_buffer(const Block& data) { + vkapi::BufferBindInfo create_params_buffer(const Block& data) { param_ubos_.emplace_back(api::ParamsBuffer(context_.get(), data)); return vkapi::BufferBindInfo(param_ubos_.back().buffer()); } + /* + * Given a ValueRef, do the following depending on the type of the Value: + * - If it is a SymInt, return the BufferBindInfo of the ParamsBuffer object + * backing the SymInt. + * - If it is a regular Int, create a new ParamsBuffer using the integer value + * and return the BufferBindInfo of the created ParamsBuffer. + */ + vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx); + + void set_symint(const ValueRef idx, const int32_t val); + /* * Convenience function to add an input tensor along with its staging buffer */ @@ -577,6 +592,7 @@ class ComputeGraph final { friend class DoubleListPtr; friend class BoolListPtr; friend class ValueListPtr; + friend class SymIntPtr; }; template diff --git a/backends/vulkan/runtime/graph/containers/SymInt.cpp b/backends/vulkan/runtime/graph/containers/SymInt.cpp new file mode 100644 index 00000000000..c91db84b787 --- /dev/null +++ b/backends/vulkan/runtime/graph/containers/SymInt.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace vkcompute { + +SymInt::SymInt(api::Context* context_p, const int32_t val) + : gpu_buffer(context_p, val){}; + +void SymInt::set(const int32_t val) { + gpu_buffer.update(val); +} + +void SymInt::operator=(const int32_t val) { + gpu_buffer.update(val); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/containers/SymInt.h b/backends/vulkan/runtime/graph/containers/SymInt.h new file mode 100644 index 00000000000..0c9fbee5fe2 --- /dev/null +++ b/backends/vulkan/runtime/graph/containers/SymInt.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +/* + * Represents a symbolic integer whose value can be variable. It is implemented + * as a thin wrapper around a `ParamsBuffer` object that holds the value of the + * integer. The `ParamsBuffer` object allows the value of the symbolic integer + * to be changed from the CPU and have those changes be visible to all shaders + * that use the symbolic integer; it also allows the value of the symbolic + * integer to be the result of a compute shader. + * + * Regular scalar types represented by `TypeTag::INT` cannot be used for + * symbolic integers because their value is assumed to be constant; therefore + * the `Value` instance holding the value of the scalar does not contain + * any reference to the GPU buffers used to pass its value into compute shaders. + * Therefore, updating the value of the scalar does not impact the value seen + * by compute shaders. + */ +struct SymInt final { + api::ParamsBuffer gpu_buffer; + + explicit SymInt(api::Context* context_p, const int32_t val); + + void set(const int32_t val); + + void operator=(const int32_t val); +}; + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/containers/Types.cpp b/backends/vulkan/runtime/graph/containers/Types.cpp index c5ffc65add1..e7a8951a552 100644 --- a/backends/vulkan/runtime/graph/containers/Types.cpp +++ b/backends/vulkan/runtime/graph/containers/Types.cpp @@ -29,6 +29,7 @@ std::ostream& operator<<(std::ostream& out, const TypeTag& tag) { PRINT_CASE(BOOLLIST) PRINT_CASE(VALUELIST) PRINT_CASE(STRING) + PRINT_CASE(SYMINT) } return out; } diff --git a/backends/vulkan/runtime/graph/containers/Types.h b/backends/vulkan/runtime/graph/containers/Types.h index 79edbd50d3a..5840d1695ee 100644 --- a/backends/vulkan/runtime/graph/containers/Types.h +++ b/backends/vulkan/runtime/graph/containers/Types.h @@ -36,6 +36,7 @@ enum class TypeTag : uint32_t { // Special Type VALUELIST, STRING, + SYMINT, }; std::ostream& operator<<(std::ostream& out, const TypeTag& tag); diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index 6e03bbd4a21..50a2b5e548c 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -13,6 +13,7 @@ #include #include +#include #include namespace vkcompute { @@ -67,6 +68,8 @@ struct Value final { std::string as_string; + SymInt as_symint; + Payload() : u() {} // NOLINTNEXTLINE ~Payload(){}; @@ -123,6 +126,7 @@ struct Value final { TypeTag::VALUELIST, std::vector, as_value_list, vector); CASE_MOVE_MOVEABLE_TYPE( TypeTag::STRING, std::string, as_string, basic_string); + CASE_MOVE_MOVEABLE_TYPE(TypeTag::SYMINT, SymInt, as_symint, SymInt); case TypeTag::NONE: clearToNone(); @@ -172,6 +176,9 @@ struct Value final { case TypeTag::STRING: payload.as_string.~basic_string(); break; + case TypeTag::SYMINT: + payload.as_symint.~SymInt(); + break; // Manually list out the types so that if a type here is added later and // not handled the compiler can catch it. case TypeTag::NONE: @@ -288,6 +295,8 @@ struct Value final { TypeTag::STRING, as_string); + SUPPORT_TRIVIALLY_MOVEABLE_TYPE(SymInt, SymInt, TypeTag::SYMINT, as_symint); + #undef SUPPORT_TRIVIALLY_COPYABLE_TYPE #undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE diff --git a/backends/vulkan/test/glsl/scalar_add_texture.glsl b/backends/vulkan/test/glsl/scalar_add_texture.glsl new file mode 100644 index 00000000000..aa2b22c81f9 --- /dev/null +++ b/backends/vulkan/test/glsl/scalar_add_texture.glsl @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +${layout_declare_tensor(0, "rw", "t_in", "float", "texture3d")} +${layout_declare_ubo(1, "uvec3", "extents")} +${layout_declare_ubo(2, "int", "scalar")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(pos, extents))) { + return; + } + + vec4 in_tex = imageLoad(t_in, pos); + imageStore(t_in, pos, imageLoad(t_in, pos) + float(scalar)); +} diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 2f9c3d22f57..a0bfefafa02 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1268,6 +1268,64 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { } } +TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) { + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + + std::vector sizes = {8, 64, 124}; + + // Build graph + + ValueRef scalar = graph.add_symint(1); + IOValueRef a = graph.add_input_tensor(sizes, vkapi::kFloat); + + IOValueRef out = {}; + out.value = a.value; + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR("scalar_add_texture"), + graph.create_global_wg_size(a.value), + graph.create_local_wg_size(a.value), + // Inputs and Outputs + {{out.value, vkapi::MemoryAccessType::WRITE}}, + // Shader params buffers + {graph.texture_limits_ubo(a.value), + graph.get_or_create_int_param_buffer(scalar)}, + // Specialization Constants + {}, + // Resizing Logic + nullptr, + {})); + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + // Run graph + + for (float i = 5.0f; i < 30.0f; i += 10.0f) { + int scalar_val = i - 3.0f; + graph.set_symint(scalar, scalar_val); + + float val_a = i + 2.0f; + float val_out = val_a + scalar_val; + + fill_vtensor(graph, a, val_a); + + graph.execute(); + + EXTRACT_TENSOR(out); + + // Sanity check that the values are correct + for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + CHECK_VALUE(data_out, i, val_out); + } + } +} + #define CREATE_WEIGHT_TENSOR(name, sizes, dtype, val) \ std::vector data_##name(utils::multiply_integers(sizes)); \ std::fill(data_##name.begin(), data_##name.end(), val); \ From 35df318391cf8b0af33f05b41ff3a7957c355b3e Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Sep 2024 13:11:45 -0700 Subject: [PATCH 3/7] [ET-VK] Add `TmpTensorVRef` struct to recycle temporary tensor memory ## Context Normally, tensor memory is planned during the export stage; tensors that do not overlap in lifetimes may share a memory allocation. However, memory planning requires knowledge of the lifetime of the tensors. However, some complex operators may not be able to perform all the necessary computations in one shader, or the implementation of the operator may require that some temporary tensors be created during the execution of the op. Since these temporary tensors are not visible to the memory planning algorithm, they will not be memory planned. This diff introduces the `TmpTensorVRef` object which facilitates memory sharing between temporary tensors. The design principle is that the lifetime of temporary tensors is restricted to the execution of the op within which they are created; thus, that knowledge can be used to implement memory planning. Please see the comments documentation of `TmpTensorVRef` for more details. Differential Revision: [D62144398](https://our.internmc.facebook.com/intern/diff/D62144398/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ComputeGraph.cpp | 66 +++++++++++++ backends/vulkan/runtime/graph/ComputeGraph.h | 80 +++++++++++++++ .../vulkan/runtime/graph/containers/Value.h | 5 + .../vulkan/test/vulkan_compute_api_test.cpp | 99 +++++++++++++++++++ 4 files changed, 250 insertions(+) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index a8f57f57d2a..92729ffb8ab 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -15,6 +15,8 @@ #include +#include + namespace vkcompute { // @@ -47,6 +49,70 @@ VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt) #undef VALUE_PTR_CLASS_IMPL +// +// TmpTensorVRef +// + +TmpTensorVRef::TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout memory_layout) + : graph_p(graph_ptr), sobj_idx(-1), vref(kDummyValueRef) { + set_sobj_idx(); + vref = + graph_p->add_tensor(sizes, dtype, storage_type, memory_layout, sobj_idx); +} + +TmpTensorVRef::TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type) + : graph_p(graph_ptr), sobj_idx(-1), vref(kDummyValueRef) { + set_sobj_idx(); + vref = graph_p->add_tensor(sizes, dtype, storage_type, sobj_idx); +} + +TmpTensorVRef::TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::GPUMemoryLayout memory_layout) + : graph_p(graph_ptr), sobj_idx(-1), vref(kDummyValueRef) { + set_sobj_idx(); + vref = graph_p->add_tensor(sizes, dtype, memory_layout, sobj_idx); +} + +TmpTensorVRef::TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype) + : graph_p(graph_ptr), sobj_idx(-1), vref(kDummyValueRef) { + set_sobj_idx(); + vref = graph_p->add_tensor(sizes, dtype, sobj_idx); +} + +TmpTensorVRef::~TmpTensorVRef() { + // Lifetime of this temporary tensor is expired; return the shared object to + // the pool, as long as the sobj index is valid + if (sobj_idx >= 0) { + graph_p->tmp_shared_object_idxs_.emplace(sobj_idx); + } +} + +void TmpTensorVRef::set_sobj_idx() { + // If no available temporary shared objects, request a new one to be created + if (graph_p->tmp_shared_object_idxs_.empty()) { + sobj_idx = graph_p->shared_objects_.size(); + } else { + // Get the first available shared object idx + sobj_idx = graph_p->tmp_shared_object_idxs_.top(); + graph_p->tmp_shared_object_idxs_.pop(); + } +} + // // ComputeGraph // diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index ac5e0d6c9d1..218e66f8e42 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -11,6 +11,7 @@ // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName #include +#include #include @@ -67,6 +68,78 @@ DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt); #undef DECL_VALUE_PTR_CLASS +// +// TmpTensorVRef +// + +/* + * This struct is used to recycle the memory of temporary tensors that are + * created during the execution of a node. Upon construction, this struct will + * check the `tmp_shared_object_idxs_` of the provided `ComputeGraph` instance + * if any shared objects are available; if not, then a new one is created. A + * tensor value is then added to the `ComputeGraph` instance with the requested + * specifications. Upon destruction, the shared object index of the temporary + * tensor is returned to `tmp_shared_object_idxs_`. + * + * Note that instances of this struct can be used as if they were `ValueRef` due + * to implementation of a custom casting operator. + * + * This class should only be used to create tensors whose lifetimes exist only + * in a well defined scope (i.e. within a function). + */ +struct TmpTensorVRef { + ComputeGraph* graph_p; + int64_t sobj_idx; + ValueRef vref; + + // + // Match all available overloads of `add_tensor` and `add_tensor_like` + // + + TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type, + const utils::GPUMemoryLayout memory_layout); + + TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::StorageType storage_type); + + TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype, + const utils::GPUMemoryLayout memory_layout); + + TmpTensorVRef( + ComputeGraph* const graph_ptr, + const std::vector& sizes, + const vkapi::ScalarType dtype); + + // No copy construction or assignment + TmpTensorVRef(TmpTensorVRef& other) = delete; + TmpTensorVRef& operator=(TmpTensorVRef& other) = delete; + + // No move construction or assignment + TmpTensorVRef(TmpTensorVRef&& other) = delete; + TmpTensorVRef& operator=(TmpTensorVRef&& other) = delete; + + // Custom cast to ValueRef + operator ValueRef() const { + return vref; + }; + + ~TmpTensorVRef(); + + private: + // Helper function to get new shared obj index + void set_sobj_idx(); +}; + // // ComputeGraph // @@ -94,7 +167,12 @@ class ComputeGraph final { vkapi::DescriptorPoolConfig execute_descriptor_counts_; std::unique_ptr context_; + std::vector shared_objects_; + // This stack is used by `TmpTensorVRef` instances to recycle shared objects + // for temporary tensors. See the comments of `TmpTensorVRef` for more details + std::stack tmp_shared_object_idxs_; + std::vector values_; std::vector param_ubos_; @@ -593,6 +671,8 @@ class ComputeGraph final { friend class BoolListPtr; friend class ValueListPtr; friend class SymIntPtr; + + friend class TmpTensorVRef; }; template diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index 50a2b5e548c..8773f0c0b04 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -29,6 +29,11 @@ inline bool is_valid(ValueRef value_ref) { struct IOValueRef { ValueRef value; ValueRef staging; + + // Custom cast to ValueRef + operator ValueRef() const { + return value; + }; }; /* diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a0bfefafa02..a87e7c3d30f 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1518,6 +1518,105 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { } } +TEST(VulkanComputeGraphTest, test_simple_graph_with_tmp_tensors) { + GraphConfig config; + ComputeGraph graph(config); + + std::vector size_big = {8, 64, 124}; + std::vector size_small = {8, 1, 124}; + + // Build graph + + IOValueRef a = graph.add_input_tensor( + size_big, vkapi::kFloat, /*shared_object_idx = */ 0); + IOValueRef b = graph.add_input_tensor( + size_small, vkapi::kFloat, /*shared_object_idx = */ 1); + + IOValueRef out = {}; + + out.value = + graph.add_tensor(size_big, vkapi::kFloat, /*shared_object_idx = */ 2); + + // Perform the following compute + // + // a, b, out; + // { + // inter; + // { + // tmp = a + b + // tmp2 = tmp + a + // inter = tmp2 + b + // } + // { + // tmp = inter + b; + // tmp2 = tmp + a + // out = tmp2 + b; + // } + // } + { + TmpTensorVRef inter(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(inter.sobj_idx == 3); + { + TmpTensorVRef tmp(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp.sobj_idx == 4); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {a, b, kDummyValueRef, tmp}); + + TmpTensorVRef tmp2(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp2.sobj_idx == 5); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp, a, kDummyValueRef, tmp2}); + + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp2, b, kDummyValueRef, inter}); + } + { + TmpTensorVRef tmp(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp.sobj_idx == 4); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {inter, b, kDummyValueRef, tmp}); + + TmpTensorVRef tmp2(&graph, size_big, vkapi::kFloat); + EXPECT_TRUE(tmp2.sobj_idx == 5); + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp, a, kDummyValueRef, tmp2}); + + VK_GET_OP_FN("aten.add.Tensor") + (graph, {tmp2, b, kDummyValueRef, out}); + } + } + + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + // Run graph + + for (float i = 5.0f; i < 30.0f; i += 10.0f) { + float val_a = i + 2.0f; + float val_b = i + 1.5f; + float val_tmp = val_a + val_b; + float val_tmp2 = val_tmp + val_a; + float val_inter = val_tmp2 + val_b; + float val_tmp_2 = val_inter + val_b; + float val_tmp2_2 = val_tmp_2 + val_a; + float val_out = val_tmp2_2 + val_b; + + fill_vtensor(graph, a, val_a); + fill_vtensor(graph, b, val_b); + + graph.execute(); + + EXTRACT_TENSOR(out); + + // Sanity check that the values are correct + for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + CHECK_VALUE(data_out, i, val_out); + } + } +} + TEST(VulkanComputeGraphTest, test_large_graph) { auto build_start_time = std::chrono::system_clock::now(); GraphConfig config; From 1de43d2a23f1d61a0916e7bde1ef0a5b90ece8be Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 4 Sep 2024 15:04:02 -0700 Subject: [PATCH 4/7] [ET-VK][BE][ez] Enable automatic layout slot index incrementing ## Context Currently, in shaders we have to declare the binding slot that layout bindings will bind to explicitly, i.e. ``` ${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_buffer(1, "r", "nchw_in", DTYPE)} ${layout_declare_ubo(2, "ivec4", "sizes")} ``` However, this can get a little tedious when making many layout declarations. This diff improves the situation by adding the `B` variable which will automatically increment the binding slot whenever a layout binding is declared. Now we can write ``` ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_buffer(B, "r", "nchw_in", DTYPE)} ${layout_declare_ubo(B, "ivec4", "sizes")} ``` I may make a follow up diff to change all layout declarations to use `B` across all shaders in the codebase later on. Differential Revision: [D62210119](https://our.internmc.facebook.com/intern/diff/D62210119/) [ghstack-poisoned] --- backends/vulkan/runtime/gen_vulkan_spv.py | 45 ++++++++++++++++++----- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index f4ba98b31fd..6ee29d45f18 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -38,6 +38,10 @@ # Basic configuration settings for shaders DEFAULT_ENV: Dict[str, Any] = { "PRECISION": "highp", + # B is shorthand for "binding". This is used to automatically increment the + # layout binding index when declaring layout bindings. Note that a container + # type is used because integers are immutable in Python. + "B": [0], } # Establishes relationships between different tensor types and different GLSL types @@ -179,8 +183,14 @@ def get_access_qualifier(access_type: Optional[str]) -> str: raise AssertionError(f"Invalid access type: {access_type}") +def get_slot_val(slot: Union[int, List[int]]) -> int: + if isinstance(slot, list): + return slot[0] + return slot + + def layout_declare_buffer( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -192,15 +202,18 @@ def layout_declare_buffer( array_type = buffer_scalar_type(dtype) out_str = f""" -layout(set = 0, binding = {slot}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{ +layout(set = 0, binding = {get_slot_val(slot)}) buffer {precision} restrict {get_access_qualifier(access_type)} {var_name}Buffer {{ {array_type} {var_name}[]; }}; """ + + if isinstance(slot, list): + slot[0] = slot[0] + 1 return out_str def layout_declare_image( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -209,11 +222,16 @@ def layout_declare_image( ) -> str: image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype] image_type = TYPE_MAPPINGS["IMAGE_T"][image_ndim][dtype] - return f"layout(set = 0, binding = {slot}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};" + + ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}, {image_format}) uniform {precision} restrict {get_access_qualifier(access_type)} {image_type} {var_name};" + + if isinstance(slot, list): + slot[0] = slot[0] + 1 + return ret_str def layout_declare_sampler( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -222,11 +240,16 @@ def layout_declare_sampler( image_ndim: int = 3, ) -> str: sampler_type = TYPE_MAPPINGS["SAMPLER_T"][image_ndim][dtype] - return f"layout(set = 0, binding = {slot}) uniform {precision} {sampler_type} {var_name};" + + ret_str = f"layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} {sampler_type} {var_name};" + + if isinstance(slot, list): + slot[0] = slot[0] + 1 + return ret_str def layout_declare_tensor( - slot: int, + slot: Union[int, List[int]], access_type: str, var_name: str, dtype: str, @@ -262,7 +285,9 @@ def layout_declare_tensor( ) -def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str: +def layout_declare_ubo( + slot: Union[int, List[int]], *args, precision: str = "PRECISION" +) -> str: assert len(args) % 2 == 0 var_list = list(zip(args[::2], args[1::2])) @@ -272,12 +297,14 @@ def layout_declare_ubo(slot: int, *args, precision: str = "PRECISION") -> str: ubo_name += var_name + "_" out_str = f""" -layout(set = 0, binding = {slot}) uniform {precision} restrict readonly {ubo_name}UBO {{ +layout(set = 0, binding = {get_slot_val(slot)}) uniform {precision} restrict readonly {ubo_name}UBO {{ """ for type_name, var_name in var_list: out_str += f"{type_name} {var_name};\n" out_str += "};" + if isinstance(slot, list): + slot[0] = slot[0] + 1 return out_str From 99e51052001201410bde4267ac60e2bb47ebe877 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 4 Sep 2024 15:04:05 -0700 Subject: [PATCH 5/7] [ET-VK] Introduce axis mapping for no-copy permute of texture-backed tensors ## Context This diff introduces the `axis_mapping` field for `vTensors`, which can be used to implement no-copy permutes. The idea behind the axis mapping is that it is somewhat analogous to dim order for texture backed tensors. The axis mapping is normalized to 4 dimensions, similar to padded sizes. The first 3 elements indicates which of the (X,Y,Z) image texture axes the width, height, and channels dim of the tensor maps to. The final element indicates the WHCN index of the tensor dimension along which batches will be concatenated. The benefit of introducing axis mapping is twofold: 1. Permutes can be performed without any data copying by re-using a texture but updating the axis mapping. 2. Allows the memory layout of texture backed tensors to be more flexible, and optimize for performance or memory footprint by using unconventional axis mappings. Regarding the second point, we have found that adding length to a texture's Z axis is more costly than adding length to the texture's X or Y axes. Similarly, we have found that reading along the Z axis yeilds slightly lower throughput than reading along the X or Y axes. By introducing axis mapping, we can map the largest dimension to a texture's X axis instead of mapping it to the most intuitive texture axis (i.e. channels to Z axis). This can save a lot of texture memory and potentially improve compute shader latency as well. However, the pre-requisite of using texture mapping heavily is that the overhead introduced in calculating tensor indices and texture positions does not significantly increase compute shader latency. The impact of this will be investigated and shown in the following diffs. Note that this diff only introduces the `axis_mapping` field; Differential Revision: [D62210118](https://our.internmc.facebook.com/intern/diff/D62210118/) [ghstack-poisoned] --- .../vulkan/runtime/api/containers/Tensor.cpp | 118 ++++++++++++++---- .../vulkan/runtime/api/containers/Tensor.h | 65 ++++++++-- backends/vulkan/runtime/graph/ComputeGraph.h | 4 + 3 files changed, 148 insertions(+), 39 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 7b9d30ef658..f61c9022366 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -80,6 +80,31 @@ std::vector calculate_strides( return strides; } +/* + * Axis mapping is somewhat analogous to strides for texture backed tensors. + * + * The axis mapping is normalized to 4 dimensions, similar to the padded sizes. + * The first 3 values of the axis mapping indicate the (X,Y,Z) image texture + * axis that corresponds to the width, height, and channels dimension of the + * tensor. Thus the axis mapping can be considered to be in WHCN dimension + * order. + * + * The last value `axis_mapping.at(3)` indicates the WHCN index of the tensor + * dimension along which batches will be concatenated. To determine which image + * texture axis is used for the concatenation, a double lookup will need to be + * performed (axis_mapping.at(axis_mapping.at(3))). + * + * The axis mapping allows for permuted views of texture-backed tensors. + */ +std::vector default_axis_mapping() { + // Currently, all compute shaders have an assumption that the channels dim is + // used to combine with the batch dim of a tensor. However, once dim mapping + // is integrated into the tensor indexing logic for each compute shader, we + // can be more flexible with mapping the batch dim to different texture axes + // in order to improve performance or memory footprint. + return {0, 1, 2, 2}; +} + bool dim_order_is_valid(const std::vector& dim_order) { int64_t sum = 0; for (size_t i = 0; i < dim_order.size(); ++i) { @@ -137,30 +162,44 @@ std::vector calculate_padded_sizes( utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, + const std::vector& axis_mapping, const utils::GPUMemoryLayout memory_layout) { VK_CHECK_COND(padded_sizes.size() == 4); + VK_CHECK_COND(axis_mapping.size() == 4); - uint32_t N = utils::safe_downcast(padded_sizes.at(0)); - uint32_t C = utils::safe_downcast(padded_sizes.at(1)); - uint32_t H = utils::safe_downcast(padded_sizes.at(2)); - uint32_t W = utils::safe_downcast(padded_sizes.at(3)); + utils::uvec3 extents({1, 1, 1}); + // First three elements of axis_mapping indicate which (X,Y,Z) image axis the + // width, height, and channels dim of the tensor maps to. + for (int whcn_dim = 0; whcn_dim < 3; ++whcn_dim) { + const int64_t axis = axis_mapping.at(whcn_dim); + const int64_t dim = padded_sizes.size() - 1 - whcn_dim; + extents[axis] = utils::safe_downcast(padded_sizes.at(dim)); + } + + // axis_mapping[3] indicates the WHCN index of the dimension used for batch + // concatenation. Thus a double lookup is required to determine the image axis + // used for batch concatenation. + const int64_t concatted_whcn_dim = axis_mapping.at(3); + const int64_t batch_axis = axis_mapping.at(concatted_whcn_dim); + // Multiply the extents of the batch axis by the batch size. + extents[batch_axis] *= padded_sizes.at(0); switch (memory_layout) { case utils::kWidthPacked: - VK_CHECK_COND(W % 4 == 0); - W /= 4; + VK_CHECK_COND(extents[0] % 4 == 0); + extents[0] /= 4; break; case utils::kHeightPacked: - VK_CHECK_COND(H % 4 == 0); - H /= 4; + VK_CHECK_COND(extents[1] % 4 == 0); + extents[1] /= 4; break; case utils::kChannelsPacked: - VK_CHECK_COND(C % 4 == 0); - C /= 4; + VK_CHECK_COND(extents[2] % 4 == 0); + extents[2] /= 4; break; } - return {W, H, C * N}; + return extents; } // @@ -176,9 +215,10 @@ vTensor::vTensor( const bool allocate_memory) : dtype_(dtype), memory_layout_(memory_layout), - // Calculate tensor size metadata + // Calculate tensor metadata sizes_(sizes.begin(), sizes.end()), dim_order_(calculate_dim_order(sizes_.size(), memory_layout_)), + axis_mapping_(default_axis_mapping()), strides_(calculate_strides(sizes, dim_order_)), numel_(utils::multiply_integers(sizes_)), padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)}, @@ -189,12 +229,14 @@ vTensor::vTensor( sizes_uniform_(), strides_uniform_(), numel_uniform_(), + axis_mapping_uniform_(), texture_limits_uniform_(), // Construct Tensor storage storage_( context, storage_type, memory_layout_, + axis_mapping_, padded_sizes_, dtype_, allocate_memory) { @@ -222,6 +264,7 @@ vTensor::vTensor(const vTensor& other) // Copy tensor size metadata sizes_(other.sizes_.begin(), other.sizes_.end()), dim_order_(other.dim_order_.begin(), other.dim_order_.end()), + axis_mapping_(other.axis_mapping_.begin(), other.axis_mapping_.end()), strides_(other.strides_.begin(), other.strides_.end()), numel_(other.numel_), padded_sizes_{other.padded_sizes_.begin(), other.padded_sizes_.end()}, @@ -234,6 +277,7 @@ vTensor::vTensor(const vTensor& other) sizes_uniform_(), strides_uniform_(), numel_uniform_(), + axis_mapping_uniform_(), texture_limits_uniform_(), // Copy Tensor storage storage_(other.storage_) {} @@ -248,6 +292,7 @@ vTensor::vTensor( // Copy tensor size metadata sizes_(sizes.begin(), sizes.end()), dim_order_(dim_order.begin(), dim_order.end()), + axis_mapping_(default_axis_mapping()), strides_(calculate_strides(sizes_, dim_order_)), numel_(utils::multiply_integers(sizes_)), padded_sizes_{calculate_padded_sizes(sizes, memory_layout_)}, @@ -258,6 +303,7 @@ vTensor::vTensor( sizes_uniform_(), strides_uniform_(), numel_uniform_(), + axis_mapping_uniform_(), texture_limits_uniform_(), // Copy Tensor storage storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) { @@ -315,6 +361,14 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() { return vkapi::BufferBindInfo(strides_uniform_.buffer()); } +const vkapi::BufferBindInfo vTensor::axis_mapping_ubo() { + if (!axis_mapping_uniform_.buffer()) { + axis_mapping_uniform_ = + ParamsBuffer(storage_.context_, utils::make_ivec4(axis_mapping_)); + } + return vkapi::BufferBindInfo(axis_mapping_uniform_.buffer()); +} + const vkapi::BufferBindInfo vTensor::texture_limits_ubo() { if (!texture_limits_uniform_.buffer()) { texture_limits_uniform_ = ParamsBuffer(storage_.context_, texture_limits_); @@ -376,11 +430,7 @@ void vTensor::bind_allocation(const vkapi::Allocation& allocation) { } } -void vTensor::update_metadata( - const std::vector& new_sizes, - const std::vector& new_dim_order) { - sizes_ = new_sizes; - dim_order_ = new_dim_order; +void vTensor::update_metadata() { strides_ = calculate_strides(sizes_, dim_order_); // Only update the memory layout for buffer-backed tensors. Strides are // meaningless for texture-backed tensors and do not impact the memory layout. @@ -396,7 +446,7 @@ void vTensor::update_metadata( // Calculate the extents of the image texture that would have been required // for a tensor of the new sizes. utils::uvec3 virtual_extents = - calculate_image_extents(padded_sizes_, memory_layout_); + calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_); // Update the texture limits to reflect the new virtual extents. texture_limits_.limits = utils::ivec3{ @@ -407,15 +457,18 @@ void vTensor::update_metadata( if (sizes_uniform_.buffer()) { sizes_uniform_.update(utils::make_whcn_ivec4(sizes_)); } - if (texture_limits_uniform_.buffer()) { - texture_limits_uniform_.update(texture_limits_); - } if (strides_uniform_.buffer()) { strides_uniform_.update(utils::make_whcn_ivec4(unsqueezed_strides_)); } if (numel_uniform_.buffer()) { numel_uniform_.update(numel_); } + if (axis_mapping_uniform_.buffer()) { + axis_mapping_uniform_.update(utils::make_ivec4(axis_mapping_)); + } + if (texture_limits_uniform_.buffer()) { + texture_limits_uniform_.update(texture_limits_); + } } void vTensor::check_sizes(const std::vector& sizes) const { @@ -423,7 +476,7 @@ void vTensor::check_sizes(const std::vector& sizes) const { // For texture storage check that the current texture is large enough for // the new sizes of the tensor. utils::uvec3 virtual_extents = - calculate_image_extents(padded_sizes_, memory_layout_); + calculate_image_extents(padded_sizes_, axis_mapping_, memory_layout_); bool valid_resize = virtual_extents[0] <= image_extents()[0]; valid_resize = valid_resize && virtual_extents[1] <= image_extents()[1]; @@ -454,7 +507,9 @@ void vTensor::virtual_reconfigure( VK_CHECK_COND(dim_order_is_valid(new_dim_order)); check_sizes(new_sizes); - update_metadata(new_sizes, new_dim_order); + sizes_ = new_sizes; + dim_order_ = new_dim_order; + update_metadata(); } void vTensor::virtual_resize(const std::vector& new_sizes) { @@ -463,13 +518,16 @@ void vTensor::virtual_resize(const std::vector& new_sizes) { "new sizes cannot modify the dimensionality of the tensor "); check_sizes(new_sizes); - update_metadata(new_sizes, dim_order_); + sizes_ = new_sizes; + update_metadata(); } void vTensor::reallocate(const std::vector& new_sizes) { - update_metadata(new_sizes, dim_order_); + sizes_ = new_sizes; + update_metadata(); storage_.discard_and_reallocate( calculate_padded_sizes(new_sizes, memory_layout_), + axis_mapping_, memory_layout_, dtype_); } @@ -547,12 +605,16 @@ vTensorStorage::vTensorStorage( Context* const context, const utils::StorageType storage_type, const utils::GPUMemoryLayout gpu_memory_layout, + const std::vector& axis_mapping, const std::vector& padded_sizes, const vkapi::ScalarType dtype, const bool allocate_memory) : context_(context), storage_type_{storage_type}, - image_extents_(calculate_image_extents(padded_sizes, gpu_memory_layout)), + image_extents_(calculate_image_extents( + padded_sizes, + axis_mapping, + gpu_memory_layout)), buffer_length_{utils::multiply_integers(padded_sizes)}, buffer_offset_{0}, image_(allocate_image( @@ -665,6 +727,7 @@ bool vTensorStorage::is_copy_of(const vTensorStorage& other) const { void vTensorStorage::discard_and_reallocate( const std::vector& padded_sizes, + const std::vector& axis_mapping, const utils::GPUMemoryLayout gpu_memory_layout, const vkapi::ScalarType dtype) { const bool image_owns_memory = image_.owns_memory(); @@ -672,7 +735,8 @@ void vTensorStorage::discard_and_reallocate( flush(); - image_extents_ = calculate_image_extents(padded_sizes, gpu_memory_layout); + image_extents_ = + calculate_image_extents(padded_sizes, axis_mapping, gpu_memory_layout); image_ = allocate_image( context_, image_extents_, diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index d37628e4adc..70f363796fd 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -60,11 +60,11 @@ std::vector calculate_padded_sizes( const utils::GPUMemoryLayout memory_layout); /* - * Given the padded sizes of a tensor and the GPU memory layout, calculate the - * 3D image extents required to store the tensor data as an image texture. + * Calculate the image extents required of a texture backed tensor. */ utils::uvec3 calculate_image_extents( const std::vector& padded_sizes, + const std::vector& axis_mapping, const utils::GPUMemoryLayout memory_layout); struct LastAccess { @@ -90,7 +90,8 @@ class vTensorStorage final { Context* context, const utils::StorageType storage_type, const utils::GPUMemoryLayout gpu_memory_layout, - const std::vector& sizes, + const std::vector& axis_mapping, + const std::vector& padded_sizes, const vkapi::ScalarType dtype, const bool allocate_memory = true); @@ -159,6 +160,7 @@ class vTensorStorage final { void discard_and_reallocate( const std::vector& padded_sizes, + const std::vector& axis_mapping, const utils::GPUMemoryLayout gpu_memory_layout, const vkapi::ScalarType dtype); }; @@ -218,21 +220,58 @@ class vTensor final { vTensor& operator=(vTensor&& other) = default; private: + /* + * "Core" tensor metadata. They are the minimum amount of information required + * to construct a tensor. + */ + + // Whether the tensor has elements of type float, int, etc. vkapi::ScalarType dtype_; + // Describes which dimension is "tightly packed". For texture backed tensors, + // this describes which dimension is packed along a texel. For buffer backed + // tensors, this describes which dimension has a stride of 1 (i.e. is last in + // the dim order). utils::GPUMemoryLayout memory_layout_; - // sizes of the tensor in NCHW dimension order std::vector sizes_; + + /* + * "Layout" metadata. These describe with further detail how tensor data is + * laid out in memory. However, they are considered secondary to the "core" + * metadata members above because defaults can be assumed based on a given + * memory layout. When permuting the tensor without performing a copy, these + * metadata members are the ones that will be changed. All other metadata is + * derived from a combination of sizes, memory layout, and the below members. + */ + // dim order of the tensor; dimension indices are in NCHW dimension order // i.e. 0 is N, 1 is C, 2 is H, 3 is W for a 4D tensor. The dims with larger // strides precede the dims with smaller strides in the dim order. The last // dim is always the fastest moving dim with a stride of 1. std::vector dim_order_; + // Describes which axis of an image texture each dimension of the tensor maps + // to. The axis mapping allows texture based tensors to be permuted and + // transposed without modifying the underlying texture storage. For a more in + // depth explanation of axis mapping, see the `default_axis_mapping()` + // function. + std::vector axis_mapping_; + + /* + * The below can be consider "layout" metadata as well, but are derived from + * the above data members. + */ + // strides of the tensor in NCHW dimension order std::vector strides_; // Contains the number of elements in the tensor according to the canonical // sizes. size_t numel_; + + /* + * The below metadata members are derived from the above, and are typically + * to i.e. pass tensor metadata to compute shaders. + */ + // padded sizes of the tensor in NCHW dimension order. See the // calculate_padded_sizes() function for more context. Note that padded sizes // are only used for texture storage, and not for buffer storage. @@ -260,6 +299,7 @@ class vTensor final { ParamsBuffer sizes_uniform_; ParamsBuffer strides_uniform_; ParamsBuffer numel_uniform_; + ParamsBuffer axis_mapping_uniform_; ParamsBuffer texture_limits_uniform_; vTensorStorage storage_; @@ -365,14 +405,18 @@ class vTensor final { */ const vkapi::BufferBindInfo strides_ubo(); + /* + * Returns a GPU buffer containing the texture axis mapping for each dimension + * of the tensor, in WHCN dimension order. + */ + const vkapi::BufferBindInfo axis_mapping_ubo(); + /* * Returns a GPU buffer containing the virtual image extents of the tensor. * Since a tensor can be resized with the virtual_resize() function, this * GPU buffer contains the image extents of the tensor calculated using the * virtual_resize() function. This allows shaders to exit early if they are * working outside the limits of the texture. - * - * This buffer should only be used to */ const vkapi::BufferBindInfo texture_limits_ubo(); @@ -423,13 +467,10 @@ class vTensor final { private: /* - * Given new sizes and new strides of the dim order, update the sizes and dim - * order metadata of the vTensor. New strides are computed using the new sizes - * and new dim order. + * Assuming sizes, dim order, or axis mapping was modified, recompute all + * derived metadata and update metadata UBO with new values. */ - void update_metadata( - const std::vector& new_sizes, - const std::vector& new_dim_order); + void update_metadata(); /* * Check that tensor sizes are valid given the current storage resource's diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 210b03c4cad..afdc8290cdd 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -327,6 +327,10 @@ class ComputeGraph final { return values_.at(idx).toTensor().numel_ubo(); } + inline vkapi::BufferBindInfo axis_mapping_ubo(const ValueRef idx) { + return values_.at(idx).toTensor().axis_mapping_ubo(); + } + inline vkapi::BufferBindInfo texture_limits_ubo(const ValueRef idx) { return values_.at(idx).toTensor().texture_limits_ubo(); } From 7c1ff3b70803ae2a490cb18078a86974f5480303 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 4 Sep 2024 15:04:10 -0700 Subject: [PATCH 6/7] [ET-VK] Integrate axis mapping into staging <-> buffer transfer shaders ## Context Building on the previous diff, this diff integrates axis mapping into staging <-> buffer transfer shaders. Alternative versions of indexing utility functions are introduced to account for axis mapping. The impact of shader latency of using axis mapping on transfer shaders is examined in the next diff. Differential Revision: [D62210117](https://our.internmc.facebook.com/intern/diff/D62210117/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/image_to_nchw.glsl | 9 +- .../runtime/graph/ops/glsl/indexing_utils.h | 93 +++++++++++++++++++ .../ops/glsl/int8_image_to_nchw_noint8.glsl | 9 +- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 9 +- .../ops/glsl/nchw_to_int8_image_noint8.glsl | 15 +-- .../runtime/graph/ops/impl/Convolution.cpp | 2 +- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 8 +- backends/vulkan/test/utils/test_utils.cpp | 7 +- .../vulkan/test/vulkan_compute_api_test.cpp | 15 +-- 9 files changed, 136 insertions(+), 31 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl index b51d5a3f6ed..8f113bd2cc2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -${layout_declare_buffer(0, "w", "nchw_out", DTYPE)} -${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} -${layout_declare_ubo(2, "ivec4", "sizes")} +${layout_declare_buffer(B, "w", "nchw_out", DTYPE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_ubo(B, "ivec4", "sizes")} +${layout_declare_ubo(B, "ivec4", "axis_mapping")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -51,7 +52,7 @@ void write_out_texel(VEC4_T texel, ivec4 tensor_idx) { void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim); + const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim); if (any(greaterThanEqual(tensor_idx, sizes))) { return; diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 21eadff0b36..b68a226c298 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -183,6 +183,42 @@ ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) { return tensor_idx; } +/* + * Derive (w,h,c,n) tensor indices from (x,y,z) texture position using axis + * mapping. + */ +ivec4 to_tensor_idx( + ivec3 pos, + ivec4 sizes, + const ivec4 axis_mapping, + const int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + // Packed dim contains 4 elements per texel, so moving 1 unit traverses 4 + // elements in the tensor. + pos[axis_mapping[packed_dim]] *= 4; + + ivec4 tensor_idx; + for (int dim = 0; dim < 4; ++dim) { + tensor_idx[dim] = pos[axis_mapping[dim]]; + } + + // Early return if batch is 1. No need to adjust index. + if (sizes[3] == 1) { + tensor_idx[3] = 0; + return tensor_idx; + } + + // Else, adjust the dim that's concatenated with batch. Note that the axis + // mapping for the batch dim indicates WHCN dim index of the dim that it is + // concatenated with, not a texture axis. + tensor_idx[3] /= sizes[axis_mapping[3]]; + tensor_idx[axis_mapping[3]] %= sizes[axis_mapping[3]]; + + return tensor_idx; +} + /* * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim * is packed along a texel @@ -199,6 +235,34 @@ ivec3 to_texture_pos(ivec4 idx, ivec4 sizes, int packed_dim) { return pos; } +/* + * Derive (x,y,z) texture position from (w,h,c,n) tensor indices using axis + * mapping. + */ +ivec3 to_texture_pos( + const ivec4 idx, + ivec4 sizes, + const ivec4 axis_mapping, + const int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + ivec3 pos; + for (int dim = 0; dim < 4; ++dim) { + pos[axis_mapping[dim]] = idx[dim]; + } + + // Adjust batch dim if needed + if (sizes.w > 1) { + pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w; + } + + // Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4 + // tensor elements in that dim. + pos[axis_mapping[packed_dim]] /= 4; + return pos; +} + /* * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of the tensor, which dim * is packed along a texel @@ -218,6 +282,35 @@ ivec4 to_texture_elem_pos(ivec4 idx, ivec4 sizes, int packed_dim) { return pos; } +/* + * Derive (x,y,z,i) texel element position from the (w,h,c,n) tensor index using + * the axis mapping. + */ +ivec4 to_texture_elem_pos( + const ivec4 idx, + ivec4 sizes, + const ivec4 axis_mapping, + const int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + ivec4 pos; + for (int dim = 0; dim < 4; ++dim) { + pos[axis_mapping[dim]] = idx[dim]; + } + + // Adjust batch dim if needed + if (sizes.w > 1) { + pos[axis_mapping[axis_mapping[3]]] += idx.w * sizes.w; + } + + // Adjust packed dim. Moving 1 texel unit along the packed dim traverses 4 + // tensor elements in that dim. + pos[axis_mapping[packed_dim]] /= 4; + pos.w = idx[packed_dim] % 4; + return pos; +} + // // Texel Access and Storage // diff --git a/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl index b1e3a0abdfe..3ef984bfc95 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl @@ -16,10 +16,11 @@ layout(std430) buffer; #extension GL_EXT_control_flow_attributes : require -${layout_declare_buffer(0, "w", "nchw_out", "int")} -${layout_declare_tensor(1, "r", "t_in", "int8", "texture3d")} -${layout_declare_ubo(2, "ivec4", "tensor_sizes")} -${layout_declare_ubo(3, "int", "out_numel")} +${layout_declare_buffer(B, "w", "nchw_out", "int")} +${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")} +${layout_declare_ubo(B, "ivec4", "tensor_sizes")} +${layout_declare_ubo(B, "ivec4", "axis_mapping")} +${layout_declare_ubo(B, "int", "out_numel")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index abe93904805..04b6a26cc44 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -21,9 +21,10 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_buffer(1, "r", "nchw_in", DTYPE)} -${layout_declare_ubo(2, "ivec4", "sizes")} +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_buffer(B, "r", "nchw_in", DTYPE)} +${layout_declare_ubo(B, "ivec4", "sizes")} +${layout_declare_ubo(B, "ivec4", "axis_mapping")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -53,7 +54,7 @@ VEC4_T read_texel(ivec4 tensor_idx) { void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 tensor_idx = to_tensor_idx(pos, sizes, packed_dim); + const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim); if (any(greaterThanEqual(tensor_idx, sizes))) { return; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl index 378cf09d129..813a174d2a5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl @@ -16,9 +16,10 @@ layout(std430) buffer; #extension GL_EXT_control_flow_attributes : require -${layout_declare_tensor(0, "w", "t_out", "int8", "texture3d")} -${layout_declare_buffer(1, "r", "nchw_in", "int")} -${layout_declare_ubo(2, "ivec4", "tensor_sizes")} +${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")} +${layout_declare_buffer(B, "r", "nchw_in", "int")} +${layout_declare_ubo(B, "ivec4", "sizes")} +${layout_declare_ubo(B, "ivec4", "axis_mapping")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -36,7 +37,7 @@ int extend_sign(int x) { ivec4 read_texel(ivec4 tensor_idx) { const ivec4 buf_indices = get_texel_nchw_buffer_ixs( - tensor_idx, tensor_sizes, packed_dim); + tensor_idx, sizes, packed_dim); int shift = (1 << 8) - 1; ivec4 masks; @@ -51,7 +52,7 @@ ivec4 read_texel(ivec4 tensor_idx) { ivec4 out_tex = ivec4(0); [[unroll]] for (int i = 0; i < 4; ++i) { - if (tensor_idx[packed_dim] + i < tensor_sizes[packed_dim]) { + if (tensor_idx[packed_dim] + i < sizes[packed_dim]) { int in_texel = nchw_in[buf_indices[i] / 4]; int extracted_val = (in_texel & masks[i]) >> (8 * (buf_indices[i] % 4)); extracted_val = extend_sign(extracted_val); @@ -64,9 +65,9 @@ ivec4 read_texel(ivec4 tensor_idx) { void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 tensor_idx = to_tensor_idx(pos, tensor_sizes, packed_dim); + const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_mapping, packed_dim); - if (any(greaterThanEqual(tensor_idx, tensor_sizes))) { + if (any(greaterThanEqual(tensor_idx, sizes))) { return; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 74113197d46..dcdd2dccfa0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -106,7 +106,7 @@ ValueRef prepack_biases( graph.create_local_wg_size(v), vref, v, - {t->sizes_ubo()}, + {t->sizes_ubo(), t->axis_mapping_ubo()}, // Specialization constants {SV(t->packed_dim_whcn_idx())})); diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 9df5b73c1a1..6a759e0fd2e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -31,7 +31,8 @@ void add_staging_to_tensor_node( graph.strides_ubo(out_tensor), graph.numel_ubo(out_tensor)}); } else { - ubos.append(graph.sizes_ubo(out_tensor)); + ubos.append( + {graph.sizes_ubo(out_tensor), graph.axis_mapping_ubo(out_tensor)}); } graph.execute_nodes().emplace_back(new ExecuteNode( @@ -69,7 +70,8 @@ void add_tensor_to_staging_node( graph.strides_ubo(in_tensor), graph.numel_ubo(in_tensor)}); } else { - ubos.append(graph.sizes_ubo(in_tensor)); + ubos.append( + {graph.sizes_ubo(in_tensor), graph.axis_mapping_ubo(in_tensor)}); } // Normally, the image_to_nchw shader is structured so that each thread reads @@ -113,7 +115,7 @@ ValueRef prepack( if (graph.is_buffer_storage(v)) { ubos.append({graph.sizes_ubo(v), graph.strides_ubo(v), graph.numel_ubo(v)}); } else { - ubos.append(graph.sizes_ubo(v)); + ubos.append({graph.sizes_ubo(v), graph.axis_mapping_ubo(v)}); } graph.prepack_nodes().emplace_back(new PrepackNode( diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 4c2972419d0..7b794500436 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -85,7 +85,8 @@ void record_nchw_to_image_op( vkapi::PipelineStage::COMPUTE, vkapi::MemoryAccessType::WRITE), src_buffer, - v_dst.sizes_ubo()); + v_dst.sizes_ubo(), + v_dst.axis_mapping_ubo()); } void record_image_to_nchw_op( @@ -106,7 +107,8 @@ void record_image_to_nchw_op( 0, dst_buffer, v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE), - v_src.sizes_ubo()); + v_src.sizes_ubo(), + v_src.axis_mapping_ubo()); } void record_int8_image_to_nchw_noint8_op( @@ -127,6 +129,7 @@ void record_int8_image_to_nchw_noint8_op( dst_buffer.buffer(), v_src.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE), v_src.sizes_ubo(), + v_src.axis_mapping_ubo(), v_src.numel_ubo()); } diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index aa48117129d..8bd4063c9cf 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1227,8 +1227,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {8, 64, 124}; - std::vector size_small = {8, 1, 124}; + std::vector size_big = {1, 8, 8}; + std::vector size_small = {1, 1, 8}; // Build graph @@ -1409,8 +1409,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { /*shared_object_idx = */ 4); // +2: t.sizes_ubo() for each staging shader + // +2: t.axis_mapping_ubo() for each staging shader // +2: staging buffer for each input tensor - EXPECT_TRUE(get_vma_allocation_count() == 4); + EXPECT_TRUE(get_vma_allocation_count() == 6); ValueRef c = graph.add_tensor( size_big, @@ -1427,8 +1428,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { // +2: alpha UBO, broadcast UBO for arithmetic shader // +1: t.sizes_ubo() uniform buffer for staging shader + // +1: t.axis_mapping_ubo() uniform buffer for staging shader // +1: staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 9); + EXPECT_TRUE(get_vma_allocation_count() == 12); ValueRef e = graph.add_tensor( size_big, @@ -1444,14 +1446,15 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { // +2: alpha UBO, broadcast UBO for arithmetic shader // +1: t.sizes_ubo() for staging shader + // +1: t.axis_mapping_ubo() for staging shader // +1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 13); + EXPECT_TRUE(get_vma_allocation_count() == 17); graph.prepare(); graph.encode_execute(); // +3: shared memory allocations for tensors - EXPECT_TRUE(get_vma_allocation_count() == 16); + EXPECT_TRUE(get_vma_allocation_count() == 20); // Run graph From 7535ad3ce33b36471b69c402d5bc8b04ec9e2e9d Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 5 Sep 2024 17:04:46 -0700 Subject: [PATCH 7/7] Update base for Update on "[ET-VK] Integrate axis mapping into staging <-> buffer transfer shaders" ## Context Building on the previous diff, this diff integrates axis mapping into staging <-> buffer transfer shaders. Alternative versions of indexing utility functions are introduced to account for axis mapping. The impact of shader latency of using axis mapping on transfer shaders is examined in the next diff. Differential Revision: [D62210117](https://our.internmc.facebook.com/intern/diff/D62210117/) [ghstack-poisoned] --- backends/vulkan/test/vulkan_compute_api_test.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 1112548b855..c7d20c38675 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1007,10 +1007,16 @@ TEST_F(VulkanComputeAPITest, print_object_sizes) { // The actual sizes of each object is dependent on the platform. However, we // can alert ourselves to any significant changes in the sizes of these // objects by checking the `sizeof()` the class against some loose thresholds. - EXPECT_TRUE(sizeof(vTensor) < 1800); - EXPECT_TRUE(sizeof(Value) < 2400); + + // Current known size on 64 bit system: 1824 B + EXPECT_TRUE(sizeof(vTensor) < 2000); + // Current known size on 64 bit system: 1840 B + EXPECT_TRUE(sizeof(Value) < 2200); + // Current known size on 64 bit system: 240 B EXPECT_TRUE(sizeof(StagingBuffer) < 500); + // Current known size on 64 bit system: 384 B EXPECT_TRUE(sizeof(ComputeGraph) < 500); + // Current known size on 64 bit system: 248 B EXPECT_TRUE(sizeof(ExecuteNode) < 500); }