Skip to content
60 changes: 60 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/tan.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define T ${buffer_scalar_type(DTYPE)}

${define_active_storage_type(STORAGE)}

#include "indexing_utils.h"

${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
$if STORAGE == "buffer":
${layout_declare_ubo(2, "int", "numel")}
$else:
${layout_declare_ubo(2, "ivec3", "out_limits")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#include "activations.h"

#ifdef USING_BUFFER

void main() {
const int i = int(gl_GlobalInvocationID.x);
if (i >= numel) {
return;
}

float in_val = float(t_in[i]);
t_out[i] = T(tan(in_val));
}

#else

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

VEC4_T in_texel = texelFetch(t_in, pos, 0);
imageStore(t_out, pos, VEC4_T(tan(in_texel)));
}

#endif
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/tan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
tan:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
STORAGE:
- VALUE: texture3d
- VALUE: buffer
shader_variants:
- NAME: tan
64 changes: 64 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Tan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

using namespace utils;

void resize_tan_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr self = graph->get_tensor(args[1].refs[0]);

out->virtual_resize(self->sizes());
}

void add_tan_node(ComputeGraph& graph, const ValueRef in, const ValueRef out) {
std::string kernel_name = "tan";
add_dtype_suffix(kernel_name, graph.dtype_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));

vkapi::ParamsBindList ubos({});
ubos.append({graph.logical_limits_ubo(out)});

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Shader params buffers
ubos,
// Push Constants
{},
// Specialization Constants
{},
// Resize Args
{},
// Resizing Logic
resize_tan_node));
}

void tan(ComputeGraph& graph, const std::vector<ValueRef>& args) {
return add_tan_node(graph, args[0], args[1]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.tan.default, tan);
}

} // namespace vkcompute
16 changes: 16 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,22 @@ def get_unary_ops_inputs():
return test_suite


# separate test suite from unary_ops for learning purposes
@register_test_suite("aten.tan.default")
def get_tan_inputs():
test_suite = VkTestSuite(
[
(M1,),
(M1, M2),
(S1, M1, M2),
(S1, S2, S2, M2),
]
)
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
test_suite.dtypes = ["at::kFloat", "at::kHalf"]
return test_suite


@register_test_suite("aten._native_batch_norm_legit_no_training.default")
def get_native_batch_norm_inputs():
Test = namedtuple(
Expand Down
Loading