diff --git a/backends/vulkan/runtime/graph/ops/glsl/tan.glsl b/backends/vulkan/runtime/graph/ops/glsl/tan.glsl new file mode 100644 index 00000000000..876cd43ad08 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/tan.glsl @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type(STORAGE)} + +#include "indexing_utils.h" + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} +$if STORAGE == "buffer": + ${layout_declare_ubo(2, "int", "numel")} +$else: + ${layout_declare_ubo(2, "ivec3", "out_limits")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "activations.h" + +#ifdef USING_BUFFER + +void main() { + const int i = int(gl_GlobalInvocationID.x); + if (i >= numel) { + return; + } + + float in_val = float(t_in[i]); + t_out[i] = T(tan(in_val)); +} + +#else + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + VEC4_T in_texel = texelFetch(t_in, pos, 0); + imageStore(t_out, pos, VEC4_T(tan(in_texel))); +} + +#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/tan.yaml b/backends/vulkan/runtime/graph/ops/glsl/tan.yaml new file mode 100644 index 00000000000..ad0755bfad0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/tan.yaml @@ -0,0 +1,13 @@ +tan: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + STORAGE: + - VALUE: texture3d + - VALUE: buffer + shader_variants: + - NAME: tan diff --git a/backends/vulkan/runtime/graph/ops/impl/Tan.cpp b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp new file mode 100644 index 00000000000..89c4a4d408f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Tan.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +using namespace utils; + +void resize_tan_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr self = graph->get_tensor(args[1].refs[0]); + + out->virtual_resize(self->sizes()); +} + +void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) { + std::string kernel_name = "tan"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + + vkapi::ParamsBindList ubos({}); + ubos.append({graph.logical_limits_ubo(out)}); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + graph.create_global_wg_size(out), + graph.create_local_wg_size(out), + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_tan_node)); +} + +void tan(ComputeGraph& graph, const std::vector& args) { + return add_tan_node(graph, args[0], args[1]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.tan.default, tan); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 4a12f16bbf9..f7abf5c2e9b 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1164,6 +1164,22 @@ def get_unary_ops_inputs(): return test_suite +# separate test suite from unary_ops for learning purposes +@register_test_suite("aten.tan.default") +def get_tan_inputs(): + test_suite = VkTestSuite( + [ + (M1,), + (M1, M2), + (S1, M1, M2), + (S1, S2, S2, M2), + ] + ) + test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] + test_suite.dtypes = ["at::kFloat", "at::kHalf"] + return test_suite + + @register_test_suite("aten._native_batch_norm_legit_no_training.default") def get_native_batch_norm_inputs(): Test = namedtuple(