From f658a42eab7c0536f7f8e65600d72ea56e6c24ac Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 2 Jan 2025 13:05:28 -0800 Subject: [PATCH] [ET-VK][ez] Fix undefined behaviour in ambiguous `ParamsBuffer` constructor ## Context I discovered this bug when trying to execute the `vulkan_compute_api_test` binary on Windows. Almost all the tests were failing, with compute shaders producing incorrect results. After bisecting the change, it turns out the culprit is https://github.com/pytorch/executorch/pull/7015. The diff introduced an alternative templated constructor for `ParamsBuffer` which would initialize an empty UBO with a specified size instead of wrapping a pre-existing object. The issue is that these constructors are ambiguous because they both are template constructors and both only accept one argument. Therefore, the original constructor would be called when certain callsites intended to call the new constructor. This results in a UBO being created with an incorrect size, and resulted in the tensor's metadata being passed incorrectly into a compute shader. To fix, I added a dummy argument into the new constructor for disambiguation purposes. I also changed it so that it's not templated, since there's no reason for it to be templated. Differential Revision: [D67770791](https://our.internmc.facebook.com/intern/diff/D67770791/) ghstack-source-id: 260031108 Pull Request resolved: https://github.com/pytorch/executorch/pull/7478 --- backends/vulkan/runtime/api/containers/ParamsBuffer.h | 5 +++-- backends/vulkan/runtime/api/containers/Tensor.cpp | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/ParamsBuffer.h b/backends/vulkan/runtime/api/containers/ParamsBuffer.h index fe157c5e014..ecc07892cf7 100644 --- a/backends/vulkan/runtime/api/containers/ParamsBuffer.h +++ b/backends/vulkan/runtime/api/containers/ParamsBuffer.h @@ -31,8 +31,9 @@ class ParamsBuffer final { vulkan_buffer_( context_p_->adapter_ptr()->vma().create_params_buffer(block)) {} - template - ParamsBuffer(Context* context_p, const VkDeviceSize nbytes) + // The last bool argument, though unused, is required to disambiguate this + // constructor from the one above. + ParamsBuffer(Context* context_p, const VkDeviceSize nbytes, const bool unused) : context_p_(context_p), vulkan_buffer_( context_p_->adapter_ptr()->vma().create_uniform_buffer(nbytes)) {} diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 21b0ee4b176..92e310d36de 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -659,7 +659,7 @@ utils::GPUMemoryLayout vTensor::estimate_memory_layout() const { const vkapi::BufferBindInfo vTensor::sizes_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (sizes_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -674,7 +674,7 @@ const vkapi::BufferBindInfo vTensor::sizes_ubo() { const vkapi::BufferBindInfo vTensor::strides_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (unsqueezed_strides_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -691,7 +691,7 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() { const vkapi::BufferBindInfo vTensor::logical_limits_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (logical_limits_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -707,7 +707,7 @@ const vkapi::BufferBindInfo vTensor::logical_limits_ubo() { const vkapi::BufferBindInfo vTensor::numel_ubo() { if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize); + uniforms_ = ParamsBuffer(storage_.context_, kMaxUniformBufferSize, true); } if (numel_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND(