diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 4ff0f9e93d6..ceb95f3a304 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -390,18 +390,20 @@ bool maybe_resize_input( const size_t input_i, executorch::aten::Tensor& et_tensor) { ValueRef in_tensor_ref = graph->inputs()[input_i].value; - vTensorPtr in_tensor = graph->get_tensor(in_tensor_ref); + + const std::vector in_tensor_vk_sizes = + graph->sizes_of(in_tensor_ref); ET_CHECK_MSG( - et_tensor.dim() == in_tensor->sizes().size(), + et_tensor.dim() == in_tensor_vk_sizes.size(), "Cannot resize input tensor: old ndim %zu does not match new ndim %zu", - static_cast(in_tensor->sizes().size()), + static_cast(in_tensor_vk_sizes.size()), static_cast(et_tensor.dim())); bool should_resize = false; std::vector new_sizes(et_tensor.dim()); for (size_t i = 0; i < et_tensor.dim(); i++) { - if (in_tensor->sizes()[i] != et_tensor.sizes()[i]) { + if (in_tensor_vk_sizes[i] != et_tensor.sizes()[i]) { should_resize = true; } new_sizes.at(i) = et_tensor.sizes()[i]; @@ -411,10 +413,11 @@ bool maybe_resize_input( graph->resize_input(input_i, new_sizes); } + const size_t in_tensor_vk_numel = graph->numel_of(in_tensor_ref); ET_CHECK_MSG( - in_tensor->numel() == et_tensor.numel(), + in_tensor_vk_numel == et_tensor.numel(), "Vulkan tensor numel %zu does not match ET tensor numel %zu", - static_cast(in_tensor->numel()), + static_cast(in_tensor_vk_numel), static_cast(et_tensor.numel())); return should_resize; @@ -445,12 +448,14 @@ void maybe_resize_output( const size_t output_i, executorch::aten::Tensor& et_tensor) { ValueRef out_tensor_ref = graph->outputs()[output_i].value; - vTensorPtr out_tensor = graph->get_tensor(out_tensor_ref); + + const std::vector out_tensor_vk_sizes = + graph->sizes_of(out_tensor_ref); executorch::aten::SizesType new_output_size[kTensorDimensionLimit]; - size_t ndim = out_tensor->sizes().size(); + size_t ndim = out_tensor_vk_sizes.size(); for (int i = 0; i < ndim; ++i) { - new_output_size[i] = out_tensor->sizes()[i]; + new_output_size[i] = out_tensor_vk_sizes[i]; } executorch::aten::ArrayRef output_size{ diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 7775165bc68..7bc00e128e5 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -704,6 +704,38 @@ utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { return create_local_wg_size(create_global_wg_size(idx)); } +void ComputeGraph::bind_tensor_to_descriptor_set( + const ValueRef ref, + vkapi::PipelineBarrier& pipeline_barrier, + const vkapi::MemoryAccessFlags access_type, + vkapi::DescriptorSet& descriptor_set, + const uint32_t idx) { + vTensorPtr tensor = get_tensor(ref); + if (tensor->buffer()) { + vkapi::VulkanBuffer& buffer = tensor->buffer( + pipeline_barrier, vkapi::PipelineStage::COMPUTE, access_type); + descriptor_set.bind(idx, buffer); + } else { + vkapi::VulkanImage& image = tensor->image( + pipeline_barrier, vkapi::PipelineStage::COMPUTE, access_type); + descriptor_set.bind(idx, image); + } +} + +void ComputeGraph::bind_value_to_descriptor_set( + const ValueRef ref, + vkapi::PipelineBarrier& pipeline_barrier, + const vkapi::MemoryAccessFlags access_type, + vkapi::DescriptorSet& descriptor_set, + const uint32_t idx) { + if (val_is_tensor(ref)) { + bind_tensor_to_descriptor_set( + ref, pipeline_barrier, access_type, descriptor_set, idx); + } else if (val_is_staging(ref)) { + descriptor_set.bind(idx, get_staging(ref)->buffer()); + } +} + void ComputeGraph::copy_into_staging( const ValueRef idx, const void* data, @@ -891,6 +923,17 @@ void ComputeGraph::execute() { execute_count_++; } +void ComputeGraph::virtual_clone(const ValueRef dst, const ValueRef src) { + get_tensor(dst)->virtual_clone(*get_tensor(src)); +} + +void ComputeGraph::virtual_transpose( + const ValueRef tensor, + const int64_t dim0, + const int64_t dim1) { + get_tensor(tensor)->virtual_transpose(dim0, dim1); +} + void ComputeGraph::resize_input( const int64_t idx, const std::vector& new_sizes) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 886e2c5ccea..3bef6a2f95a 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -319,6 +319,10 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().numel(); } + inline size_t staging_buffer_numel_of(const ValueRef idx) const { + return values_.at(idx).toConstTensor().staging_buffer_numel(); + } + inline utils::StorageType storage_type_of(const ValueRef idx) const { return values_.at(idx).toConstTensor().storage_type(); } @@ -832,6 +836,20 @@ class ComputeGraph final { */ utils::uvec3 create_local_wg_size(const ValueRef idx); + void bind_tensor_to_descriptor_set( + const ValueRef ref, + vkapi::PipelineBarrier& pipeline_barrier, + const vkapi::MemoryAccessFlags accessType, + vkapi::DescriptorSet& descriptor_set, + const uint32_t idx); + + void bind_value_to_descriptor_set( + const ValueRef ref, + vkapi::PipelineBarrier& pipeline_barrier, + const vkapi::MemoryAccessFlags access_type, + vkapi::DescriptorSet& descriptor_set, + const uint32_t idx); + // // Input/Output // @@ -890,14 +908,27 @@ class ComputeGraph final { void execute(); + // + // Tensor View + // + + void virtual_clone(const ValueRef dst, const ValueRef src); + + void virtual_transpose( + const ValueRef tensor, + const int64_t dim0, + const int64_t dim1); + // // Dynamic Shape support // void resize_input(const int64_t idx, const std::vector& new_sizes); + void virtual_resize( const ValueRef idx, const std::vector& new_sizes); + void propagate_resize(); // diff --git a/backends/vulkan/runtime/graph/ops/BlitNode.cpp b/backends/vulkan/runtime/graph/ops/BlitNode.cpp index 03ee4caa51a..de1ad596069 100644 --- a/backends/vulkan/runtime/graph/ops/BlitNode.cpp +++ b/backends/vulkan/runtime/graph/ops/BlitNode.cpp @@ -26,11 +26,9 @@ BlitNode::BlitNode( } void BlitNode::encode(ComputeGraph* graph) { - auto src_tensor = graph->get_tensor(src_); - auto dst_tensor = graph->get_tensor(dst_); VK_CHECK_COND( - src_tensor->storage_type() != utils::kBuffer && - dst_tensor->storage_type() != utils::kBuffer, + graph->storage_type_of(src_) != utils::kBuffer && + graph->storage_type_of(dst_) != utils::kBuffer, "BlitNode: Only texture backed tensors are supported."); api::Context* const context = graph->context(); @@ -41,18 +39,18 @@ void BlitNode::encode(ComputeGraph* graph) { // Hack to get timing data for non shader op std::string kernel_name("Blit_"); kernel_name.reserve(32); - kernel_name += vkapi::to_string(src_tensor->dtype()); + kernel_name += vkapi::to_string(graph->dtype_of(src_)); kernel_name += "_to_"; - kernel_name += vkapi::to_string(dst_tensor->dtype()); + kernel_name += vkapi::to_string(graph->dtype_of(dst_)); context->report_shader_dispatch_start( kernel_name, utils::uvec3(), utils::WorkgroupSize(), node_id_); context->register_blit( pipeline_barrier, - src_tensor->image( + graph->get_tensor(src_)->image( pipeline_barrier, vkapi::PipelineStage::TRANSFER, vkapi::kRead), - dst_tensor->image( + graph->get_tensor(dst_)->image( pipeline_barrier, vkapi::PipelineStage::TRANSFER, vkapi::kWrite)); context->report_shader_dispatch_end(); diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 05729172420..c8220df837b 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -18,9 +18,8 @@ namespace vkcompute { vkapi::ShaderInfo get_noop_shader(ComputeGraph& graph, const ValueRef packed) { std::string noop_shader_name("no_op"); - vTensorPtr t_packed = graph.get_tensor(packed); - add_dtype_suffix(noop_shader_name, *t_packed); - add_storage_type_suffix(noop_shader_name, *t_packed); + add_dtype_suffix(noop_shader_name, graph.dtype_of(packed)); + add_storage_type_suffix(noop_shader_name, graph.storage_type_of(packed)); return VK_KERNEL_FROM_STR(noop_shader_name); } @@ -48,13 +47,13 @@ PrepackNode::PrepackNode( } api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) { - vTensorPtr packed = graph->get_tensor(packed_); - - // If no TensorRef is provided, create a staging buffer of zeros according to - // the vkapi::vTensor metadata. + // If no TensorRef is provided, create a staging buffer of zeros based on the + // Tensor metadata. if (graph->val_is_none(tref_)) { - size_t numel = utils::multiply_integers(packed->sizes()); - api::StagingBuffer staging(graph->context(), packed->dtype(), numel); + const std::vector packed_sizes = graph->sizes_of(packed_); + size_t numel = utils::multiply_integers(packed_sizes); + api::StagingBuffer staging( + graph->context(), graph->dtype_of(packed_), numel); staging.set_staging_zeros(); return staging; } @@ -80,7 +79,6 @@ void PrepackNode::encode(ComputeGraph* graph) { context->check_device_capabilities(shader_); - vTensorPtr packed = graph->get_tensor(packed_); api::StagingBuffer staging = create_staging_buffer(graph); std::unique_lock cmd_lock = context->dispatch_lock(); @@ -101,8 +99,8 @@ void PrepackNode::encode(ComputeGraph* graph) { shader_, local_workgroup_size_, spec_vars_, push_constants_offset); uint32_t idx = 0; - bind_tensor_to_descriptor_set( - *packed, + graph->bind_tensor_to_descriptor_set( + packed_, pipeline_barrier, vkapi::MemoryAccessType::WRITE, descriptor_set, @@ -128,8 +126,8 @@ void PrepackNode::encode(ComputeGraph* graph) { vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( noop_shader_, utils::WorkgroupSize(1, 1, 1)); - bind_tensor_to_descriptor_set( - *packed, + graph->bind_tensor_to_descriptor_set( + packed_, pipeline_barrier, vkapi::MemoryAccessType::READ, descriptor_set, diff --git a/backends/vulkan/runtime/graph/ops/impl/Arange.cpp b/backends/vulkan/runtime/graph/ops/impl/Arange.cpp index 490def4860a..ebfadbb05cb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arange.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Arange.cpp @@ -20,22 +20,22 @@ void resize_arange_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); + const ValueRef out = args.at(0).refs.at(0); int start_val = 0; int step_val = 1; - if (!graph->val_is_none(extra_args[0])) { - start_val = graph->extract_scalar(extra_args[0]); + if (!graph->val_is_none(extra_args.at(0))) { + start_val = graph->extract_scalar(extra_args.at(0)); } - int end_val = graph->extract_scalar(extra_args[1]); - if (!graph->val_is_none(extra_args[2])) { - step_val = graph->extract_scalar(extra_args[2]); + const int end_val = graph->extract_scalar(extra_args.at(1)); + if (!graph->val_is_none(extra_args.at(2))) { + step_val = graph->extract_scalar(extra_args.at(2)); } - std::vector out_sizes = { + const std::vector out_sizes = { utils::div_up(end_val - start_val, step_val)}; - out->virtual_resize(out_sizes); + graph->virtual_resize(out, out_sizes); } void check_arange_input( @@ -82,11 +82,9 @@ void add_arange_node( } } - vTensorPtr t_out = graph.get_tensor(out); - std::string kernel_name("arange"); kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -96,7 +94,7 @@ void add_arange_node( // Inputs and Outputs {{out, vkapi::kWrite}}, // Shader params buffers - {t_out->sizes_ubo(), + {graph.sizes_ubo(out), graph.create_params_buffer(start_val), graph.create_params_buffer(step_val)}, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp index 81cbd62d90c..dcadcf80e42 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -46,44 +46,42 @@ void add_native_batch_norm_node( ValueRef var_ref, ValueRef eps_ref, ValueRef out_tuple_ref) { - std::vector in_sizes = graph.get_tensor(in_ref)->sizes(); - std::vector out_sizes = graph.get_tensor(in_ref)->sizes(); + const std::vector in_sizes = graph.sizes_of(in_ref); + const std::vector out_sizes = graph.sizes_of(in_ref); VK_CHECK_COND(in_sizes.size() == 4, "BatchNorm only support 4d tensor"); VK_CHECK_COND(out_sizes.size() == 4, "BatchNorm only support 4d tensor"); // Only the first element of the return value is propagated. The remaining 2 // elements are zero-size dummy tensor. - ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0); + const ValueRef out_ref = graph.get_value_list(out_tuple_ref)->at(0); - utils::StorageType stype = graph.storage_type_of(out_ref); + const utils::StorageType stype = graph.storage_type_of(out_ref); - int64_t num_channels = dim_at(in_sizes); + const int64_t num_channels = dim_at(in_sizes); - ValueRef arg_weight = + const ValueRef arg_weight = check_and_prepack_arg(graph, weight_ref, stype, num_channels, "weight"); - ValueRef arg_bias = + const ValueRef arg_bias = check_and_prepack_arg(graph, bias_ref, stype, num_channels, "bias"); - ValueRef arg_mean = + const ValueRef arg_mean = check_and_prepack_arg(graph, mean_ref, stype, num_channels, "mean"); - ValueRef arg_var = + const ValueRef arg_var = check_and_prepack_arg(graph, var_ref, stype, num_channels, "var"); - float epsilon = graph.extract_scalar(eps_ref); - - vTensorPtr t_in = graph.get_tensor(in_ref); + const float epsilon = graph.extract_scalar(eps_ref); VK_CHECK_COND(!graph.val_is_tref(out_ref), "Output should not be tref"); - vTensorPtr t_out = graph.get_tensor(out_ref); + const std::vector out_tensor_sizes = graph.sizes_of(out_ref); VK_CHECK_COND( - dim_at(t_out->sizes()) == num_channels, + dim_at(out_tensor_sizes) == num_channels, "out channel must match in channel"); std::string kernel_name = "batchnorm"; - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out_ref)); - int32_t num_texel_per_batch = - utils::div_up_4((dim_at(t_in->sizes()))); + const int32_t num_texel_per_batch = + utils::div_up_4((dim_at(in_sizes))); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -92,7 +90,7 @@ void add_native_batch_norm_node( graph.create_local_wg_size(out_ref), {{out_ref, vkapi::kWrite}, {{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, vkapi::kRead}}, - {t_out->logical_limits_ubo(), + {graph.logical_limits_ubo(out_ref), graph.create_params_buffer(epsilon), graph.create_params_buffer(num_texel_per_batch)}, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 18a1aacf323..6e9baafd45f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -19,13 +19,20 @@ namespace vkcompute { void check_binary_op_args( - const api::vTensor& self, - const api::vTensor& other, - const api::vTensor& out) { - VK_CHECK_COND(check_same_packed_dim(self, other, out)); + ComputeGraph& graph, + const ValueRef self, + const ValueRef other, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(self) == graph.packed_dim_of(other)); + VK_CHECK_COND(graph.packed_dim_of(self) == graph.packed_dim_of(out)); + + const std::vector self_sizes = graph.sizes_of(self); + const std::vector other_sizes = graph.sizes_of(other); + const std::vector out_sizes = graph.sizes_of(out); + std::vector broadcasted_sizes = - calculate_broadcasted_output_size(self, other); - VK_CHECK_COND(out.sizes() == broadcasted_sizes); + calculate_broadcasted_output_size(self_sizes, other_sizes); + VK_CHECK_COND(out_sizes == broadcasted_sizes); } void resize_binary_op_node( @@ -33,16 +40,18 @@ void resize_binary_op_node( const std::vector& args, const std::vector& resize_args) { (void)resize_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); + const ValueRef out = args.at(0).refs.at(0); // TODO(T183442143): Verify tensors are broadcastable. - vTensorPtr self = graph->get_tensor(args[1].refs[0]); - vTensorPtr other = graph->get_tensor(args[1].refs[1]); + const ValueRef self = args.at(1).refs.at(0); + const ValueRef other = args.at(1).refs.at(1); - std::vector new_out_sizes = - calculate_broadcasted_output_size(*self, *other); + const std::vector self_sizes = graph->sizes_of(self); + const std::vector other_sizes = graph->sizes_of(other); + const std::vector new_out_sizes = + calculate_broadcasted_output_size(self_sizes, other_sizes); - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } void add_binary_op_texture_node( @@ -55,11 +64,7 @@ void add_binary_op_texture_node( ValueRef arg1 = prepack_standard_like(graph, in1, out, true); ValueRef arg2 = prepack_standard_like(graph, in2, out, true); - vTensorPtr t_in1 = graph.get_tensor(arg1); - vTensorPtr t_in2 = graph.get_tensor(arg2); - vTensorPtr t_out = graph.get_tensor(out); - - check_binary_op_args(*t_in1, *t_in2, *t_out); + check_binary_op_args(graph, arg1, arg2, out); float alpha_val = 1.0f; // String is checked since floor_div passes in an unused string argument in @@ -71,12 +76,12 @@ void add_binary_op_texture_node( const struct BinaryOpsParams { const utils::ivec2 broadcast_params; const float alpha_val; - } binary_ops_params{create_broadcast_params(*t_in1, *t_in2), alpha_val}; + } binary_ops_params{create_broadcast_params(graph, arg1, arg2), alpha_val}; std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); kernel_name += op_name; - add_storage_type_suffix(kernel_name, *t_out); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(in1)); graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -94,7 +99,9 @@ void add_binary_op_texture_node( graph.sizes_pc_of(arg2), PushConstantDataInfo(&binary_ops_params, sizeof(binary_ops_params))}}, // Specialization Constants - {t_out->hashed_layout(), t_in1->hashed_layout(), t_in2->hashed_layout()}, + {graph.hashed_layout_of(out), + graph.hashed_layout_of(arg1), + graph.hashed_layout_of(arg2)}, // Resize Args {}, // Resizing Logic diff --git a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp index fcbac2df0fc..04e74af4e0c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp @@ -24,12 +24,12 @@ void resize_clone_node( const std::vector& args, const std::vector& resize_args) { (void)resize_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); // TODO: support for when dimensionality doesn't match, i.e. clone is used to // implement squeeze. - if (out->dim() == in->dim()) { - out->virtual_resize(in->sizes()); + if (graph->dim_of(out) == graph->dim_of(in)) { + graph->virtual_resize(out, graph->sizes_of(in)); } } @@ -37,10 +37,8 @@ void add_clone_node( ComputeGraph& graph, const ValueRef in, const ValueRef out) { - vTensorPtr t_out = graph.get_tensor(out); - std::string kernel_name = "clone"; - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -50,7 +48,7 @@ void add_clone_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Parameter Buffers - {t_out->logical_limits_ubo()}, + {graph.logical_limits_ubo(out)}, // Push Constants {}, // Specialization Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index d85bd9d841e..25b4d85be68 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -23,19 +23,20 @@ void resize_conv2d_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); - size_t ndim = self->sizes().size(); + size_t ndim = graph->dim_of(self); std::vector new_out_sizes(ndim); - const bool transposed = graph->get_bool(extra_args[4]); + const bool transposed = graph->get_bool(extra_args.at(4)); + std::vector self_sizes = graph->sizes_of(self); // Batch, Channel if (ndim == 4) { - new_out_sizes.at(ndim - 4) = self->sizes().at(ndim - 4); + new_out_sizes.at(ndim - 4) = self_sizes.at(ndim - 4); } - TensorRefPtr weight_ref = graph->get_tref(extra_args[0]); + TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0)); const auto& weight_sizes = weight_ref->sizes; new_out_sizes.at(ndim - 3) = transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4); @@ -43,44 +44,44 @@ void resize_conv2d_node( // Height, Width const auto& new_out_sizes_hw = calc_out_sizes_hw( *graph, - self->sizes(), - extra_args[0], + self_sizes, + extra_args.at(0), /*kernel_size_only = */ false, - {extra_args[1], extra_args[2], extra_args[3], extra_args[5]}, + {extra_args.at(1), extra_args.at(2), extra_args.at(3), extra_args.at(5)}, transposed); new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0); new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1); - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } void resize_conv1d_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); - TensorRefPtr weight_ref = graph->get_tref(extra_args[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0)); - int64_t stride_size = graph->get_int_list(extra_args[1])->at(0); - int64_t padding_size = graph->get_int_list(extra_args[2])->at(0); - int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0); + const int64_t stride_size = graph->get_int_list(extra_args.at(1))->at(0); + const int64_t padding_size = graph->get_int_list(extra_args.at(2))->at(0); + const int64_t dilation_size = graph->get_int_list(extra_args.at(3))->at(0); const std::vector& weight_sizes = weight_ref->sizes; - const std::vector& in_sizes = self->sizes(); - size_t ndim = in_sizes.size(); + const std::vector in_sizes = graph->sizes_of(self); + const size_t ndim = in_sizes.size(); std::vector new_out_sizes(ndim); - int64_t kernel_size = weight_sizes.at(2); - int64_t in_length = in_sizes.at(2); + const int64_t kernel_size = weight_sizes.at(2); + const int64_t in_length = in_sizes.at(2); new_out_sizes.at(0) = in_sizes.at(0); new_out_sizes.at(1) = weight_sizes.at(0); new_out_sizes.at(2) = calc_out_size( in_length, kernel_size, stride_size, padding_size, dilation_size, false); - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } ValueRef prepack_biases( @@ -95,9 +96,8 @@ ValueRef prepack_biases( ValueRef v = graph.add_tensor( {out_channels}, graph.dtype_of(weight), storage_type, memory_layout); - vTensorPtr t = graph.get_tensor(v); - vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(*t); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(graph, v); graph.prepack_nodes().emplace_back(new PrepackNode( graph, @@ -108,7 +108,7 @@ ValueRef prepack_biases( v, {}, // Specialization constants - {t->hashed_layout()}, + {graph.hashed_layout_of(v)}, {graph.sizes_pc_of(v)})); return v; @@ -123,7 +123,7 @@ enum class Conv2dMethod : uint8_t { vkapi::ShaderInfo get_conv2d_shader( ComputeGraph& graph, - const api::vTensor& t_out, + const ValueRef out, const bool prepack_weights, const Conv2dMethod method, const ValueRef weight, @@ -167,7 +167,7 @@ vkapi::ShaderInfo get_conv2d_shader( } else if (clamp_out) { kernel_name += "_clamp"; } - add_dtype_suffix(kernel_name, t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); return VK_KERNEL_FROM_STR(kernel_name); } @@ -206,10 +206,9 @@ ValueRef prepack_weights( graph.dtype_of(vref), utils::kTexture2D, utils::kChannelsPacked); - vTensorPtr t = graph.get_tensor(v); vkapi::ShaderInfo shader = - get_conv2d_shader(graph, *t, /*prepack_weights = */ true, method, vref); + get_conv2d_shader(graph, v, /*prepack_weights = */ true, method, vref); const auto original_sizes_pc = utils::make_ivec4(original_sizes, /*reverse = */ true); @@ -222,16 +221,19 @@ ValueRef prepack_weights( v, {}, // Specialization constants - {SV(t->packed_dim())}, + {graph.packed_dim_of(v)}, {graph.sizes_pc_of(v), PushConstantDataInfo(&original_sizes_pc, sizeof(original_sizes_pc))})); return v; } -void check_conv_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); +void check_conv_args( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); } struct Conv2dParams final { @@ -365,12 +367,12 @@ void add_conv2d_node( /* storage_type = */ utils::kTexture2D, /* memory_layout = */ utils::kWidthPacked); - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - if (t_in->sizes().at(0) > 1) { + const std::vector in_sizes = graph.sizes_of(in); + if (in_sizes.at(0) > 1) { VK_THROW("conv2d: input batch size > 1 is not supported yet!"); } - check_conv_args(*t_in, *t_out); + + check_conv_args(graph, in, out); Kernel2dParams kernel_params = create_kernel2d_params( graph, @@ -396,7 +398,7 @@ void add_conv2d_node( vkapi::ShaderInfo shader = get_conv2d_shader( graph, - *t_out, + out, /*prepack_weights = */ false, method, weight_data, @@ -476,8 +478,8 @@ void add_conv2d_node( }; } else { param_buffers = { - t_out->logical_limits_ubo(), - t_in->sizes_ubo(), + graph.logical_limits_ubo(out), + graph.sizes_ubo(in), graph.create_params_buffer(kernel_params), graph.create_params_buffer(extra_params), graph.create_params_buffer(out_params), @@ -540,17 +542,13 @@ void add_conv1d_node( out_max_val = graph.extract_scalar(out_max); } - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_weight = graph.get_tensor(arg_weight); - vTensorPtr t_bias = graph.get_tensor(arg_bias); - vTensorPtr t_out = graph.get_tensor(out); const int64_t groups_val = graph.get_int(groups); - std::vector in_sizes = t_in->sizes(); - std::vector weight_sizes = t_weight->sizes(); - std::vector out_sizes = t_out->sizes(); + const std::vector in_sizes = graph.sizes_of(in); + const std::vector weight_sizes = graph.sizes_of(arg_weight); + const std::vector out_sizes = graph.sizes_of(out); - check_conv_args(*t_in, *t_out); + check_conv_args(graph, in, out); const int32_t in_channels = in_sizes.at(1); const int32_t out_channels = weight_sizes.at(0); @@ -587,7 +585,7 @@ void add_conv1d_node( } kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -598,18 +596,18 @@ void add_conv1d_node( {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers { - t_out->logical_limits_ubo(), - t_in->sizes_ubo(), + graph.logical_limits_ubo(out), + graph.sizes_ubo(in), graph.create_params_buffer(kernel_params), graph.create_params_buffer(out_params), }, // Push Constants {}, // Specialization Constants - {t_out->hashed_layout(), - t_in->hashed_layout(), - t_weight->hashed_layout(), - t_bias->hashed_layout()}, + {graph.hashed_layout_of(out), + graph.hashed_layout_of(in), + graph.hashed_layout_of(arg_weight), + graph.hashed_layout_of(arg_bias)}, // Resize Args {weight, stride, padding, dilation}, // Resizing Logic @@ -617,7 +615,7 @@ void add_conv1d_node( } void conv(ComputeGraph& graph, const std::vector& args) { - int64_t in_ndim = graph.get_tensor(args[0])->sizes().size(); + int64_t in_ndim = graph.dim_of(args[0]); if (in_ndim == 4) { if (args.size() == 10) { // ordinary conv2d diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp index c4f37bd9386..27e8c81ba9e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -28,13 +28,10 @@ void add_copy_offset_node( const ValueRef out, bool calc_out_pos_using_src_chnl, bool calc_in_pos_using_dst_chnl) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - std::string kernel_name = "copy_offset"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); - add_storage_type_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); auto shader = VK_KERNEL_FROM_STR(kernel_name); @@ -75,27 +72,27 @@ void add_copy_packed_dim_offset_node( const ivec4& src_offset, const ivec4& dst_offset, const ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - // Check the packed dimension is same for both tensors, also check if the // packed dimension is Width or Height. Since the function does not support // channel packing. VK_CHECK_COND( - check_same_packed_dim(*t_in, *t_out) && - (check_packed_dim_is(*t_in, WHCN::kWidthDim) || - check_packed_dim_is(*t_in, WHCN::kHeightDim))); + graph.packed_dim_of(in) == graph.packed_dim_of(out) && + (graph.packed_dim_of(in) == WHCN::kWidthDim || + graph.packed_dim_of(in) == WHCN::kHeightDim)); std::string kernel_name = "copy_packed_dim_offset"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + const std::vector in_sizes = graph.sizes_of(in); + const std::vector out_sizes = graph.sizes_of(out); // A copy of range with the last element set to batch size of the input tensor ivec4 final_range = { - range[0], range[1], range[2], dim_at(t_in->sizes(), kBatch4D)}; - ivec3 global_wg_size = t_out->logical_limits(); + range[0], range[1], range[2], dim_at(in_sizes, kBatch4D)}; + ivec3 global_wg_size = graph.logical_limits_of(out); - const auto packed_dim = t_in->packed_dim(); + const auto packed_dim = graph.packed_dim_of(in); // The starting offset in a texel where this tensor will start copying from const auto src_lane_offset = src_offset[packed_dim] & 0x3; // The starting offset in a texel where this tensor will start copying to @@ -106,16 +103,14 @@ void add_copy_packed_dim_offset_node( // remaining lanes from current source Hence (4 - src_lane_offset) is added // to tensor size in packed dimension const auto src_packed_size = utils::div_up_4( - (4 - src_lane_offset) + - dim_at(t_out->sizes(), normalize_to_dim_index(*t_out, packed_dim))); + (4 - src_lane_offset) + utils::val_at(-packed_dim, out_sizes)); // The total packed texels this tensor will be copied to // The first texel of tensor data in packed dimension will be copied to // remaining lanes from previous write Hence (4 - dst_lane_offset) is added // to tensor size in packed dimension const auto dst_packed_size = utils::div_up_4( - (4 - dst_lane_offset) + - dim_at(t_in->sizes(), normalize_to_dim_index(*t_in, packed_dim))); + (4 - dst_lane_offset) + utils::val_at(-packed_dim, in_sizes)); // If the starting src offset is not 0, and the total packed texels is // greater than the source texel range @@ -169,20 +164,17 @@ void add_copy_channel_offset_node( int32_t src_channel_offset, int32_t dst_channel_offset, const ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - // Likely need to prepad these numbers. - std::vector in_sizes = t_in->sizes(); - std::vector out_sizes = t_out->sizes(); + const std::vector in_sizes = graph.sizes_of(in); + const std::vector out_sizes = graph.sizes_of(out); - VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim)); + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); // NOTE: This function should be able to support 1d and 2d tensors when // range=1, src_offset=dst_offset=1. - VK_CHECK_COND(t_in->dim() >= 3, "Src dim should be at least 3"); - VK_CHECK_COND(t_out->dim() >= 3, "Dst dim should be at least 3"); + VK_CHECK_COND(graph.dim_of(in) >= 3, "Src dim should be at least 3"); + VK_CHECK_COND(graph.dim_of(out) >= 3, "Dst dim should be at least 3"); VK_CHECK_COND( dim_at(in_sizes) >= src_channel_offset + channel_range, @@ -212,7 +204,7 @@ void add_copy_channel_offset_node( std::string kernel_name = "copy_channel_offset"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); int32_t out_channels = dim_at(out_sizes); diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 61fd76145a4..0822dcb05f3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -23,10 +23,11 @@ void resize_dequantize_node( const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - out->virtual_resize(in->sizes()); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); } utils::uvec3 dequantize_per_channel_local_wg_size( diff --git a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp index 85c80e01c27..b5a2f20cf4b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Embedding.cpp @@ -23,15 +23,16 @@ using utils::GPUMemoryLayout; using utils::StorageType; void check_embedding_args( - const api::vTensor& weight, - const api::vTensor& in, - const api::vTensor& out) { + ComputeGraph& graph, + const ValueRef weight, + const ValueRef in, + const ValueRef out) { // The packing logic may not be trivial here. Input and output are Channel // Packed, which is default for the Vulkan backend. However, weight vector is // height-packed instead of channel-packed for space reason. - VK_CHECK_COND(check_packed_dim_is(weight, WHCN::kHeightDim)); - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); + VK_CHECK_COND(graph.packed_dim_of(weight) == WHCN::kHeightDim); + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); } void add_embedding_node( @@ -39,15 +40,11 @@ void add_embedding_node( ValueRef weight, ValueRef in, ValueRef out) { - vTensorPtr t_weight = graph.get_tensor(weight); - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - - check_embedding_args(*t_weight, *t_in, *t_out); + check_embedding_args(graph, weight, in, out); std::string kernel_name = "embedding"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -56,14 +53,14 @@ void add_embedding_node( graph.create_local_wg_size(out), {{out, vkapi::kWrite}, {{in, weight}, vkapi::kRead}}, { - t_out->sizes_ubo(), + graph.sizes_ubo(out), }, // Push Constants {}, // Specialization Constants - {t_out->hashed_layout(), - t_in->hashed_layout(), - t_weight->hashed_layout()}, + {graph.hashed_layout_of(out), + graph.hashed_layout_of(in), + graph.hashed_layout_of(weight)}, // Resize Args {}, // Resizing Logic diff --git a/backends/vulkan/runtime/graph/ops/impl/Flip.cpp b/backends/vulkan/runtime/graph/ops/impl/Flip.cpp index 04aac2484ac..6679bfe32f5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Flip.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Flip.cpp @@ -15,9 +15,12 @@ namespace vkcompute { -void check_flip_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); +void check_flip_args( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); } void resize_flip_node( @@ -25,10 +28,10 @@ void resize_flip_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - out->virtual_resize(in->sizes()); + graph->virtual_resize(out, graph->sizes_of(in)); } utils::ivec4 create_whcn_bitmap( @@ -48,15 +51,13 @@ void add_flip_node( const ValueRef in, const std::vector& dim_list, const ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - check_flip_args(*t_in, *t_out); + check_flip_args(graph, in, out); - const auto dim_bitmap = create_whcn_bitmap(dim_list, t_in->dim()); + const auto dim_bitmap = create_whcn_bitmap(dim_list, graph.dim_of(in)); std::string kernel_name("flip"); kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index 3ed18445463..2fa22312745 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -19,30 +19,28 @@ void resize_full_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); + const ValueRef out = args.at(0).refs.at(0); std::vector out_sizes; - if (graph->val_is_tensor(extra_args[0])) { - out_sizes = graph->get_tensor(extra_args[0])->sizes(); + if (graph->val_is_tensor(extra_args.at(0))) { + out_sizes = graph->sizes_of(extra_args.at(0)); } else { - out_sizes = *graph->get_int_list(extra_args[0]); + out_sizes = *graph->get_int_list(extra_args.at(0)); } - out->virtual_resize(out_sizes); + graph->virtual_resize(out, out_sizes); } -// size_or_in is IntListPtr when op is full and vTensorPtr if op is full_like void add_full_node( ComputeGraph& graph, const ValueRef size_or_in, const ValueRef fill_value, const ValueRef out) { float fill_value_val = graph.extract_scalar(fill_value); - vTensorPtr t_out = graph.get_tensor(out); std::string kernel_name("full"); kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -52,11 +50,11 @@ void add_full_node( // Inputs and Outputs {{out, vkapi::kWrite}}, // Shader params buffers - {t_out->sizes_ubo(), graph.create_params_buffer(fill_value_val)}, + {graph.sizes_ubo(out), graph.create_params_buffer(fill_value_val)}, // Push Constants {}, // Specialization Constants - {SV(t_out->packed_dim())}, + {graph.packed_dim_of(out)}, // Resize Args {size_or_in}, // Resizing Logic diff --git a/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp index 0624020c872..620613fdfb8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp @@ -23,13 +23,13 @@ void resize_grid_priors_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(extra_args[0]); - std::vector in_sizes = in->sizes(); - int64_t height = in_sizes.at(in_sizes.size() - 2); - int64_t width = in_sizes.at(in_sizes.size() - 1); - std::vector sizes = {height * width, 2}; - out->virtual_resize(sizes); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = extra_args.at(0); + const std::vector in_sizes = graph->sizes_of(in); + const int64_t height = in_sizes.at(in_sizes.size() - 2); + const int64_t width = in_sizes.at(in_sizes.size() - 1); + const std::vector sizes = {height * width, 2}; + graph->virtual_resize(out, sizes); } void add_grid_priors_node( @@ -38,16 +38,14 @@ void add_grid_priors_node( const ValueRef& stride_ref, const ValueRef& offset_ref, const ValueRef& out) { - vTensorPtr t_out = graph.get_tensor(out); - vTensorPtr t_in = graph.get_tensor(in); - int32_t stride = graph.extract_scalar(stride_ref); - float offset = graph.extract_scalar(offset_ref); + const int32_t stride = graph.extract_scalar(stride_ref); + const float offset = graph.extract_scalar(offset_ref); std::string kernel_name = "grid_priors"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); - GridPriorsParam param = {stride, offset}; + const GridPriorsParam param = {stride, offset}; graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -59,8 +57,8 @@ void add_grid_priors_node( }, // Shader params buffers { - t_in->sizes_ubo(), - t_out->sizes_ubo(), + graph.sizes_ubo(in), + graph.sizes_ubo(out), graph.create_params_buffer(param), }, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp index 8d2a848b0c4..368b95c9d3b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/GroupNorm.cpp @@ -17,14 +17,6 @@ namespace vkcompute { -std::vector calc_group_norm_mean_sizes( - api::vTensor& self, - const int64_t group) { - const std::vector& input_sizes = self.sizes(); - const int64_t N = input_sizes.at(0); - return {N, group}; -} - utils::uvec3 group_norm_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, diff --git a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp index 8203829c50f..86faabd48d5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp @@ -18,12 +18,13 @@ namespace vkcompute { void check_index_select_args( - const api::vTensor& in, - const api::vTensor& idx, - const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(idx, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); + ComputeGraph& graph, + const ValueRef in, + const ValueRef idx, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(idx) == WHCN::kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); } void add_index_select_channel_node( @@ -31,15 +32,11 @@ void add_index_select_channel_node( ValueRef in, ValueRef idx, ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_idx = graph.get_tensor(idx); - vTensorPtr t_out = graph.get_tensor(out); - - check_index_select_args(*t_in, *t_idx, *t_out); + check_index_select_args(graph, in, idx, out); std::string kernel_name = "index_select_channel"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -47,7 +44,7 @@ void add_index_select_channel_node( graph.create_global_wg_size(out), graph.create_local_wg_size(out), {{out, vkapi::kWrite}, {{in, idx}, vkapi::kRead}}, - {t_out->sizes_ubo(), t_in->sizes_ubo()}, + {graph.sizes_ubo(out), graph.sizes_ubo(in)}, // Push Constants {}, // Specialization Constants @@ -64,14 +61,16 @@ struct IndexSelectParams final { }; IndexSelectParams create_index_select_params( + ComputeGraph& graph, const int64_t dim_idx, - const api::vTensor& in) { + const ValueRef in) { if (dim_idx == kWidth4D) { return {0, 1}; } else if (dim_idx == kHeight4D) { return {1, 1}; } else if (dim_idx == kBatch4D) { - int64_t n_channels = dim_at(in.sizes(), kChannel4D); + const std::vector in_sizes = graph.sizes_of(in); + int64_t n_channels = dim_at(in_sizes, kChannel4D); int64_t stride = utils::div_up_4(n_channels); return {2, static_cast(stride)}; } else { @@ -85,17 +84,13 @@ void add_index_select_node( const int64_t dim_idx, ValueRef idx, ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_idx = graph.get_tensor(idx); - vTensorPtr t_out = graph.get_tensor(out); + check_index_select_args(graph, in, idx, out); - check_index_select_args(*t_in, *t_idx, *t_out); - - IndexSelectParams params = create_index_select_params(dim_idx, *t_in); + IndexSelectParams params = create_index_select_params(graph, dim_idx, in); std::string kernel_name = "index_select"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -103,7 +98,7 @@ void add_index_select_node( graph.create_global_wg_size(out), graph.create_local_wg_size(out), {{out, vkapi::kWrite}, {{in, idx}, vkapi::kRead}}, - {t_out->sizes_ubo(), graph.create_params_buffer(params)}, + {graph.sizes_ubo(out), graph.create_params_buffer(params)}, // Push Constants {}, // Specialization Constants @@ -115,10 +110,12 @@ void add_index_select_node( } int64_t get_dim_idx(ComputeGraph& graph, ValueRef in, ValueRef dim_ref) { - vTensorPtr t_in = graph.get_tensor(in); int64_t dim = graph.extract_scalar(dim_ref); - dim = normalize(dim, t_in->dim()); - return normalize_to_dim_index(*t_in, dim); + const int64_t ndim = graph.dim_of(in); + dim = normalize(dim, ndim); + + // Convert to DimIndex - this replicates normalize_to_dim_index logic + return dim < 0 ? dim : dim - ndim; } void index_select(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 14ed9c84a32..a58444a7830 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -54,29 +54,31 @@ void resize_addmm_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); - vTensorPtr self = graph->get_tensor(args[1].refs[2]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); - bool mat2_is_transposed = graph->get_bool(extra_args[0]); + const bool mat2_is_transposed = graph->get_bool(extra_args.at(0)); - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes()) - : utils::val_at(-1, mat2->sizes()); + const std::vector mat1_sizes = graph->sizes_of(mat1); + const std::vector mat2_sizes = graph->sizes_of(mat2); + + const int out_cols = utils::val_at(-2, mat1_sizes); + const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2_sizes) + : utils::val_at(-1, mat2_sizes); std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { + if (mat1_sizes.size() == 2) { new_out_sizes.resize(2); new_out_sizes.at(0) = out_cols; new_out_sizes.at(1) = out_rows; } else { - new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(0) = mat1_sizes.at(0); new_out_sizes.at(1) = out_cols; new_out_sizes.at(2) = out_rows; } - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } struct Params final { diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 73a625f3adf..0f5556060a2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -39,22 +39,25 @@ void resize_matmul_node( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); + + bool mat2_is_transposed = graph->get_bool(resize_args.at(0)); - bool mat2_is_transposed = graph->get_bool(resize_args[0]); + const std::vector mat1_sizes = graph->sizes_of(mat1); + const std::vector mat2_sizes = graph->sizes_of(mat2); - const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes()) - : utils::val_at(-1, mat2->sizes()); + const int out_cols = utils::val_at(-2, mat1_sizes); + const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2_sizes) + : utils::val_at(-1, mat2_sizes); - const int64_t out_dim = out->dim(); - std::vector new_out_sizes(mat1->sizes()); + const int64_t out_dim = graph->dim_of(out); + std::vector new_out_sizes(mat1_sizes); new_out_sizes.at(out_dim - 1) = out_rows; new_out_sizes.at(out_dim - 2) = out_cols; - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } /** diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 100d6e33931..99f945da535 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -18,10 +18,10 @@ namespace vkcompute { std::vector calc_out_mean_sizes( - api::vTensor& self, + const std::vector& self_sizes, int64_t normalized_shape_dim) { - std::vector output_size = self.sizes(); - int64_t self_dim = self.sizes().size(); + std::vector output_size = self_sizes; + int64_t self_dim = self_sizes.size(); for (int64_t i = 0; i < normalized_shape_dim; ++i) { output_size.at(self_dim - i - 1) = 1; } @@ -32,20 +32,21 @@ void resize_native_layer_norm_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mean = graph->get_tensor(args[0].refs[1]); - vTensorPtr rstd = graph->get_tensor(args[0].refs[2]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); - std::vector in_sizes = in->sizes(); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mean = args.at(0).refs.at(1); + const ValueRef rstd = args.at(0).refs.at(2); + const ValueRef in = args.at(1).refs.at(0); + const std::vector in_sizes = graph->sizes_of(in); - const auto normalized_shape_dim = graph->get_int_list(extra_args[0])->size(); + const auto normalized_shape_dim = + graph->get_int_list(extra_args.at(0))->size(); - std::vector mean_size = - calc_out_mean_sizes(*in, normalized_shape_dim); + const std::vector mean_size = + calc_out_mean_sizes(in_sizes, normalized_shape_dim); - out->virtual_resize(in_sizes); - mean->virtual_resize(mean_size); - rstd->virtual_resize(mean_size); + graph->virtual_resize(out, in_sizes); + graph->virtual_resize(mean, mean_size); + graph->virtual_resize(rstd, mean_size); } void add_native_layer_norm_node( @@ -74,16 +75,17 @@ void add_native_layer_norm_node( ValueRef arg_bias = prepack_standard_like(graph, bias_data, in); const auto out_val = graph.get_value_list(out); - vTensorPtr t_out = graph.get_tensor(out_val->at(0)); - vTensorPtr t_mean = graph.get_tensor(out_val->at(1)); - vTensorPtr t_input = graph.get_tensor(in); + const ValueRef out_tensor = out_val->at(0); + const ValueRef mean_tensor = out_val->at(1); + const ValueRef rstd_tensor = out_val->at(2); + float epsilon = graph.extract_scalar(eps); - VK_CHECK_COND(check_same_packed_dim(*t_input, *t_out)); + VK_CHECK_COND(check_same_packed_dim(graph, in, out_tensor)); - std::vector in_sizes = t_input->sizes(); + const std::vector in_sizes = graph.sizes_of(in); - utils::uvec3 global_size = t_out->logical_limits(); + utils::uvec3 global_size = graph.logical_limits_of(out_tensor); utils::uvec3 local_size; // Since the shader sets shared memory scale factor > 1, if dispatch is @@ -100,7 +102,7 @@ void add_native_layer_norm_node( std::string kernel_name("native_layer_norm"); kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -108,20 +110,20 @@ void add_native_layer_norm_node( global_size, local_size, // Inputs and Outputs - {{{out_val->at(0), out_val->at(1), out_val->at(2)}, vkapi::kWrite}, + {{{out_tensor, mean_tensor, rstd_tensor}, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers {}, // Push Constants { - graph.logical_limits_pc_of(out_val->at(0)), - graph.sizes_pc_of(out_val->at(0)), + graph.logical_limits_pc_of(out_tensor), + graph.sizes_pc_of(out_tensor), PushConstantDataInfo(&epsilon, sizeof(epsilon)), }, // Specialization Constants { - t_input->hashed_layout(), - t_out->hashed_layout(), + graph.hashed_layout_of(in), + graph.hashed_layout_of(out_tensor), }, // Resize Args {normalized_shape}, diff --git a/backends/vulkan/runtime/graph/ops/impl/Pad.cpp b/backends/vulkan/runtime/graph/ops/impl/Pad.cpp index 8f3ba7532a9..a10984eac78 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pad.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pad.cpp @@ -41,17 +41,17 @@ void resize_constant_pad_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); - IntListPtr pad_vec = graph->get_int_list(extra_args[0]); - std::vector in_size = self->sizes(); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + const IntListPtr pad_vec = graph->get_int_list(extra_args.at(0)); + std::vector in_size = graph->sizes_of(self); int dim = in_size.size() - 1; for (int i = 0; i < pad_vec->size(); i += 2) { in_size.at(dim) += pad_vec->at(i) + pad_vec->at(i + 1); dim--; } - out->virtual_resize(in_size); + graph->virtual_resize(out, in_size); } void add_constant_pad_nd_node( @@ -60,22 +60,20 @@ void add_constant_pad_nd_node( const ValueRef& pad, const ValueRef& fill_value, const ValueRef& out) { - float fill_value_val = graph.extract_scalar(fill_value); - IntListPtr pad_vec = graph.get_int_list(pad); - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); + const float fill_value_val = graph.extract_scalar(fill_value); + const IntListPtr pad_vec = graph.get_int_list(pad); std::string kernel_name = ""; - PadParam pad_param = creat_pad_param(*pad_vec); + const PadParam pad_param = creat_pad_param(*pad_vec); if (pad_vec->size() <= 4) { kernel_name = "pad_height_width"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); } else { kernel_name = "pad_channel"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); } graph.execute_nodes().emplace_back(new DispatchNode( @@ -86,8 +84,8 @@ void add_constant_pad_nd_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers - {t_out->sizes_ubo(), - t_in->sizes_ubo(), + {graph.sizes_ubo(out), + graph.sizes_ubo(in), graph.create_params_buffer(pad_param), graph.create_params_buffer(fill_value_val)}, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index e8afafa9a45..e74b9ec96a7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -17,44 +17,48 @@ namespace vkcompute { -void check_pool2d_args(const api::vTensor& in, const api::vTensor& out) { - VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim)); - VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim)); +void check_pool2d_args( + ComputeGraph& graph, + const ValueRef in, + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(in) == WHCN::kChannelsDim); + VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kChannelsDim); } void resize_pool2d_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - bool is_max_pool2d = extra_args[3] != kDummyValueRef; + bool is_max_pool2d = extra_args.at(3) != kDummyValueRef; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); - size_t ndim = self->sizes().size(); + const std::vector self_sizes = graph->sizes_of(self); + size_t ndim = self_sizes.size(); std::vector new_out_sizes(ndim); // Batch, Channel if (ndim == 4) { - new_out_sizes.at(ndim - 4) = self->sizes().at(ndim - 4); + new_out_sizes.at(ndim - 4) = self_sizes.at(ndim - 4); } - new_out_sizes.at(ndim - 3) = self->sizes().at(ndim - 3); + new_out_sizes.at(ndim - 3) = self_sizes.at(ndim - 3); // Height, Width const auto& new_out_sizes_hw = calc_out_sizes_hw( *graph, - self->sizes(), - extra_args[0], + self_sizes, + extra_args.at(0), /*kernel_size_only = */ true, - {extra_args[1], extra_args[2], extra_args[3], extra_args[4]}); + {extra_args.at(1), extra_args.at(2), extra_args.at(3), extra_args.at(4)}); new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0); new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1); - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); if (is_max_pool2d) { - vTensorPtr indices = graph->get_tensor(args[0].refs[1]); - indices->virtual_resize(new_out_sizes); + const ValueRef indices = args.at(0).refs.at(1); + graph->virtual_resize(indices, new_out_sizes); } } @@ -71,18 +75,16 @@ void add_max_pool2d_node( const ValueRef dilation, const ValueRef ceil_mode, const ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - const auto out_val = graph.get_value_list(out); - vTensorPtr t_out = graph.get_tensor(out_val->at(0)); + const ValueRef out_tensor = out_val->at(0); - check_pool2d_args(*t_in, *t_out); + check_pool2d_args(graph, in, out_tensor); - utils::uvec3 global_size = t_out->logical_limits(); + utils::uvec3 global_size = graph.logical_limits_of(out_tensor); utils::uvec3 local_size = adaptive_work_group_size(global_size); std::string kernel_name("max_pool2d"); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); Kernel2dParams kernel_params = create_kernel2d_params( graph, @@ -101,8 +103,8 @@ void add_max_pool2d_node( {{{out_val->at(0), out_val->at(1)}, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers { - t_out->logical_limits_ubo(), - t_in->sizes_ubo(), + graph.logical_limits_ubo(out_tensor), + graph.sizes_ubo(in), graph.create_params_buffer(kernel_params), }, // Push Constants @@ -150,16 +152,13 @@ void add_avg_pool2d_node( const ValueRef count_include_pad, const ValueRef divisor_override, const ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - - check_pool2d_args(*t_in, *t_out); + check_pool2d_args(graph, in, out); - utils::uvec3 global_size = t_out->logical_limits(); + utils::uvec3 global_size = graph.logical_limits_of(out); utils::uvec3 local_size = adaptive_work_group_size(global_size); std::string kernel_name("avg_pool2d"); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); Kernel2dParams kernel_params = create_kernel2d_params(graph, kernel_size, stride, padding); @@ -175,8 +174,8 @@ void add_avg_pool2d_node( // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers - {t_out->logical_limits_ubo(), - t_in->sizes_ubo(), + {graph.logical_limits_ubo(out), + graph.sizes_ubo(in), graph.create_params_buffer(kernel_params), graph.create_params_buffer(divisor_params)}, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 92719505a0f..d4d0ba30293 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -23,10 +23,11 @@ void resize_quantize_node( const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - out->virtual_resize(in->sizes()); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); } utils::uvec3 quantize_per_channel_local_wg_size( diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 07502a7a107..05a300bee4c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -55,30 +55,33 @@ void resize_linear_qcsnw_node( const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef qmat2 = args.at(1).refs.at(1); - const int out_cols = utils::val_at(-2, mat1->sizes()); - int out_rows = utils::val_at(-1, qmat2->sizes()); + const std::vector mat1_sizes = graph->sizes_of(mat1); + const std::vector qmat2_sizes = graph->sizes_of(qmat2); + + const int out_cols = utils::val_at(-2, mat1_sizes); + int out_rows = utils::val_at(-1, qmat2_sizes); // Byte dtype suggests 4-bit quantization in which case the weight tensor is // packed with 2 values per byte. - if (qmat2->dtype() == vkapi::kByte) { + if (graph->dtype_of(qmat2) == vkapi::kByte) { out_rows *= 2; } std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { + if (mat1_sizes.size() == 2) { new_out_sizes.resize(2); new_out_sizes.at(0) = out_cols; new_out_sizes.at(1) = out_rows; } else { - new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(0) = mat1_sizes.at(0); new_out_sizes.at(1) = out_cols; new_out_sizes.at(2) = out_rows; } - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } void add_linear_qcs8w_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp index 728d38c3e2d..e3443ca34e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp @@ -85,25 +85,28 @@ void resize_linear_qta8a_qga4w_node( const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); - vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); + + const std::vector mat1_sizes = graph->sizes_of(mat1); + const std::vector mat2_sizes = graph->sizes_of(mat2); - const int64_t out_cols = utils::val_at(-2, mat1->sizes()); - const int64_t out_rows = utils::val_at(-1, mat2->sizes()) * 2; + const int64_t out_cols = utils::val_at(-2, mat1_sizes); + const int64_t out_rows = utils::val_at(-1, mat2_sizes) * 2; std::vector new_out_sizes(3); - if (mat1->sizes().size() == 2) { + if (mat1_sizes.size() == 2) { new_out_sizes.resize(2); new_out_sizes.at(0) = out_cols; new_out_sizes.at(1) = out_rows; } else { - new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(0) = mat1_sizes.at(0); new_out_sizes.at(1) = out_cols; new_out_sizes.at(2) = out_rows; } - out->virtual_resize(new_out_sizes); + graph->virtual_resize(out, new_out_sizes); } /** diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index c0fd442ec50..38b8c51576c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -22,14 +22,15 @@ void resize_reduce_node( ComputeGraph* graph, const std::vector& args, const std::vector& resize_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - int32_t reduce_dim_nchw = graph->extract_scalar(resize_args.at(0)); + const int32_t reduce_dim_nchw = + graph->extract_scalar(resize_args.at(0)); - std::vector new_sizes = in->sizes(); + std::vector new_sizes = graph->sizes_of(in); new_sizes.at(normalize(reduce_dim_nchw, new_sizes.size())) = 1; - out->virtual_resize(new_sizes); + graph->virtual_resize(out, new_sizes); } utils::uvec3 reduce_global_wg_size( diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index f472e4dad0d..d7a2b7a8ca2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -20,39 +20,43 @@ namespace vkcompute { namespace { void check_args( - const api::vTensor& in, + ComputeGraph& graph, + const ValueRef in, const std::vector& repeats, - const api::vTensor& out) { - VK_CHECK_COND(check_same_packed_dim(in, out)); + const ValueRef out) { + VK_CHECK_COND(graph.packed_dim_of(in) == graph.packed_dim_of(out)); - VK_CHECK_COND(in.storage_type() == out.storage_type()); - if (in.storage_type() == utils::kTexture2D) { - VK_CHECK_COND(in.dim() <= 2); + VK_CHECK_COND(graph.storage_type_of(in) == graph.storage_type_of(out)); + if (graph.storage_type_of(in) == utils::kTexture2D) { + VK_CHECK_COND(graph.dim_of(in) <= 2); } - int64_t in_dim = in.dim(); + const int64_t in_dim = graph.dim_of(in); VK_CHECK_COND( in_dim <= repeats.size(), "Input tensor dim size must be not greater than the repeat argument's size"); + const std::vector in_sizes = graph.sizes_of(in); + const std::vector out_sizes = graph.sizes_of(out); + VK_CHECK_COND( - dim_at(in.sizes()) * dim_at(repeats) == - dim_at(out.sizes()), + dim_at(in_sizes) * dim_at(repeats) == + dim_at(out_sizes), "Output's width doesn't match input's width * repeat count"); VK_CHECK_COND( - dim_at(in.sizes()) * dim_at(repeats) == - dim_at(out.sizes()), + dim_at(in_sizes) * dim_at(repeats) == + dim_at(out_sizes), "Output's height doesn't match input's height * repeat count"); VK_CHECK_COND( - dim_at(in.sizes()) * dim_at(repeats) == - dim_at(out.sizes()), + dim_at(in_sizes) * dim_at(repeats) == + dim_at(out_sizes), "Output's channel doesn't match input's channel * repeat count"); VK_CHECK_COND( - dim_at(in.sizes()) * dim_at(repeats) == - dim_at(out.sizes()), + dim_at(in_sizes) * dim_at(repeats) == + dim_at(out_sizes), "Output's batch doesn't match input's batch * repeat count"); } @@ -65,15 +69,14 @@ void add_repeat_node( ValueRef out) { const std::vector repeats = *(graph.get_int_list(repeats_ref)); - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - check_args(*t_in, repeats, *t_out); + check_args(graph, in, repeats, out); + const std::vector in_sizes = graph.sizes_of(in); const utils::ivec4 src_dims{ - dim_at(t_in->sizes()), - dim_at(t_in->sizes()), - dim_at(t_in->sizes()), - dim_at(t_in->sizes())}; + dim_at(in_sizes), + dim_at(in_sizes), + dim_at(in_sizes), + dim_at(in_sizes)}; const utils::ivec4 dst_repeats{ dim_at(repeats), dim_at(repeats), @@ -82,10 +85,10 @@ void add_repeat_node( std::string kernel_name = "repeat"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); // A copy of range with the last element set to batch size of the input tensor - const utils::ivec3 wg_size = t_out->logical_limits(); + const utils::ivec3 wg_size = graph.logical_limits_of(out); const auto shader = VK_KERNEL_FROM_STR(kernel_name); diff --git a/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp b/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp index 5bfadf43160..ae2aeec10bf 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.cpp @@ -20,17 +20,17 @@ void resize_repeat_interleave_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - const int64_t nrepeats = graph->extract_scalar(extra_args[0]); - int64_t repeat_dim = graph->extract_scalar(extra_args[1]); + const int64_t nrepeats = graph->extract_scalar(extra_args.at(0)); + int64_t repeat_dim = graph->extract_scalar(extra_args.at(1)); - std::vector new_sizes = in->sizes(); + std::vector new_sizes = graph->sizes_of(in); repeat_dim = normalize(repeat_dim, new_sizes.size()); new_sizes.at(repeat_dim) *= nrepeats; - out->virtual_resize(new_sizes); + graph->virtual_resize(out, new_sizes); } void add_repeat_interleave_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 6057f1e183a..b194524c94e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -33,7 +33,7 @@ void resize_sdpa_out( int arg_idx = 0; const ValueRef q_projected = extra_args[arg_idx++]; const ValueRef out = extra_args[arg_idx++]; - graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); + graph->virtual_resize(out, graph->sizes_of(q_projected)); } void resize_flash_attention_out( @@ -49,7 +49,7 @@ void resize_flash_attention_out( const ValueRef q_projected = args.at(1).refs.at(0); // Resize output to match query dimensions - graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); + graph->virtual_resize(out, graph->sizes_of(q_projected)); } // Flash Attention implementation using single compute shader @@ -338,7 +338,7 @@ void resize_cache_slice_view_node( std::vector slice_sizes = get_cache_slice_sizes( *graph, extra_args[0], extra_args[1], extra_args[2]); - graph->get_tensor(extra_args[3])->virtual_resize(slice_sizes); + graph->virtual_resize(extra_args[3], slice_sizes); } void add_cache_slice_view_node( @@ -353,7 +353,7 @@ void add_cache_slice_view_node( // Initialize the slice to the maximum possible size to start slice_sizes.at(1) = max_seq_len; - graph.get_tensor(cache_sliced)->virtual_resize(slice_sizes); + graph.virtual_resize(cache_sliced, slice_sizes); graph.execute_nodes().emplace_back(new ExecuteNode( resize_cache_slice_view_node, @@ -489,7 +489,7 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { std::vector attn_weight_sizes = attn_weight_full_sizes; attn_weight_sizes.at(2) = graph.size_at(2, q_transposed); attn_weight_sizes.at(3) = graph.size_at(2, k_transposed); - graph.get_tensor(attn_weight)->virtual_resize(attn_weight_sizes); + graph.virtual_resize(attn_weight, attn_weight_sizes); // Calculate attention weight, which is a matmul of Q and K const ValueRef mat2_is_transposed = graph.add_scalar(false); @@ -502,7 +502,7 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { TmpTensor attn_weight_softmax( &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed)); - graph.get_tensor(attn_weight_softmax)->virtual_resize(attn_weight_sizes); + graph.virtual_resize(attn_weight_softmax, attn_weight_sizes); add_softmax_node(graph, attn_weight, width, attn_weight_softmax, false); // Calculate final output diff --git a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp index e37ef66434b..5e645e29e3d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp @@ -67,11 +67,11 @@ void resize_softmax_node( const std::vector& args, const std::vector& resize_args) { (void)resize_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - std::vector in_sizes = in->sizes(); - out->virtual_resize(in_sizes); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); } void add_softmax_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index 8002dadc538..f87af08ee69 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -23,23 +23,22 @@ void add_split_with_sizes_default_node( const std::vector& split_sizes, int64_t dim, ValueRef out_list_ref) { - vTensorPtr t_in = graph.get_tensor(in); + const ValueListPtr out_list = graph.get_value_list(out_list_ref); - ValueListPtr out_list = graph.get_value_list(out_list_ref); - - DimIndex dim_index = normalize_to_dim_index(*t_in, dim); + const int64_t input_ndim = graph.dim_of(in); + const DimIndex dim_index = dim < 0 ? static_cast(dim) + : static_cast(dim - input_ndim); VK_CHECK_COND(out_list->size() == split_sizes.size()); for (int split_idx = 0; split_idx < split_sizes.size(); split_idx++) { - int64_t split_size = split_sizes[split_idx]; - ValueRef out_ref = (*out_list)[split_idx]; + const int64_t split_size = split_sizes.at(split_idx); + const ValueRef out_ref = out_list->at(split_idx); - vTensorPtr t_out = graph.get_tensor(out_ref); - VK_CHECK_COND(dim_at(*t_out, dim_index) == split_size); + VK_CHECK_COND(dim_at(graph.sizes_of(out_ref), dim_index) == split_size); } - const auto packed_dim = t_in->packed_dim(); + const auto packed_dim = graph.packed_dim_of(in); const auto packed_dim_index = static_cast(kWidth4D - packed_dim); // Index of dimension to be concatenated in (w, h, c * b) coordinate system @@ -53,15 +52,14 @@ void add_split_with_sizes_default_node( // if splitting channels if (is_splitting_channel) { // set source offset w as channel size of the input tensor - src_offset[3] = dim_at(t_in->sizes(), kChannel4D); + src_offset[3] = dim_at(graph.sizes_of(in), kChannel4D); } for (ValueRef out_ref : *out_list) { // Doesn't need to use split_size since we have already verified that the // output tensor's size matches with the split_size. - vTensorPtr t_out = graph.get_tensor(out_ref); - const auto out_channel_size = dim_at(t_out->sizes(), kChannel4D); - utils::ivec3 range = t_out->logical_limits(); + const auto out_channel_size = dim_at(graph.sizes_of(out_ref), kChannel4D); + const utils::ivec3 range = graph.logical_limits_of(out_ref); if (dim_index == packed_dim_index) { // if splitting channels, use add_copy_channel_offset_node function as @@ -79,7 +77,8 @@ void add_split_with_sizes_default_node( dst_offset[3] = is_splitting_channel ? out_channel_size : 0; add_copy_packed_dim_offset_node( graph, in, range, src_offset, dst_offset, out_ref); - src_offset[dim_xyz_index] += dim_at(t_out->sizes(), packed_dim_index); + src_offset[dim_xyz_index] += + dim_at(graph.sizes_of(out_ref), packed_dim_index); } } else { // set destination offset w as channel size of the output tensor if @@ -117,13 +116,14 @@ void add_split_tensor_node( ValueRef split_size_ref, ValueRef dim_ref, ValueRef out) { - int64_t split_size = graph.extract_scalar(split_size_ref); - int64_t dim = graph.extract_scalar(dim_ref); - - vTensorPtr t_in = graph.get_tensor(in); - DimIndex dim_index = normalize_to_dim_index(*t_in, dim); - int64_t size = dim_at(*t_in, dim_index); - std::vector split_sizes(size / split_size, split_size); + const int64_t split_size = graph.extract_scalar(split_size_ref); + const int64_t dim = graph.extract_scalar(dim_ref); + + const int64_t input_ndim = graph.dim_of(in); + const DimIndex dim_index = dim < 0 ? static_cast(dim) + : static_cast(dim - input_ndim); + const int64_t size = dim_at(graph.sizes_of(in), dim_index); + const std::vector split_sizes(size / split_size, split_size); add_split_with_sizes_default_node(graph, in, split_sizes, dim, out); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index bfaad716059..5faeae3e21b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -27,7 +27,7 @@ void add_staging_to_tensor_node( VK_CHECK_COND(graph.val_is_staging(in_staging)); vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( - *graph.get_tensor(out_tensor), graph.int8_buffers_enabled()); + graph, out_tensor, graph.int8_buffers_enabled()); std::vector pcs; if (graph.is_buffer_storage(out_tensor)) { @@ -73,7 +73,7 @@ vkapi::ShaderInfo get_tensor_to_staging_shader( (void)resize_args; const ValueRef in_tensor = args.at(1).refs.at(0); return get_tensor_to_nchw_shader( - *graph->get_tensor(in_tensor), graph->int8_buffers_enabled()); + *graph, in_tensor, graph->int8_buffers_enabled()); } utils::uvec3 tensor_to_staging_global_wg_size( @@ -110,8 +110,8 @@ void add_tensor_to_staging_node( const ValueRef out_staging) { VK_CHECK_COND(graph.val_is_staging(out_staging)); - vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( - *graph.get_tensor(in_tensor), graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = + get_tensor_to_nchw_shader(graph, in_tensor, graph.int8_buffers_enabled()); std::vector pcs; if (graph.is_buffer_storage(in_tensor)) { @@ -151,8 +151,8 @@ void add_prepack_standard_node( const ValueRef tensor_data, const ValueRef tensor, const bool transpose_hw = false) { - vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( - *graph.get_tensor(tensor), graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = + get_nchw_to_tensor_shader(graph, tensor, graph.int8_buffers_enabled()); std::vector pcs; if (graph.is_buffer_storage(tensor)) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Tan.cpp b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp index 89c4a4d408f..307f774de5e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Tan.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp @@ -20,10 +20,11 @@ void resize_tan_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); - out->virtual_resize(self->sizes()); + const std::vector self_sizes = graph->sizes_of(self); + graph->virtual_resize(out, self_sizes); } void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) { diff --git a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp index d1145a925d4..b7e0218823a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp @@ -19,10 +19,10 @@ void resize_to_copy_op_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); - out->virtual_resize(self->sizes()); + graph->virtual_resize(out, graph->sizes_of(self)); } void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Transpose.cpp b/backends/vulkan/runtime/graph/ops/impl/Transpose.cpp index 8501d085bc8..b797536d817 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Transpose.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Transpose.cpp @@ -23,16 +23,16 @@ void resize_transpose_view_node( const std::vector& args, const std::vector& extra_args) { (void)args; - vTensorPtr out = graph->get_tensor(extra_args[0]); - vTensorPtr in = graph->get_tensor(extra_args[1]); + const ValueRef out = extra_args.at(0); + const ValueRef in = extra_args.at(1); - const int64_t dim0 = graph->extract_scalar(extra_args[2]); - const int64_t dim1 = graph->extract_scalar(extra_args[3]); + const int64_t dim0 = graph->extract_scalar(extra_args.at(2)); + const int64_t dim1 = graph->extract_scalar(extra_args.at(3)); - std::vector new_sizes = in->sizes(); + std::vector new_sizes = graph->sizes_of(in); // Transpose the resized input sizes std::iter_swap(new_sizes.begin() + dim0, new_sizes.begin() + dim1); - out->virtual_resize(new_sizes); + graph->virtual_resize(out, new_sizes); } void check_transpose_view_args( @@ -62,9 +62,8 @@ void add_transpose_view_node( const int64_t dim1 = graph.extract_scalar(dim1_ref); check_transpose_view_args(graph, input_ref, dim0, dim1, out_ref); - const vTensorPtr in = graph.get_tensor(input_ref); - graph.get_tensor(out_ref)->virtual_clone(*in); - graph.get_tensor(out_ref)->virtual_transpose(dim0, dim1); + graph.virtual_clone(out_ref, input_ref); + graph.virtual_transpose(out_ref, dim0, dim1); graph.execute_nodes().emplace_back(new ExecuteNode( resize_transpose_view_node, {out_ref, input_ref, dim0_ref, dim1_ref})); diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 085e8559980..9830a8e8784 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -26,10 +26,11 @@ void resize_unary_op_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); - out->virtual_resize(self->sizes()); + const std::vector self_sizes = graph->sizes_of(self); + graph->virtual_resize(out, self_sizes); } void add_unary_op_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp index d098ed94c7f..ed9fef61a78 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Upsample.cpp @@ -22,12 +22,12 @@ void resize_upsample_nearest2d_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr self = graph->get_tensor(args[1].refs[0]); - std::vector out_sizes = self->sizes(); // NCHW + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + std::vector out_sizes = graph->sizes_of(self); // NCHW - const ValueRef output_sizes = extra_args[0]; // HW - const ValueRef scale_factors = extra_args[1]; // HW + const ValueRef output_sizes = extra_args.at(0); // HW + const ValueRef scale_factors = extra_args.at(1); // HW if (!graph->val_is_none(output_sizes)) { IntListPtr output_size_ref = graph->get_int_list(output_sizes); out_sizes.at(2) = output_size_ref->at(0); @@ -38,7 +38,7 @@ void resize_upsample_nearest2d_node( out_sizes.at(3) *= scales->at(1); } - out->virtual_resize(out_sizes); + graph->virtual_resize(out, out_sizes); } void add_upsample_nearest2d_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/Var.cpp b/backends/vulkan/runtime/graph/ops/impl/Var.cpp index 41fdc41e982..106a6fd6d9a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Var.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Var.cpp @@ -19,16 +19,17 @@ void resize_var_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - int dim = extra_args[0]; + const int dim = extra_args.at(0); - std::vector new_sizes = in->sizes(); + std::vector new_sizes = graph->sizes_of(in); if (!new_sizes.empty()) { new_sizes.at(normalize(dim, new_sizes.size())) = 1; } - out->virtual_resize(new_sizes); + + graph->virtual_resize(out, new_sizes); } void add_var_buffer_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 9dbe79faebb..cb868acf7e9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -44,15 +44,19 @@ void resize_view_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); - if (extra_args[0] == kDummyValueRef || graph->val_is_none(extra_args[0])) { - out->virtual_resize(in->sizes()); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + if (extra_args.at(0) == kDummyValueRef || + graph->val_is_none(extra_args.at(0))) { + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); } else { std::vector view_sizes = - graph->extract_int_or_symint_list(extra_args[0]); - std::vector out_sizes = compute_out_sizes(in->sizes(), view_sizes); - out->virtual_resize(out_sizes); + graph->extract_int_or_symint_list(extra_args.at(0)); + const std::vector in_sizes = graph->sizes_of(in); + const std::vector out_sizes = + compute_out_sizes(in_sizes, view_sizes); + graph->virtual_resize(out, out_sizes); } } @@ -61,12 +65,9 @@ void add_view_node( ValueRef in, ValueRef sizes, ValueRef out) { - vTensorPtr t_in = graph.get_tensor(in); - vTensorPtr t_out = graph.get_tensor(out); - std::string kernel_name = "view"; kernel_name.reserve(kShaderNameReserve); - add_dtype_suffix(kernel_name, *t_out); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -81,7 +82,7 @@ void add_view_node( // Push Constants {{graph.sizes_pc_of(out), graph.sizes_pc_of(in)}}, // Specialization Constants - {SV(t_in->packed_dim()), SV(t_out->packed_dim())}, + {graph.packed_dim_of(in), graph.packed_dim_of(out)}, // Resize Args {sizes}, // Resizing Logic diff --git a/backends/vulkan/runtime/graph/ops/impl/Where.cpp b/backends/vulkan/runtime/graph/ops/impl/Where.cpp index ea610b1fe74..1868d3b872e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Where.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Where.cpp @@ -19,11 +19,11 @@ void resize_where_node( const std::vector& args, const std::vector& extra_args) { (void)extra_args; - vTensorPtr out = graph->get_tensor(args[0].refs[0]); - vTensorPtr in = graph->get_tensor(args[1].refs[0]); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); - std::vector in_sizes = in->sizes(); - out->virtual_resize(in_sizes); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); } void add_where_texture_node( diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h index 4bd8e9b900b..5ed07dece38 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h @@ -31,11 +31,6 @@ constexpr DimIndex kHeight4D = DimIndex::DIM_2ND_LAST; constexpr DimIndex kChannel4D = DimIndex::DIM_3RD_LAST; constexpr DimIndex kBatch4D = DimIndex::DIM_4TH_LAST; -inline DimIndex normalize_to_dim_index(const api::vTensor& v_in, int32_t dim) { - return dim < 0 ? static_cast(dim) - : static_cast(dim - v_in.dim()); -} - /* * Semantic dimension names for a 1D tensor */ @@ -83,15 +78,6 @@ int32_t dim_at(const std::vector& sizes) { return dim_at(sizes, DI); } -template -int32_t dim_at(const api::vTensor& v_in) { - return dim_at(v_in.sizes(), DI); -} - -inline int32_t dim_at(const api::vTensor& v_in, DimIndex dim_index) { - return dim_at(v_in.sizes(), dim_index); -} - inline std::ostream& operator<<(std::ostream& os, DimIndex dim_index) { switch (dim_index) { case kWidth4D: diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp index 2bcf2a3842f..a52572289a4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.cpp @@ -15,15 +15,14 @@ namespace vkcompute { // std::vector calculate_broadcasted_output_size( - const api::vTensor& t1, - const api::vTensor& t2) { - std::vector out_sizes( - std::max(t1.sizes().size(), t2.sizes().size())); + const std::vector& sizes1, + const std::vector& sizes2) { + std::vector out_sizes(std::max(sizes1.size(), sizes2.size())); // Match the sizes in reverse because sizes are in NCHW order for (int i = -1; i >= -out_sizes.size(); --i) { out_sizes.at(out_sizes.size() + i) = - std::max(utils::val_at(i, t1.sizes()), utils::val_at(i, t2.sizes())); + std::max(utils::val_at(i, sizes1), utils::val_at(i, sizes2)); } return out_sizes; @@ -33,30 +32,6 @@ std::vector calculate_broadcasted_output_size( // Tensor property checking functions // -bool check_ndim_is(const api::vTensor& t, size_t ndim) { - return t.sizes().size() == ndim; -} - -bool check_same_sizes_at( - const api::vTensor& t1, - const int64_t d1, - const api::vTensor& t2, - const int64_t d2) { - return utils::val_at(d1, t1.sizes()) == utils::val_at(d2, t2.sizes()); -} - -bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim) { - return t.packed_dim() == packed_dim; -} - -bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2) { - return t1.sizes().size() == t2.sizes().size(); -} - -bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2) { - return t1.packed_dim() == t2.packed_dim(); -} - bool check_same_packed_dim( ComputeGraph& graph, const ValueRef in, @@ -64,42 +39,38 @@ bool check_same_packed_dim( return graph.packed_dim_of(in) == graph.packed_dim_of(out); } -bool check_same_packed_dim( - const api::vTensor& t1, - const api::vTensor& t2, - const api::vTensor& t3) { - if (t1.packed_dim() != t2.packed_dim()) { - return false; - } - return (t1.packed_dim() == t3.packed_dim()); -} - // // Broadcast flag functions // bool is_packed_dim_broadcasted( - const api::vTensor& sndr, - const api::vTensor& rcvr) { + ComputeGraph& graph, + const ValueRef sndr, + const ValueRef rcvr) { // We assume that the tensors are broadcastable. If values aren't equal at // some index, then the value of rcvr is 1 and hence should be broadcasted. - switch (sndr.packed_dim()) { + const std::vector sndr_sizes = graph.sizes_of(sndr); + const std::vector rcvr_sizes = graph.sizes_of(rcvr); + + switch (graph.packed_dim_of(sndr)) { case WHCN::kChannelsDim: - return utils::val_at(-3, sndr.sizes()) > utils::val_at(-3, rcvr.sizes()); + return utils::val_at(-3, sndr_sizes) > utils::val_at(-3, rcvr_sizes); case WHCN::kHeightDim: - return utils::val_at(-2, sndr.sizes()) > utils::val_at(-2, rcvr.sizes()); + return utils::val_at(-2, sndr_sizes) > utils::val_at(-2, rcvr_sizes); case WHCN::kWidthDim: - return utils::val_at(-1, sndr.sizes()) > utils::val_at(-1, rcvr.sizes()); + return utils::val_at(-1, sndr_sizes) > utils::val_at(-1, rcvr_sizes); default: VK_THROW("Invalid packed dim"); } } utils::ivec2 create_broadcast_params( - const api::vTensor& t1, - const api::vTensor& t2) { + ComputeGraph& graph, + const ValueRef t1, + const ValueRef t2) { return utils::make_ivec2( - {is_packed_dim_broadcasted(t2, t1), is_packed_dim_broadcasted(t1, t2)}); + {is_packed_dim_broadcasted(graph, t2, t1), + is_packed_dim_broadcasted(graph, t1, t2)}); } // diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h index 3b61083069e..b62bf661995 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h @@ -18,44 +18,31 @@ namespace vkcompute { // std::vector calculate_broadcasted_output_size( - const api::vTensor& t1, - const api::vTensor& t2); + const std::vector& sizes1, + const std::vector& sizes2); // // Tensor property checking functions // -bool check_ndim_is(const api::vTensor& t, size_t ndim); - -bool check_same_ndim(const api::vTensor& t1, const api::vTensor& t2); - -bool check_same_sizes_at( - const api::vTensor& t1, - int64_t d1, - const api::vTensor& t2, - int64_t d2); - -bool check_packed_dim_is(const api::vTensor& t, const int32_t packed_dim); - -bool check_same_packed_dim(const api::vTensor& t1, const api::vTensor& t2); - bool check_same_packed_dim( ComputeGraph& graph, const ValueRef in, const ValueRef out); -bool check_same_packed_dim( - const api::vTensor& t1, - const api::vTensor& t2, - const api::vTensor& t3); - // // Broadcast flag functions // +bool is_packed_dim_broadcasted( + ComputeGraph& graph, + const ValueRef sndr, + const ValueRef rcvr); + utils::ivec2 create_broadcast_params( - const api::vTensor& t1, - const api::vTensor& t2); + ComputeGraph& graph, + const ValueRef t1, + const ValueRef t2); // // Work group size calculation functions diff --git a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp index b3a72e27c43..e829f355fe2 100644 --- a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp @@ -10,23 +10,6 @@ namespace vkcompute { -void bind_tensor_to_descriptor_set( - api::vTensor& tensor, - vkapi::PipelineBarrier& pipeline_barrier, - const vkapi::MemoryAccessFlags accessType, - vkapi::DescriptorSet& descriptor_set, - const uint32_t idx) { - if (tensor.buffer()) { - vkapi::VulkanBuffer& buffer = tensor.buffer( - pipeline_barrier, vkapi::PipelineStage::COMPUTE, accessType); - descriptor_set.bind(idx, buffer); - } else { - vkapi::VulkanImage& image = tensor.image( - pipeline_barrier, vkapi::PipelineStage::COMPUTE, accessType); - descriptor_set.bind(idx, image); - } -} - uint32_t bind_values_to_descriptor_set( ComputeGraph* graph, const std::vector& args, @@ -36,19 +19,8 @@ uint32_t bind_values_to_descriptor_set( uint32_t idx = base_idx; for (auto& arg : args) { for (auto& ref : arg.refs) { - if (graph->val_is_tensor(ref)) { - bind_tensor_to_descriptor_set( - *(graph->get_tensor(ref)), - pipeline_barrier, - arg.access, - descriptor_set, - idx++); - } else if (graph->val_is_staging(ref)) { - bind_staging_to_descriptor_set( - *(graph->get_staging(ref)), descriptor_set, idx++); - } else { - VK_THROW("Unsupported type: ", graph->get_val_type(ref)); - } + graph->bind_value_to_descriptor_set( + ref, pipeline_barrier, arg.access, descriptor_set, idx++); } } return idx; diff --git a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h index 671a18f7e91..307bec154f3 100644 --- a/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h @@ -16,13 +16,6 @@ namespace vkcompute { // For objects in the graph // -void bind_tensor_to_descriptor_set( - api::vTensor& tensor, - vkapi::PipelineBarrier& pipeline_barrier, - const vkapi::MemoryAccessFlags accessType, - vkapi::DescriptorSet& descriptor_set, - const uint32_t idx); - uint32_t bind_values_to_descriptor_set( ComputeGraph* graph, const std::vector& args, diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index 6388a8ad091..231e6d0c7f6 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -26,12 +26,6 @@ void add_storage_type_suffix( } } -void add_storage_type_suffix( - std::string& kernel_name, - const api::vTensor& tensor) { - return add_storage_type_suffix(kernel_name, tensor.storage_type()); -} - void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { switch (dtype) { case vkapi::kDouble: @@ -75,23 +69,6 @@ void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { } } -void add_dtype_suffix(std::string& kernel_name, const api::vTensor& tensor) { - return add_dtype_suffix(kernel_name, tensor.dtype()); -} - -void add_ndim_suffix(std::string& kernel_name, const api::vTensor& tensor) { - switch (tensor.storage_type()) { - case utils::kTexture3D: - kernel_name += "_3d"; - break; - case utils::kTexture2D: - kernel_name += "_2d"; - break; - default: - break; - } -} - void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim) { switch (packed_dim) { case WHCN::kWidthDim: @@ -108,10 +85,4 @@ void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim) { } } -void add_packed_dim_suffix( - std::string& kernel_name, - const api::vTensor& tensor) { - return add_packed_dim_suffix(kernel_name, tensor.packed_dim()); -} - } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h index 10084054964..4a2fddb5cf2 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h @@ -19,19 +19,11 @@ constexpr size_t kShaderNameReserve = 64u; void add_storage_type_suffix( std::string& kernel_name, const utils::StorageType storage_type); -void add_storage_type_suffix( - std::string& kernel_name, - const api::vTensor& tensor); void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype); -void add_dtype_suffix(std::string& kernel_name, const api::vTensor& tensor); void add_ndim_suffix(std::string& kernel_name, const size_t ndim); -void add_ndim_suffix(std::string& kernel_name, const api::vTensor& tensor); void add_packed_dim_suffix(std::string& kernel_name, const int32_t packed_dim); -void add_packed_dim_suffix( - std::string& kernel_name, - const api::vTensor& tensor); } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index ea3ae0fa1c3..904b91965d6 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -21,29 +21,33 @@ bool is_bitw8(vkapi::ScalarType dtype) { } vkapi::ShaderInfo get_nchw_to_tensor_shader( - const api::vTensor& v_dst, + ComputeGraph& graph, + const ValueRef dst, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); - if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer && + const vkapi::ScalarType dst_dtype = graph.dtype_of(dst); + const utils::StorageType dst_storage_type = graph.storage_type_of(dst); + + if (is_bitw8(dst_dtype) && dst_storage_type != utils::kBuffer && !int8_buffer_enabled) { kernel_name = "nchw_to_bitw8_image_nobitw8buffer"; if (!push_constant_variant) { kernel_name += "_no_pc"; } - add_storage_type_suffix(kernel_name, v_dst); - add_dtype_suffix(kernel_name, v_dst); + add_storage_type_suffix(kernel_name, dst_storage_type); + add_dtype_suffix(kernel_name, dst_dtype); return VK_KERNEL_FROM_STR(kernel_name); } - if (v_dst.storage_type() == utils::kBuffer) { + if (dst_storage_type == utils::kBuffer) { kernel_name = "nchw_to_buffer"; if (!push_constant_variant) { kernel_name += "_no_pc"; } - add_dtype_suffix(kernel_name, v_dst); + add_dtype_suffix(kernel_name, dst_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -51,36 +55,40 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( if (!push_constant_variant) { kernel_name += "_no_pc"; } - add_storage_type_suffix(kernel_name, v_dst); - add_dtype_suffix(kernel_name, v_dst); + add_storage_type_suffix(kernel_name, dst_storage_type); + add_dtype_suffix(kernel_name, dst_dtype); return VK_KERNEL_FROM_STR(kernel_name); } vkapi::ShaderInfo get_tensor_to_nchw_shader( - const api::vTensor& v_src, + ComputeGraph& graph, + const ValueRef src, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); - if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer && + const vkapi::ScalarType src_dtype = graph.dtype_of(src); + const utils::StorageType src_storage_type = graph.storage_type_of(src); + + if (is_bitw8(src_dtype) && src_storage_type != utils::kBuffer && !int8_buffer_enabled) { kernel_name = "bitw8_image_to_nchw_nobitw8buffer"; if (!push_constant_variant) { kernel_name += "_no_pc"; } - add_storage_type_suffix(kernel_name, v_src); - add_dtype_suffix(kernel_name, v_src); + add_storage_type_suffix(kernel_name, src_storage_type); + add_dtype_suffix(kernel_name, src_dtype); return VK_KERNEL_FROM_STR(kernel_name); } - if (v_src.storage_type() == utils::kBuffer) { + if (src_storage_type == utils::kBuffer) { kernel_name = "buffer_to_nchw"; if (!push_constant_variant) { kernel_name += "_no_pc"; } - add_dtype_suffix(kernel_name, v_src); + add_dtype_suffix(kernel_name, src_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -88,8 +96,8 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( if (!push_constant_variant) { kernel_name += "_no_pc"; } - add_storage_type_suffix(kernel_name, v_src); - add_dtype_suffix(kernel_name, v_src); + add_storage_type_suffix(kernel_name, src_storage_type); + add_dtype_suffix(kernel_name, src_dtype); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index 9e6b61d6cd8..71c92b833b7 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -13,11 +13,13 @@ namespace vkcompute { vkapi::ShaderInfo get_nchw_to_tensor_shader( - const api::vTensor& v_dst, + ComputeGraph& graph, + const ValueRef dst, bool int8_buffer_enabled = true, bool push_constant_variant = true); vkapi::ShaderInfo get_tensor_to_nchw_shader( - const api::vTensor& v_src, + ComputeGraph& graph, + const ValueRef src, bool int8_buffer_enabled = true, bool push_constant_variant = true); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 22725a46100..5efcfc1ffb2 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1137,7 +1137,7 @@ def get_repeat_inputs(): "utils::kHeightPacked", "utils::kChannelsPacked", ] - test_suite_2d.storage_types = ["utils::kTexture2D"] + test_suite_2d.storage_types = ["utils::kTexture3D"] test_suite_2d.data_gen = "make_seq_tensor" test_suite_2d.dtypes = ["at::kFloat"] test_suite_2d.test_name_suffix = "2d" diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index 4fba14ca16e..490044340d6 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -549,15 +549,13 @@ def virtual_resize(self, ref: ValueRefList) -> str: return "" if ref.src_cpp_type == AT_TENSOR: - ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)" - ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n" + ret_str = f"{self.graph}{self.dot}virtual_resize({ref.name}.value, " + ret_str += f"{ref.src_cpp_name}.sizes().vec());\n" elif ref.src_cpp_type == AT_TENSOR_LIST: ret_str = "" ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" - ret_str += ( - f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)" - ) - ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n" + ret_str += f" {self.graph}{self.dot}virtual_resize({ref.name}_io_value_refs[i].value, " + ret_str += f"{ref.src_cpp_name}[i].sizes().vec());\n" ret_str += "}\n" else: raise AssertionError(f"{ref.src_cpp_type} not expected") diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index faa0e7d0c47..c026c1364fa 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -14,9 +14,88 @@ #include #include +#include using namespace vkcompute; +bool is_bitw8(vkapi::ScalarType dtype) { + return dtype == vkapi::kByte || dtype == vkapi::kChar || + dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8; +} + +vkapi::ShaderInfo get_nchw_to_tensor_shader( + const api::vTensor& v_dst, + bool int8_buffer_enabled, + bool push_constant_variant) { + std::string kernel_name; + kernel_name.reserve(kShaderNameReserve); + + if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer && + !int8_buffer_enabled) { + kernel_name = "nchw_to_bitw8_image_nobitw8buffer"; + if (!push_constant_variant) { + kernel_name += "_no_pc"; + } + add_storage_type_suffix(kernel_name, v_dst.storage_type()); + add_dtype_suffix(kernel_name, v_dst.dtype()); + return VK_KERNEL_FROM_STR(kernel_name); + } + + if (v_dst.storage_type() == utils::kBuffer) { + kernel_name = "nchw_to_buffer"; + if (!push_constant_variant) { + kernel_name += "_no_pc"; + } + add_dtype_suffix(kernel_name, v_dst.dtype()); + return VK_KERNEL_FROM_STR(kernel_name); + } + + kernel_name = "nchw_to_image"; + if (!push_constant_variant) { + kernel_name += "_no_pc"; + } + add_storage_type_suffix(kernel_name, v_dst.storage_type()); + add_dtype_suffix(kernel_name, v_dst.dtype()); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +vkapi::ShaderInfo get_tensor_to_nchw_shader( + const api::vTensor& v_src, + bool int8_buffer_enabled, + bool push_constant_variant) { + std::string kernel_name; + kernel_name.reserve(kShaderNameReserve); + + if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer && + !int8_buffer_enabled) { + kernel_name = "bitw8_image_to_nchw_nobitw8buffer"; + if (!push_constant_variant) { + kernel_name += "_no_pc"; + } + add_storage_type_suffix(kernel_name, v_src.storage_type()); + add_dtype_suffix(kernel_name, v_src.dtype()); + return VK_KERNEL_FROM_STR(kernel_name); + } + + if (v_src.storage_type() == utils::kBuffer) { + kernel_name = "buffer_to_nchw"; + if (!push_constant_variant) { + kernel_name += "_no_pc"; + } + add_dtype_suffix(kernel_name, v_src.dtype()); + return VK_KERNEL_FROM_STR(kernel_name); + } + + kernel_name = "image_to_nchw"; + if (!push_constant_variant) { + kernel_name += "_no_pc"; + } + add_storage_type_suffix(kernel_name, v_src.storage_type()); + add_dtype_suffix(kernel_name, v_src.dtype()); + + return VK_KERNEL_FROM_STR(kernel_name); +} // // Operator Recording Functions // @@ -121,8 +200,8 @@ void record_bitw8_image_to_nchw_nobitw8buffer_op( utils::uvec3 global_wg_size = {buffer_len, 1, 1}; std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer_no_pc"; - add_storage_type_suffix(kernel_name, v_src); - add_dtype_suffix(kernel_name, v_src); + add_storage_type_suffix(kernel_name, v_src.storage_type()); + add_dtype_suffix(kernel_name, v_src.dtype()); context->submit_compute_job( VK_KERNEL_FROM_STR(kernel_name), @@ -145,7 +224,7 @@ void record_binary_op( api::vTensor& v_in2, api::vTensor& v_dst) { std::string kernel_name = "binary_" + op_name + "_nobroadcast__test"; - add_dtype_suffix(kernel_name, v_dst); + add_dtype_suffix(kernel_name, v_dst.dtype()); vkapi::PipelineBarrier pipeline_barrier{}; vkapi::SpecVarList specialization_constants = {}; @@ -236,7 +315,7 @@ void record_scalar_add_buffer( vkapi::PipelineBarrier pipeline_barrier{}; vkapi::SpecVarList specialization_constants = {SV(offset)}; std::string kernel = "scalar_add_buffer"; - add_dtype_suffix(kernel, v_ten); + add_dtype_suffix(kernel, v_ten.dtype()); api::context()->submit_compute_job( VK_KERNEL_FROM_STR(kernel), pipeline_barrier, @@ -398,10 +477,9 @@ void fill_vtensor( const IOValueRef idx, float val, bool iota) { - vTensorPtr t = graph.get_tensor(idx.value); - std::vector data(t->numel()); - if (t->storage_type() != utils::kBuffer) { - data.resize(t->staging_buffer_numel()); + std::vector data(graph.numel_of(idx.value)); + if (graph.storage_type_of(idx.value) != utils::kBuffer) { + data.resize(graph.staging_buffer_numel_of(idx.value)); } if (iota) { std::iota(data.begin(), data.end(), val); @@ -489,13 +567,12 @@ void execute_graph_and_check_output( for (size_t i = 0; i < graph.outputs().size(); ++i) { IOValueRef out_ioval = graph.outputs().at(i); - vTensorPtr t_out = graph.get_tensor(out_ioval.value); - - std::vector output_data(t_out->staging_buffer_numel()); + std::vector output_data( + graph.staging_buffer_numel_of(out_ioval.value)); graph.copy_from_staging( out_ioval.staging, output_data.data(), output_data.size()); - for (size_t j = 0; j < t_out->numel(); ++j) { + for (size_t j = 0; j < graph.numel_of(out_ioval.value); ++j) { CHECK_VALUE(output_data, j, expected_outputs.at(i)); } } diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 0f0d2647792..1fd40b6f815 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -214,9 +214,7 @@ inline int64_t get_buf_idx( vkcompute::ComputeGraph& graph, vkcompute::IOValueRef ref, const std::vector& tensor_coor) { - vkcompute::vTensorPtr vten_ptr = graph.get_tensor(ref.value); - - const std::vector& sizes = vten_ptr->sizes(); + const std::vector& sizes = graph.sizes_of(ref.value); int64_t c = vkcompute::dim_at(sizes); int64_t h = vkcompute::dim_at(sizes); diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 82df7e7d96f..f99552ceee1 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -498,7 +498,7 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) { vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); std::string kernel_name("fill_texture__test"); - add_dtype_suffix(kernel_name, a); + add_dtype_suffix(kernel_name, a.dtype()); struct Params final { utils::ivec3 size; @@ -1014,9 +1014,8 @@ TEST_F(VulkanComputeAPITest, texture_virtual_resize) { // Compute Graph Tests // -#define EXTRACT_TENSOR(name) \ - std::vector data_##name( \ - graph.get_tensor(name.value)->staging_buffer_numel()); \ +#define EXTRACT_TENSOR(name) \ + std::vector data_##name(graph.staging_buffer_numel_of(name.value)); \ graph.copy_from_staging(name.staging, data_##name.data(), data_##name.size()); // The purpose of this test is simply to track the size of various classes over @@ -1041,8 +1040,8 @@ TEST_F(VulkanComputeAPITest, print_object_sizes) { EXPECT_TRUE(sizeof(Value) < 56); // Current known size on 64 bit system: 120 B EXPECT_TRUE(sizeof(StagingBuffer) < 500); - // Current known size on 64 bit system: 384 B - EXPECT_TRUE(sizeof(ComputeGraph) < 500); + // Current known size on 64 bit system: 512 B + EXPECT_TRUE(sizeof(ComputeGraph) < 600); // Current known size on 64 bit system: 248 B EXPECT_TRUE(sizeof(DispatchNode) < 500); } @@ -1193,7 +1192,7 @@ TEST(VulkanComputeGraphTest, test_zero_dim_tensor) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + for (size_t i = 0; i < graph.numel_of(out.value); ++i) { CHECK_VALUE(data_out, i, val_c); } } @@ -1233,7 +1232,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_buffer) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + for (size_t i = 0; i < graph.numel_of(out.value); ++i) { CHECK_VALUE(data_out, i, expected_val); } } @@ -1320,7 +1319,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + for (size_t i = 0; i < graph.numel_of(out.value); ++i) { CHECK_VALUE(data_out, i, val_c); } } @@ -1382,7 +1381,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_symint) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + for (size_t i = 0; i < graph.numel_of(out.value); i++) { CHECK_VALUE(data_out, i, val_out); } } @@ -1445,7 +1444,7 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + for (size_t i = 0; i < graph.numel_of(out.value); ++i) { CHECK_VALUE(data_out, i, val_out); } @@ -1531,9 +1530,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { {8, 44, 34}, {4, 13, 56}, {8, 12, 64}, {12, 55, 33}, {4, 54, 10}}; for (auto& new_sizes : new_sizes_list) { - graph.get_tensor(a.value)->virtual_resize(new_sizes); - graph.get_tensor(b.value)->virtual_resize(new_sizes); - graph.get_tensor(d.value)->virtual_resize(new_sizes); + graph.virtual_resize(a.value, new_sizes); + graph.virtual_resize(b.value, new_sizes); + graph.virtual_resize(d.value, new_sizes); graph.propagate_resize(); float val_a = new_sizes[1] + 4.0f; @@ -1551,7 +1550,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); i++) { + for (size_t i = 0; i < graph.numel_of(out.value); i++) { CHECK_VALUE(data_out, i, val_out); } } @@ -1566,7 +1565,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { graph.propagate_resize(); // Check output shape - EXPECT_TRUE(graph.get_tensor(out.value)->sizes() == new_sizes); + EXPECT_TRUE(graph.sizes_of(out.value) == new_sizes); float val_a = new_sizes[1] + 6.0f; float val_b = new_sizes[2] + 2.5f; @@ -1583,7 +1582,7 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); i++) { + for (size_t i = 0; i < graph.numel_of(out.value); i++) { CHECK_VALUE(data_out, i, val_out); } } @@ -1681,7 +1680,7 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_tmp_tensors) { EXTRACT_TENSOR(out); // Sanity check that the values are correct - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { + for (size_t i = 0; i < graph.numel_of(out.value); ++i) { CHECK_VALUE(data_out, i, val_out); } } @@ -1767,7 +1766,7 @@ TEST(VulkanComputeGraphTest, test_large_graph) { auto inference_time = std::chrono::duration_cast( inference_end_time - inference_start_time); - for (int i = 0; i < graph.get_tensor(out.value)->numel(); i++) { + for (int i = 0; i < graph.numel_of(out.value); i++) { CHECK_VALUE(data_out, i, val_e); } @@ -2282,7 +2281,7 @@ TEST(VulkanComputeGraphTest, test_view_change_packing) { // The extracted data is a flattened nchw buffer. Hence, should expect the // all elements inside the out array to match the index. - for (int i = 0; i < graph.get_tensor(out.value)->numel(); i++) { + for (int i = 0; i < graph.numel_of(out.value); i++) { CHECK_VALUE(data_out, i, i); } } @@ -2317,7 +2316,7 @@ void run_from_gpu_test( vTensor vten = vTensor(context(), sizes, dtype, storage_type, memory_layout); std::string kernel_name("idx_fill_texture"); - add_dtype_suffix(kernel_name, vten); + add_dtype_suffix(kernel_name, vten.dtype()); int32_t offset = -50; @@ -2432,9 +2431,7 @@ void compute_graph_round_trip_test( graph.prepare(); - vTensorPtr tensor = graph.get_tensor(r_tensor); - - std::vector data_in(tensor->numel()); + std::vector data_in(graph.numel_of(r_tensor)); for (int i = 0; i < data_in.size(); i++) { data_in[i] = T(i * -1); } @@ -2442,7 +2439,7 @@ void compute_graph_round_trip_test( graph.execute(); - std::vector data_out(tensor->staging_buffer_numel()); + std::vector data_out(graph.staging_buffer_numel_of(r_tensor)); graph.copy_from_staging(r_staging_out, data_out.data(), data_out.size()); for (int i = 0; i < data_in.size(); i++) { @@ -2740,94 +2737,6 @@ TEST(VulkanComputeGraphOpsTest, test_graph_resize_reencode) { utils::kWidthPacked); } -void test_max_pool2d( - const std::vector& in_size, - const int64_t base_val, - std::vector& kernel) { - GraphConfig config; - ComputeGraph graph(config); - - // Build graph - - std::vector out_size(in_size); - int h = in_size.size() - 2; - int w = in_size.size() - 1; - out_size[h] = in_size[h] - kernel[0] + 1; - out_size[w] = in_size[w] - kernel[1] + 1; - - IOValueRef in_ioval = graph.add_input_tensor( - in_size, vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); - IOValueRef out_ioval; - out_ioval.value = graph.add_tensor( - out_size, vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); - IOValueRef idx_ioval; - idx_ioval.value = graph.add_tensor( - out_size, vkapi::kInt, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); - ValueRef out = graph.add_value_list({out_ioval.value, idx_ioval.value}); - - std::vector kernel_copy(kernel); - VK_GET_OP_FN("aten.max_pool2d_with_indices.default") - (graph, - {in_ioval.value, - graph.add_scalar_list(std::move(kernel)), - graph.add_scalar_list({1, 1}), - graph.add_scalar_list({0, 0}), - graph.add_scalar_list({1, 1}), - graph.add_scalar(false), - out}); - - out_ioval.staging = graph.set_output_tensor(out_ioval.value); - idx_ioval.staging = graph.set_output_tensor(idx_ioval.value); - - graph.prepare(); - - graph.prepack(); - - // Run graph - - fill_vtensor(graph, graph.inputs().at(0), base_val, /*iota = */ true); - - vTensorPtr t_in = graph.get_tensor(in_ioval.value); - std::vector input_data(t_in->staging_buffer_numel()); - graph.copy_from_staging( - in_ioval.staging, input_data.data(), input_data.size()); - - graph.execute(); - - vTensorPtr t_out = graph.get_tensor(out_ioval.value); - std::vector output_data(t_out->staging_buffer_numel()); - graph.copy_from_staging( - out_ioval.staging, output_data.data(), output_data.size()); - vTensorPtr t_idx = graph.get_tensor(idx_ioval.value); - std::vector index_data(t_idx->staging_buffer_numel()); - graph.copy_from_staging( - idx_ioval.staging, index_data.data(), index_data.size()); - - // Check results - - int h_offset = kernel_copy[0] - 1; - int w_offset = kernel_copy[1] - 1; - int h_out = utils::val_at(-2, t_out->sizes()); - int w_out = utils::val_at(-1, t_out->sizes()); - int w_in = utils::val_at(-1, t_in->sizes()); - for (size_t i = 0; i < h_out; ++i) { - for (size_t j = 0; j < w_out; ++j) { - size_t idx_out = i * w_out + j; - size_t idx_in = (i + h_offset) * w_in + (j + w_offset); - CHECK_VALUE(index_data, idx_out, idx_in); - CHECK_VALUE(output_data, idx_out, input_data[idx_in]); - } - } -} - -TEST(VulkanComputeGraphOpsTest, max_pool2d_smoke_test) { - std::vector kernel = {2, 3}; - test_max_pool2d( - /*in_size = */ {1, 4, 6}, - /*base_val = */ 10.0f, - kernel); -} - void test_grid_priors( std::vector input_sizes, std::vector output_sizes, @@ -2861,20 +2770,19 @@ void test_grid_priors( graph.prepack(); - vTensorPtr t_in = graph.get_tensor(in.value); - vTensorPtr t_out = graph.get_tensor(out.value); // Resize input graph.propagate_resize(); // run graph graph.execute(); - std::vector output_data(t_out->staging_buffer_numel()); + std::vector output_data(graph.staging_buffer_numel_of(out.value)); graph.copy_from_staging(out.staging, output_data.data(), output_data.size()); // check results - int h_out = utils::val_at(-2, t_out->sizes()); - int w_out = utils::val_at(-1, t_out->sizes()); + std::vector out_sizes = graph.sizes_of(out.value); + int h_out = utils::val_at(-2, out_sizes); + int w_out = utils::val_at(-1, out_sizes); for (size_t i = 0; i < h_out; ++i) { for (size_t j = 0; j < w_out; ++j) { size_t idx_out = i * w_out + j; @@ -3151,7 +3059,7 @@ void resize_dynamic_dispatch_node( std::vector out_sizes = graph->sizes_of(mat1); out_sizes.at(out_sizes.size() - 2) = 1; - graph->get_tensor(out)->virtual_resize(out_sizes); + graph->virtual_resize(out, out_sizes); } void add_dynamic_dispatch_test_node(