Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backends/vulkan/_passes/remove_redundant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class RemoveRedundantOpsTransform(ExportPass):
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.aten.expand_copy.default,
# copy.default(self, src): no-op when src dtype/shape matches self.
exir_ops.edge.aten.copy.default,
}

# For these ops the meaningful input is args[1] (src), not args[0] (self).
_src_arg1_ops: Set[OpType] = {
exir_ops.edge.aten.copy.default,
}

def __init__(self) -> None:
Expand All @@ -41,7 +48,8 @@ def _should_remove(self, node: torch.fx.Node) -> bool:
if node.target not in self.redundant_ops:
return False

orig_node = node.args[0]
src_arg_idx = 1 if node.target in self._src_arg1_ops else 0
orig_node = node.args[src_arg_idx]
assert isinstance(orig_node, torch.fx.Node)

src_dtype = orig_node.meta["val"].dtype
Expand All @@ -61,7 +69,8 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
if not self._should_remove(node):
continue

node.replace_all_uses_with(node.args[0])
src_arg_idx = 1 if node.target in self._src_arg1_ops else 0
node.replace_all_uses_with(node.args[src_arg_idx])

graph_module.graph.eliminate_dead_code()

Expand Down
50 changes: 49 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def update_features_impl(op: OpKey):
# Guard and assert ops
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
# copy.default is a no-op when src dtype matches dst dtype; removed by
# RemoveRedundantOpsTransform before execution.
exir_ops.edge.aten.copy.default,
]
)
def register_ephemeral_ops():
Expand Down Expand Up @@ -231,17 +234,46 @@ def register_clamp():
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.pow.Tensor_Tensor,
]
)
def register_binaryop_cpp_ops():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
supports_highdim=True,
)


@update_features(
[
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.ge.Tensor,
]
)
def register_binaryop_cpp_ops():
def register_comparison_ops():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# BinaryOp.cpp (bitwise)
# =============================================================================


@update_features(exir_ops.edge.aten.bitwise_and.Tensor)
def register_bitwise_and():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)
Expand Down Expand Up @@ -673,6 +705,7 @@ def register_argreduce_cpp_ops():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
inputs_dtypes=utils.FP_T,
outputs_dtypes=utils.INT_T,
supports_resize=True,
supports_highdim=True,
are_node_inputs_supported_fn=is_reduce_node_supported,
Expand Down Expand Up @@ -1157,6 +1190,21 @@ def register_index_select():
)


# =============================================================================
# Where.cpp
# =============================================================================


@update_features(exir_ops.edge.aten.where.self)
def register_where():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=[utils.BOOL_T, utils.FP_T, utils.FP_T],
outputs_dtypes=utils.FP_T,
supports_resize=True,
)


# =============================================================================
# Arange.cpp
# =============================================================================
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,11 @@ binary_op:
- VALUE: half
- VALUE: float
- VALUE: int32
- NAME: binary_bitwise_and
OPERATOR: X & Y
generate_variant_forall:
STORAGE:
- VALUE: buffer
- VALUE: texture3d
DTYPE:
- VALUE: uint8
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void main() {
#endif

#ifdef OUTPUT_IS_INDICES
t_out[out_bufi] = int(0); // int(local_accum.idx);
t_out[out_bufi] = int(local_accum.idx);
#else
t_out[out_bufi] = convert_to_T(local_accum.val);
#endif
Expand Down
101 changes: 62 additions & 39 deletions backends/vulkan/runtime/graph/ops/glsl/where.glsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// where.glsl

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand All @@ -8,7 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/


#version 450 core

${define_required_extensions(STORAGE, DTYPE)}
Expand All @@ -24,44 +21,50 @@ ${define_active_storage_type(STORAGE)}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_condition", "bool", STORAGE)}
${layout_declare_tensor(B, "r", "t_self", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}


#include "indexing_utils.h"

$if STORAGE == "buffer":
${layout_declare_ubo(B, "int", "out_numl")}
${layout_declare_ubo(B, "ivec4", "out_strides")}
${layout_declare_ubo(B, "ivec4", "cond_strides")}
${layout_declare_ubo(B, "ivec4", "self_strides")}
${layout_declare_ubo(B, "ivec4", "other_strides")}
${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "condp")}
${layout_declare_ubo(B, "BufferMetadata", "selfp")}
${layout_declare_ubo(B, "BufferMetadata", "otherp")}
$else:
${layout_declare_ubo(B, "ivec3", "out_limits")}

${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_DIM_ORDER")}

const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
${layout_declare_ubo(B, "TextureMetadata", "outp")}
${layout_declare_ubo(B, "TextureMetadata", "condp")}
${layout_declare_ubo(B, "TextureMetadata", "selfp")}
${layout_declare_ubo(B, "TextureMetadata", "otherp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

#ifdef USING_BUFFER

void main() {
int out_bufi = int(gl_GlobalInvocationID.x);
if (out_bufi >= out_numl) {
const uint out_bufi = gl_GlobalInvocationID.x;
if (out_of_bounds(out_bufi, outp)) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi);

TensorIndex cond_tidx = out_tidx;
clamp_tensor_idx(condp, cond_tidx);

const int cond_bufi = tidx_to_bufi(out_tidx, cond_strides);
const int self_bufi = tidx_to_bufi(out_tidx, self_strides);
const int other_bufi = tidx_to_bufi(out_tidx, other_strides);
TensorIndex self_tidx = out_tidx;
clamp_tensor_idx(selfp, self_tidx);

COND_T cond = t_condition[cond_bufi] ;
TensorIndex other_tidx = out_tidx;
clamp_tensor_idx(otherp, other_tidx);

const uint cond_bufi = tensor_idx_to_linear_idx(condp, cond_tidx);
const uint self_bufi = tensor_idx_to_linear_idx(selfp, self_tidx);
const uint other_bufi = tensor_idx_to_linear_idx(otherp, other_tidx);

COND_T cond = t_condition[cond_bufi];
T v_self = t_self[self_bufi];
T v_other = t_other[other_bufi];

Expand All @@ -72,29 +75,49 @@ void main() {
}
}

#else // !USING_BUFFER
#else // USING_TEXTURE

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);


if (any(greaterThanEqual(pos, out_limits))) {
if (out_of_bounds(out_pos, outp)) {
return;
}

vec4 cond = load_texel(t_condition, pos);
VEC4_T selftex = load_texel(t_self, pos);
VEC4_T othertex = load_texel(t_other, pos);

VEC4_T outtex;

for (int idx = 0; idx < 4; ++idx) {
if (cond[idx] == 1) {
outtex[idx] = selftex[idx];
TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);

VEC4_T outtex = VEC4_T(0);

int limit = min(
4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]);
for (int comp = 0; comp < limit; comp++) {
TensorIndex4D cond_tidx;
cond_tidx.data = min(out_tidx.data, condp.sizes - 1);
TextureElementIndex cond_elem =
tensor4d_idx_to_texture_element_idx_simple(condp, cond_tidx);
uint cond_val = texelFetch(t_condition, cond_elem.pos, 0)[cond_elem.comp];

TensorIndex4D self_tidx;
self_tidx.data = min(out_tidx.data, selfp.sizes - 1);
TextureElementIndex self_elem =
tensor4d_idx_to_texture_element_idx_simple(selfp, self_tidx);
VEC4_T self_texel = texelFetch(t_self, self_elem.pos, 0);

TensorIndex4D other_tidx;
other_tidx.data = min(out_tidx.data, otherp.sizes - 1);
TextureElementIndex other_elem =
tensor4d_idx_to_texture_element_idx_simple(otherp, other_tidx);
VEC4_T other_texel = texelFetch(t_other, other_elem.pos, 0);

if (cond_val > 0) {
outtex[comp] = self_texel[self_elem.comp];
} else {
outtex[idx] = othertex[idx];
outtex[comp] = other_texel[other_elem.comp];
}

out_tidx.data[outp.packed_dim]++;
}
write_texel(t_out, pos, outtex);

imageStore(t_out, out_pos, outtex);
}
#endif // !USING_BUFFER
#endif
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ DEFINE_BINARY_OP_FN(lt);
DEFINE_BINARY_OP_FN(le);
DEFINE_BINARY_OP_FN(gt);
DEFINE_BINARY_OP_FN(ge);
DEFINE_BINARY_OP_FN(bitwise_and);

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.add.Tensor, add);
Expand All @@ -212,6 +213,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.le.Tensor, le);
VK_REGISTER_OP(aten.gt.Tensor, gt);
VK_REGISTER_OP(aten.ge.Tensor, ge);
VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and);
}

} // namespace vkcompute
55 changes: 10 additions & 45 deletions backends/vulkan/runtime/graph/ops/impl/Where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,13 @@ void resize_where_node(
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
const ValueRef out = args.at(0).refs.at(0);
const ValueRef in = args.at(1).refs.at(0);
const ValueRef self = args.at(1).refs.at(1);

const std::vector<int64_t> in_sizes = graph->sizes_of(in);
graph->virtual_resize(out, in_sizes);
const std::vector<int64_t> self_sizes = graph->sizes_of(self);
graph->virtual_resize(out, self_sizes);
}

void add_where_texture_node(
ComputeGraph& graph,
const ValueRef cond,
const ValueRef self,
const ValueRef other,
const ValueRef out) {
std::string kernel_name = "where";

add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}},
// Parameter buffers
{graph.logical_limits_ubo(self)},
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(out)},
// Resize Arguments
{},
// Resizing Logic
resize_where_node));
}

void add_where_buffer_node(
void add_where_node(
ComputeGraph& graph,
const ValueRef cond,
const ValueRef self,
Expand All @@ -69,11 +39,10 @@ void add_where_buffer_node(
add_dtype_suffix(kernel_name, graph.dtype_of(out));

vkapi::ParamsBindList ubos = {
graph.numel_ubo(out),
graph.strides_ubo(out),
graph.strides_ubo(cond),
graph.strides_ubo(self),
graph.strides_ubo(other)};
graph.meta_ubo(out),
graph.meta_ubo(cond),
graph.meta_ubo(self),
graph.meta_ubo(other)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
Expand All @@ -87,7 +56,7 @@ void add_where_buffer_node(
// Push Constants
{},
// Specialization Constants
{graph.hashed_layout_of(out)},
{},
// Resize Arguments
{},
// Resizing Logic
Expand All @@ -100,11 +69,7 @@ void where(ComputeGraph& graph, const std::vector<ValueRef>& args) {
const ValueRef self = args[args_i++];
const ValueRef other = args[args_i++];
const ValueRef out = args[args_i++];
if (graph.is_buffer_storage(out)) {
add_where_buffer_node(graph, cond, self, other, out);
} else {
add_where_texture_node(graph, cond, self, other, out);
}
add_where_node(graph, cond, self, other, out);
}

REGISTER_OPERATORS {
Expand Down
Loading
Loading