diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index f9f8aeb79e3..ade82bcde3b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -434,6 +434,7 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool: @update_features( [ exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.max_pool2d.default, exir_ops.edge.aten.max_pool2d_with_indices.default, ] ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl index 9d78b7a6a6e..28afe5a822f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl @@ -24,6 +24,8 @@ ${layout_declare_ubo(B, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "pad layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "write_indices", "1")} + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -55,5 +57,7 @@ void main() { } imageStore(t_out, pos, out_texel); - imageStore(t_idx, pos, idx_texel); + if (write_indices > 0) { + imageStore(t_idx, pos, idx_texel); + } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index b3791a4f7d1..250fcdd5490 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -76,8 +76,19 @@ void add_max_pool2d_node( const ValueRef dilation, const ValueRef ceil_mode, const ValueRef out) { - const auto out_val = graph.get_value_list(out); - const ValueRef out_tensor = out_val->at(0); + ValueRef out_tensor = out; + // Placeholder tensor to fill binding slot for indices tensor in case we are + // computing max_pool2d instead of max_pool2d_with_indices. + TmpTensor tmp_indices_tensor = + TmpTensor(&graph, {}, graph.dtype_of(in), graph.storage_type_of(in)); + ValueRef indices_tensor = tmp_indices_tensor.vref; + int32_t write_indices = 0; + if (graph.val_is_value_list(out)) { + const auto out_val = graph.get_value_list(out); + out_tensor = out_val->at(0); + indices_tensor = out_val->at(1); + write_indices = 1; + } check_pool2d_args(graph, in, out_tensor); @@ -98,7 +109,7 @@ void add_max_pool2d_node( default_pick_global_wg_size, default_pick_local_wg_size, // Inputs and Outputs - {{{out_val->at(0), out_val->at(1)}, vkapi::kWrite}, {in, vkapi::kRead}}, + {{{out_tensor, indices_tensor}, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers { graph.logical_limits_ubo(out_tensor), @@ -108,7 +119,7 @@ void add_max_pool2d_node( // Push Constants {}, // Specialization Constants - {}, + {write_indices}, // Resize Args {kernel_size, stride, padding, dilation, ceil_mode}, // Resizing Logic @@ -203,6 +214,7 @@ void avg_pool2d(ComputeGraph& graph, const std::vector& args) { REGISTER_OPERATORS { VK_REGISTER_OP(aten.avg_pool2d.default, avg_pool2d); VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d); + VK_REGISTER_OP(aten.max_pool2d.default, max_pool2d); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index cb29d836056..8c5d0c4797b 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -270,11 +270,13 @@ def get_avg_pool2d_inputs(): return test_suite -@register_test_suite("aten.max_pool2d_with_indices.default") +@register_test_suite( + ["aten.max_pool2d_with_indices.default", "aten.max_pool2d.default"] +) def get_max_pool2d_inputs(): test_suite = VkTestSuite( [ - ((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]), + ((1, 7, 89, 77), [2, 2], [1, 1], [0, 0], [1, 1]), ] ) return test_suite