diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_rgba.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_rgba.glsl new file mode 100644 index 00000000000..b3e2f856974 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_rgba.glsl @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION highp + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, rgba8) uniform PRECISION restrict writeonly image2D rgba_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D t_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict readonly limits_UBO { + ivec3 limits; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, limits))) { + return; + } + + imageStore(rgba_out, pos.xy, texelFetch(t_in, pos, 0)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/rgba_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/rgba_to_image.glsl new file mode 100644 index 00000000000..9e764d9e88f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/rgba_to_image.glsl @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION highp + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, rgba32f) uniform PRECISION restrict writeonly image3D t_out; +layout(set = 0, binding = 1) uniform PRECISION sampler2D rgba_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict readonly limits_UBO { + ivec3 limits; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, limits))) { + return; + } + + imageStore(t_out, pos, texelFetch(rgba_in, pos.xy, 0)); +} diff --git a/backends/vulkan/runtime/utils/AhbUtils.h b/backends/vulkan/runtime/utils/AhbUtils.h index 02e5e3a2b77..d8f2ceceb96 100644 --- a/backends/vulkan/runtime/utils/AhbUtils.h +++ b/backends/vulkan/runtime/utils/AhbUtils.h @@ -12,6 +12,8 @@ #include +#include +#include #include namespace vkcompute { @@ -181,5 +183,47 @@ vkapi::VulkanImage create_image_from_ahb( context->device(), image_props, image, image_view, sampler); } +void add_rgba_to_image_node( + ComputeGraph& graph, + const ValueRef in_rgba, + const ValueRef out_tensor) { + std::string kernel_name("rgba_to_image"); + kernel_name.reserve(kShaderNameReserve); + + const auto global_wg_size = graph.create_global_wg_size(out_tensor); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Input and Outputs + {{out_tensor, vkapi::MemoryAccessType::WRITE}, + {in_rgba, vkapi::MemoryAccessType::READ}}, + // Parameter Buffers + {graph.logical_limits_ubo(out_tensor)})); +} + +void add_image_to_rgba_node( + ComputeGraph& graph, + const ValueRef in_tensor, + const ValueRef out_rgba) { + std::string kernel_name("image_to_rgba"); + kernel_name.reserve(kShaderNameReserve); + + const auto global_wg_size = graph.create_global_wg_size(out_rgba); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Input and Outputs + {{out_rgba, vkapi::MemoryAccessType::WRITE}, + {in_tensor, vkapi::MemoryAccessType::READ}}, + // Parameter Buffers + {graph.logical_limits_ubo(out_rgba)})); +} + } // namespace utils } // namespace vkcompute