Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 116 additions & 57 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,75 +420,133 @@ def register_softmax_op():
)


@update_features(
[
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
]
)
def register_reduce_op():
def check_reduce_node(node: torch.fx.Node) -> bool:
# Only one argument implies that the reduction is over the entire tensor, which
# is not supported yet.
if len(node.args) == 1:
return False
def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]:
ndim = utils.ndim_of(node.args[0])
assert ndim is not None
dims_reduced = None
if len(node.args) >= 1:
dims_reduced = node.args[1]

dim_list = node.args[1]
# Only 1D and 2D reductions are supported at the moment.
if isinstance(dim_list, list) and len(dim_list) > 2:
return False
# If dim_list is None, return a list containing all the dims of the tensor
if dims_reduced is None:
dims_reduced = list(range(ndim))

def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
for arg in node.args:
if isinstance(arg, bool):
return arg
# Special case for reducing tensors with shape [1, N] - this is equivalent to
# reducing the last dim.
if utils.is_unsqueezed_vector(node) and ndim == 2:
dims_reduced = 1

# Assume false by default
return False
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) == 1:
dims_reduced = dims_reduced[0]

keepdim = try_find_keepdim_arg(node)
if isinstance(keepdim, bool) and not keepdim:
return False
assert isinstance(dims_reduced, (int, list, tuple))
return utils.normalize_dims(dims_reduced, ndim)

return True

def pick_io_storage_for_reduce(node: torch.fx.Node):
inputs_storage = utils.ANY_TEXTURE
outputs_storage = utils.ANY_TEXTURE

input_tensor = node.args[0]
ndim = input_tensor.meta["val"].ndim
dim_list = node.args[1]
if isinstance(dim_list, list) and len(dim_list) == 2:
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)

possible_packed_dims = {0, 1, 2}
possible_packed_dims.discard(reduce_dim1_whcn)
possible_packed_dims.discard(reduce_dim2_whcn)

packed_dim = possible_packed_dims.pop()
assert packed_dim in [0, 1, 2]

if packed_dim == 0:
inputs_storage = utils.WIDTH_PACKED_TEXTURE
outputs_storage = utils.WIDTH_PACKED_TEXTURE
elif packed_dim == 1:
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
else:
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
def get_keepdim_setting(node: torch.fx.Node) -> bool:
for arg in node.args:
if isinstance(arg, bool):
return arg

# Assume false by default
return False


def is_reduce_node_supported_by_per_row_impl(node: torch.fx.Node) -> bool:
"""
Checks if a reduction node is supported by the Vulkan backend's reduce per row
special case implementation.
"""
input_ndim = utils.ndim_of(node.args[0])
assert input_ndim is not None
dims_reduced = get_dims_reduced(node)

return dims_reduced == input_ndim - 1


def is_reduce_node_supported_by_general_impl(node: torch.fx.Node) -> bool:
dims_reduced = get_dims_reduced(node)
# Only 1D and 2D reductions are supported at the moment.
if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) > 2:
return False

keepdim = get_keepdim_setting(node)
# keepdim = False is not supported yet for general implementation
if isinstance(keepdim, bool) and not keepdim:
return False

return True


def is_reduce_node_supported(node: torch.fx.Node) -> bool:
# 0-dim output unsupported at the moment
if utils.ndim_of(node) == 0:
return False

return is_reduce_node_supported_by_per_row_impl(
node
) or is_reduce_node_supported_by_general_impl(node)


def pick_storage_for_reduce(node: torch.fx.Node):
inputs_storage = utils.NO_STORAGE
outputs_storage = utils.NO_STORAGE

ndim = utils.ndim_of(node.args[0])
dim_list = node.args[1]

if is_reduce_node_supported_by_general_impl(node):
inputs_storage = inputs_storage.make_union(utils.ANY_TEXTURE)
outputs_storage = inputs_storage

# For 1D reductions of the last dim, a special reduce per row case is implemented
# for buffer backed tensors.
if is_reduce_node_supported_by_per_row_impl(node):
inputs_storage = inputs_storage.make_union(utils.CONTIGUOUS_BUFFER)
outputs_storage = inputs_storage
return inputs_storage, outputs_storage

# For 2D reductions, the packed dimension cannot be one of the reduced dims
if isinstance(dim_list, (list, tuple)) and len(dim_list) == 2:
# pyre-ignore[6]
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
# pyre-ignore[6]
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)

possible_packed_dims = {0, 1, 2}
possible_packed_dims.discard(reduce_dim1_whcn)
possible_packed_dims.discard(reduce_dim2_whcn)

packed_dim = possible_packed_dims.pop()
assert packed_dim in [0, 1, 2]

if packed_dim == 0:
inputs_storage = utils.WIDTH_PACKED_TEXTURE
outputs_storage = utils.WIDTH_PACKED_TEXTURE
elif packed_dim == 1:
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
else:
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
outputs_storage = utils.CHANNELS_PACKED_TEXTURE

return inputs_storage, outputs_storage


@update_features(
[
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
]
)
def register_reduce_op():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
supports_resize=True,
are_node_inputs_supported_fn=check_reduce_node,
pick_io_storage_fn=pick_io_storage_for_reduce,
are_node_inputs_supported_fn=is_reduce_node_supported,
pick_io_storage_fn=pick_storage_for_reduce,
)


Expand All @@ -515,6 +573,7 @@ def register_2d_pool_op():
def register_convolution_op():
def check_conv_node(node: torch.fx.Node) -> bool:
x = node.args[0]
assert isinstance(x, torch.fx.Node)
x_shape = x.meta["val"].size()
# 4-D input implies 2D convolution
if len(x_shape) == 4:
Expand Down
94 changes: 94 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef REDUCE_OP_DEFS_GLSLH
#define REDUCE_OP_DEFS_GLSLH

struct Accum {
T val;
uint idx;
uint count;
};

void init_accum(out Accum accum, T val, uint idx) {
accum.val = val;
accum.idx = idx;
accum.count = 1;
}

void init_accum_zero(out Accum accum) {
accum.val = T(0);
accum.idx = 0;
accum.count = 0;
}

// Sum / Mean

void update_accum_sum(inout Accum accum, T val, uint idx) {
accum.val += val;
accum.count += 1;
}

void merge_accum_sum(inout Accum accum, const Accum other) {
accum.val += other.val;
accum.count += other.count;
}

void postprocess_accum_mean(inout Accum accum) {
accum.val /= T(accum.count);
}

// Amax (maximum value)

void update_accum_amax(inout Accum accum, T val, uint idx) {
if (val > accum.val) {
accum.val = val;
accum.idx = idx;
}
// For equivalence, select the lower index
if (val == accum.val && idx < accum.idx) {
accum.idx = idx;
}
}

void merge_accum_amax(inout Accum accum, const Accum other) {
if (other.val > accum.val) {
accum.val = other.val;
accum.idx = other.idx;
}
// For equivalence, select the lower index
if (other.val == accum.val && other.idx < accum.idx) {
accum.idx = other.idx;
}
}

// Amin (minimum value)

void update_accum_amin(inout Accum accum, T val, uint idx) {
if (val < accum.val) {
accum.val = val;
accum.idx = idx;
}
// For equivalence, select the lower index
if (val == accum.val && idx < accum.idx) {
accum.idx = idx;
}
}

void merge_accum_amin(inout Accum accum, const Accum other) {
if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) {
accum.val = other.val;
accum.idx = other.idx;
}
// For equivalence, select the lower index
if (other.val == accum.val && other.idx < accum.idx) {
accum.idx = other.idx;
}
}

#endif // REDUCE_OP_DEFS_GLSLH
Loading
Loading