diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b47a8f383a0..92a98ce82a7 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -420,75 +420,133 @@ def register_softmax_op(): ) -@update_features( - [ - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten.amax.default, - exir_ops.edge.aten.amin.default, - ] -) -def register_reduce_op(): - def check_reduce_node(node: torch.fx.Node) -> bool: - # Only one argument implies that the reduction is over the entire tensor, which - # is not supported yet. - if len(node.args) == 1: - return False +def get_dims_reduced(node: torch.fx.Node) -> Union[int, List[int]]: + ndim = utils.ndim_of(node.args[0]) + assert ndim is not None + dims_reduced = None + if len(node.args) >= 1: + dims_reduced = node.args[1] - dim_list = node.args[1] - # Only 1D and 2D reductions are supported at the moment. - if isinstance(dim_list, list) and len(dim_list) > 2: - return False + # If dim_list is None, return a list containing all the dims of the tensor + if dims_reduced is None: + dims_reduced = list(range(ndim)) - def try_find_keepdim_arg(node: torch.fx.Node) -> bool: - for arg in node.args: - if isinstance(arg, bool): - return arg + # Special case for reducing tensors with shape [1, N] - this is equivalent to + # reducing the last dim. + if utils.is_unsqueezed_vector(node) and ndim == 2: + dims_reduced = 1 - # Assume false by default - return False + if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) == 1: + dims_reduced = dims_reduced[0] - keepdim = try_find_keepdim_arg(node) - if isinstance(keepdim, bool) and not keepdim: - return False + assert isinstance(dims_reduced, (int, list, tuple)) + return utils.normalize_dims(dims_reduced, ndim) - return True - def pick_io_storage_for_reduce(node: torch.fx.Node): - inputs_storage = utils.ANY_TEXTURE - outputs_storage = utils.ANY_TEXTURE - - input_tensor = node.args[0] - ndim = input_tensor.meta["val"].ndim - dim_list = node.args[1] - if isinstance(dim_list, list) and len(dim_list) == 2: - reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim) - reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim) - - possible_packed_dims = {0, 1, 2} - possible_packed_dims.discard(reduce_dim1_whcn) - possible_packed_dims.discard(reduce_dim2_whcn) - - packed_dim = possible_packed_dims.pop() - assert packed_dim in [0, 1, 2] - - if packed_dim == 0: - inputs_storage = utils.WIDTH_PACKED_TEXTURE - outputs_storage = utils.WIDTH_PACKED_TEXTURE - elif packed_dim == 1: - inputs_storage = utils.HEIGHT_PACKED_TEXTURE - outputs_storage = utils.HEIGHT_PACKED_TEXTURE - else: - inputs_storage = utils.CHANNELS_PACKED_TEXTURE - outputs_storage = utils.CHANNELS_PACKED_TEXTURE +def get_keepdim_setting(node: torch.fx.Node) -> bool: + for arg in node.args: + if isinstance(arg, bool): + return arg + + # Assume false by default + return False + + +def is_reduce_node_supported_by_per_row_impl(node: torch.fx.Node) -> bool: + """ + Checks if a reduction node is supported by the Vulkan backend's reduce per row + special case implementation. + """ + input_ndim = utils.ndim_of(node.args[0]) + assert input_ndim is not None + dims_reduced = get_dims_reduced(node) + + return dims_reduced == input_ndim - 1 + + +def is_reduce_node_supported_by_general_impl(node: torch.fx.Node) -> bool: + dims_reduced = get_dims_reduced(node) + # Only 1D and 2D reductions are supported at the moment. + if isinstance(dims_reduced, (list, tuple)) and len(dims_reduced) > 2: + return False + + keepdim = get_keepdim_setting(node) + # keepdim = False is not supported yet for general implementation + if isinstance(keepdim, bool) and not keepdim: + return False + + return True + + +def is_reduce_node_supported(node: torch.fx.Node) -> bool: + # 0-dim output unsupported at the moment + if utils.ndim_of(node) == 0: + return False + + return is_reduce_node_supported_by_per_row_impl( + node + ) or is_reduce_node_supported_by_general_impl(node) + +def pick_storage_for_reduce(node: torch.fx.Node): + inputs_storage = utils.NO_STORAGE + outputs_storage = utils.NO_STORAGE + + ndim = utils.ndim_of(node.args[0]) + dim_list = node.args[1] + + if is_reduce_node_supported_by_general_impl(node): + inputs_storage = inputs_storage.make_union(utils.ANY_TEXTURE) + outputs_storage = inputs_storage + + # For 1D reductions of the last dim, a special reduce per row case is implemented + # for buffer backed tensors. + if is_reduce_node_supported_by_per_row_impl(node): + inputs_storage = inputs_storage.make_union(utils.CONTIGUOUS_BUFFER) + outputs_storage = inputs_storage return inputs_storage, outputs_storage + # For 2D reductions, the packed dimension cannot be one of the reduced dims + if isinstance(dim_list, (list, tuple)) and len(dim_list) == 2: + # pyre-ignore[6] + reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim) + # pyre-ignore[6] + reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim) + + possible_packed_dims = {0, 1, 2} + possible_packed_dims.discard(reduce_dim1_whcn) + possible_packed_dims.discard(reduce_dim2_whcn) + + packed_dim = possible_packed_dims.pop() + assert packed_dim in [0, 1, 2] + + if packed_dim == 0: + inputs_storage = utils.WIDTH_PACKED_TEXTURE + outputs_storage = utils.WIDTH_PACKED_TEXTURE + elif packed_dim == 1: + inputs_storage = utils.HEIGHT_PACKED_TEXTURE + outputs_storage = utils.HEIGHT_PACKED_TEXTURE + else: + inputs_storage = utils.CHANNELS_PACKED_TEXTURE + outputs_storage = utils.CHANNELS_PACKED_TEXTURE + + return inputs_storage, outputs_storage + + +@update_features( + [ + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.amax.default, + exir_ops.edge.aten.amin.default, + ] +) +def register_reduce_op(): return OpFeatures( inputs_storage=utils.ANY_TEXTURE, supports_resize=True, - are_node_inputs_supported_fn=check_reduce_node, - pick_io_storage_fn=pick_io_storage_for_reduce, + are_node_inputs_supported_fn=is_reduce_node_supported, + pick_io_storage_fn=pick_storage_for_reduce, ) @@ -515,6 +573,7 @@ def register_2d_pool_op(): def register_convolution_op(): def check_conv_node(node: torch.fx.Node) -> bool: x = node.args[0] + assert isinstance(x, torch.fx.Node) x_shape = x.meta["val"].size() # 4-D input implies 2D convolution if len(x_shape) == 4: diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh b/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh new file mode 100644 index 00000000000..bbfb991808f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_op_defs.glslh @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef REDUCE_OP_DEFS_GLSLH +#define REDUCE_OP_DEFS_GLSLH + +struct Accum { + T val; + uint idx; + uint count; +}; + +void init_accum(out Accum accum, T val, uint idx) { + accum.val = val; + accum.idx = idx; + accum.count = 1; +} + +void init_accum_zero(out Accum accum) { + accum.val = T(0); + accum.idx = 0; + accum.count = 0; +} + +// Sum / Mean + +void update_accum_sum(inout Accum accum, T val, uint idx) { + accum.val += val; + accum.count += 1; +} + +void merge_accum_sum(inout Accum accum, const Accum other) { + accum.val += other.val; + accum.count += other.count; +} + +void postprocess_accum_mean(inout Accum accum) { + accum.val /= T(accum.count); +} + +// Amax (maximum value) + +void update_accum_amax(inout Accum accum, T val, uint idx) { + if (val > accum.val) { + accum.val = val; + accum.idx = idx; + } + // For equivalence, select the lower index + if (val == accum.val && idx < accum.idx) { + accum.idx = idx; + } +} + +void merge_accum_amax(inout Accum accum, const Accum other) { + if (other.val > accum.val) { + accum.val = other.val; + accum.idx = other.idx; + } + // For equivalence, select the lower index + if (other.val == accum.val && other.idx < accum.idx) { + accum.idx = other.idx; + } +} + +// Amin (minimum value) + +void update_accum_amin(inout Accum accum, T val, uint idx) { + if (val < accum.val) { + accum.val = val; + accum.idx = idx; + } + // For equivalence, select the lower index + if (val == accum.val && idx < accum.idx) { + accum.idx = idx; + } +} + +void merge_accum_amin(inout Accum accum, const Accum other) { + if (other.count > 0 && (accum.count == 0 || other.val < accum.val)) { + accum.val = other.val; + accum.idx = other.idx; + } + // For equivalence, select the lower index + if (other.val == accum.val && other.idx < accum.idx) { + accum.idx = other.idx; + } +} + +#endif // REDUCE_OP_DEFS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl new file mode 100644 index 00000000000..d1574a67a62 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${texel_load_component_type(DTYPE, "buffer")} + +#define NUM_OUTPUTS_PER_WG 1 +#define NUM_WORKERS_PER_OUTPUT 64 + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" +#include "reduce_op_defs.glslh" + +$if OUTPUT_IS_INDICES: + ${layout_declare_tensor(B, "w", "t_out", "int", "buffer")} +$else: + ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// Shared memory for cooperative reduction +shared Accum shared_values[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; + +#define init_fn ${INIT_ACCUM_FN} +#define update_fn ${UPDATE_ACCUM_FN} +#define merge_fn ${MERGE_ACCUM_FN} + +$if POSTPROCESS_ACCUM_FN != "none": + #define postprocess_fn ${POSTPROCESS_ACCUM_FN} + +$if OOB_INIT_MODE == "zero": + #define OOB_INIT_MODE 0 +$else: + #define OOB_INIT_MODE 1 + +$if OUTPUT_IS_INDICES: + #define OUTPUT_IS_INDICES + +#extension GL_EXT_debug_printf : require + +void main() { + const uint out_bufi = gl_GlobalInvocationID.y; + + if (out_of_bounds(out_bufi, outp)) { + return; + } + + // Local indices + const uint worker_id = gl_LocalInvocationID.x; + const uint output_id = gl_LocalInvocationID.y; + + const uint in_bufi_base = out_bufi * width(inp); + + Accum local_accum; + // Initialize accumulator with the first element being processed + if (worker_id < width(inp)) { + const uint in_bufi = in_bufi_base + worker_id; + init_fn(local_accum, t_in[in_bufi], worker_id); + } + // For out of bounds case, initialization depends on reduction op + else { +#if OOB_INIT_MODE == 0 + // Init with a zero value + init_accum_zero(local_accum); +#else + // Init with the first value (i.e. amin, amax) + init_fn(local_accum, t_in[in_bufi_base], 0); +#endif + } + + for (uint x = worker_id + NUM_WORKERS_PER_OUTPUT; x < width(inp); + x += NUM_WORKERS_PER_OUTPUT) { + update_fn(local_accum, t_in[in_bufi_base + x], x); + } + + shared_values[output_id][worker_id] = local_accum; + + memoryBarrierShared(); + barrier(); + + for (int i = NUM_WORKERS_PER_OUTPUT / 2; i > 0; i >>= 1) { + if (worker_id < i) { + merge_fn( + shared_values[output_id][worker_id], + shared_values[output_id][worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + if (worker_id == 0) { + local_accum = shared_values[output_id][0]; +#ifdef postprocess_fn + postprocess_fn(local_accum); +#endif + +#ifdef OUTPUT_IS_INDICES + t_out[out_bufi] = int(0); // int(local_accum.idx); +#else + t_out[out_bufi] = local_accum.val; +#endif + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml new file mode 100644 index 00000000000..e5a94165b96 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.yaml @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +reduce_per_row_buffer: + parameter_names_with_default_values: + DTYPE: float + INIT_ACCUM_FN: init_accum + UPDATE_ACCUM_FN: update_accum_sum + MERGE_ACCUM_FN: merge_accum_sum + POSTPROCESS_ACCUM_FN: none + OOB_INIT_MODE: zero + OUTPUT_IS_INDICES: false + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + - VALUE: int32 + shader_variants: + - NAME: sum_per_row_buffer + - NAME: mean_per_row_buffer + POSTPROCESS_ACCUM_FN: postprocess_accum_mean + - NAME: amax_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amax + MERGE_ACCUM_FN: merge_accum_amax + OOB_INIT_MODE: first_element + - NAME: amin_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amin + MERGE_ACCUM_FN: merge_accum_amin + OOB_INIT_MODE: first_element + - NAME: argmax_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amax + MERGE_ACCUM_FN: merge_accum_amax + OOB_INIT_MODE: first_element + OUTPUT_IS_INDICES: true + - NAME: argmin_per_row_buffer + UPDATE_ACCUM_FN: update_accum_amin + MERGE_ACCUM_FN: merge_accum_amin + OOB_INIT_MODE: first_element + OUTPUT_IS_INDICES: true diff --git a/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp b/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp new file mode 100644 index 00000000000..68a51602f74 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/ArgReduce.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +namespace vkcompute { + +void arg_reduce_impl( + ComputeGraph& graph, + const std::vector& args, + const std::string& op_name) { + int arg_idx = 0; + const ValueRef in = args.at(arg_idx++); + const ValueRef dim = args.at(arg_idx++); + const ValueRef keepdim = args.at(arg_idx++); + const ValueRef out = args.at(arg_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in)); + + int64_t dim_val = 0; + if (graph.val_is_not_none(dim)) { + dim_val = graph.extract_scalar(dim); + } + const int64_t ndim = graph.dim_of(in); + const int64_t normalized_dim = normalize(dim_val, graph.dim_of(in)); + + VK_CHECK_COND(normalized_dim == ndim - 1); + + // Use the reduce_per_row_node function + add_reduce_per_row_node(graph, in, keepdim, out, op_name); +} + +void argmin(ComputeGraph& graph, const std::vector& args) { + arg_reduce_impl(graph, args, "argmin"); +} + +void argmax(ComputeGraph& graph, const std::vector& args) { + arg_reduce_impl(graph, args, "argmax"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.argmin.default, argmin); + VK_REGISTER_OP(aten.argmax.default, argmax); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index 6ad1d7f371d..feb36301202 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -52,6 +52,26 @@ void resize_reduce2d_node( graph->virtual_resize(out, new_sizes); } +void resize_reduce_per_row_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + + const bool keepdim = graph->extract_scalar(resize_args.at(0)); + + std::vector new_sizes = graph->sizes_of(in); + if (keepdim) { + // Per-row reduction always reduces along the last dimension (width) + new_sizes.back() = 1; + } else { + // Remove the last dimension + new_sizes.pop_back(); + } + graph->virtual_resize(out, new_sizes); +} + utils::uvec3 reduce_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -237,12 +257,83 @@ void add_reduce2d_node( resize_reduce2d_node)); } +utils::uvec3 reduce_per_row_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef in = args.at(1).refs.at(0); + return {1u, utils::safe_downcast(graph->numel_of(in)), 1u}; +} + +utils::uvec3 reduce_per_row_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + uint32_t outputs_per_wg = 1u; + uint32_t workers_per_output = 64u; + + return {workers_per_output, outputs_per_wg, 1u}; +} + +void add_reduce_per_row_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef keepdim_ref, + const ValueRef output, + const std::string& op_name) { + std::string kernel_name = op_name + "_per_row"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(output), + graph.meta_ubo(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + // Global workgroup size function + reduce_per_row_global_wg_size, + // Local workgroup size function + reduce_per_row_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {keepdim_ref}, + // Resizing Logic + resize_reduce_per_row_node)); +} + #define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ const std::vector dims_list = \ graph.extract_int_or_symint_list(args[1]); \ if (dims_list.size() == 1) { \ - const int64_t dim_val = dims_list.at(0); \ + int64_t dim_val = dims_list.at(0); \ + int64_t ndim = graph.dim_of(args[0]); \ + if ((dim_val == -1 || dim_val == ndim - 1) && \ + graph.is_buffer_storage(args[0])) { \ + return add_reduce_per_row_node( \ + graph, args[0], args[2], args[out_arg_idx], #op_name); \ + } \ const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ return add_reduce_node( \ graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.h b/backends/vulkan/runtime/graph/ops/impl/Reduce.h new file mode 100644 index 00000000000..7d38e438d31 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace vkcompute { + +void add_reduce_per_row_node( + ComputeGraph& graph, + const ValueRef input, + const ValueRef keepdim_ref, + const ValueRef output, + const std::string& op_name); + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 76e5dc1bc0e..dfb9a2865ba 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1557,6 +1557,21 @@ def get_reduce_inputs(is_softmax: bool = False): ] +def get_reduce_per_row_inputs(): + inputs = [ + ((5, 10), 1, False), + ((5, 16), -1, True), + ((5, 16), -1, False), + ((7, 21), -1, True), + ((7, 21), -1, False), + ((3, 7, 280), -1, True), + ((3, 7, 280), -1, False), + ((3, 17, 77), -1, True), + ((3, 17, 77), -1, False), + ] + return inputs + + @register_test_suite(["aten._softmax.default", "aten._log_softmax.default"]) def get_softmax_inputs(): test_suite = VkTestSuite(get_reduce_inputs(is_softmax=True)) @@ -1576,6 +1591,20 @@ def get_reduce_op_inputs(): "utils::kChannelsPacked", "utils::kWidthPacked", ] + + per_row_suite = VkTestSuite(get_reduce_per_row_inputs()) + per_row_suite.layouts = ["utils::kWidthPacked"] + per_row_suite.storage_types = ["utils::kBuffer"] + per_row_suite.test_name_suffix = "per_row" + return [test_suite, per_row_suite] + + +@register_test_suite(["aten.argmin.default", "aten.argmax.default"]) +def get_reduce_arg_op_inputs(): + test_suite = VkTestSuite(get_reduce_per_row_inputs()) + test_suite.layouts = ["utils::kWidthPacked"] + test_suite.storage_types = ["utils::kBuffer"] + test_suite.dtypes = ["at::kFloat"] return test_suite diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 6a510e65925..00147dab2c3 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -757,6 +757,7 @@ def make_filtered_tensor_repset( CONTIGUOUS_BUFFER = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) WIDTH_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_WIDTH_PACKED}) +HEIGHT_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_HEIGHT_PACKED}) CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts)