Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
@update_features(
[
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
]
)
Expand Down
6 changes: 5 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ ${layout_declare_ubo(B, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "pad

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "write_indices", "1")}

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

Expand Down Expand Up @@ -55,5 +57,7 @@ void main() {
}

imageStore(t_out, pos, out_texel);
imageStore(t_idx, pos, idx_texel);
if (write_indices > 0) {
imageStore(t_idx, pos, idx_texel);
}
}
20 changes: 16 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/Pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,19 @@ void add_max_pool2d_node(
const ValueRef dilation,
const ValueRef ceil_mode,
const ValueRef out) {
const auto out_val = graph.get_value_list(out);
const ValueRef out_tensor = out_val->at(0);
ValueRef out_tensor = out;
// Placeholder tensor to fill binding slot for indices tensor in case we are
// computing max_pool2d instead of max_pool2d_with_indices.
TmpTensor tmp_indices_tensor =
TmpTensor(&graph, {}, graph.dtype_of(in), graph.storage_type_of(in));
ValueRef indices_tensor = tmp_indices_tensor.vref;
int32_t write_indices = 0;
if (graph.val_is_value_list(out)) {
const auto out_val = graph.get_value_list(out);
out_tensor = out_val->at(0);
indices_tensor = out_val->at(1);
write_indices = 1;
}

check_pool2d_args(graph, in, out_tensor);

Expand All @@ -98,7 +109,7 @@ void add_max_pool2d_node(
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{{out_val->at(0), out_val->at(1)}, vkapi::kWrite}, {in, vkapi::kRead}},
{{{out_tensor, indices_tensor}, vkapi::kWrite}, {in, vkapi::kRead}},
// Shader params buffers
{
graph.logical_limits_ubo(out_tensor),
Expand All @@ -108,7 +119,7 @@ void add_max_pool2d_node(
// Push Constants
{},
// Specialization Constants
{},
{write_indices},
// Resize Args
{kernel_size, stride, padding, dilation, ceil_mode},
// Resizing Logic
Expand Down Expand Up @@ -203,6 +214,7 @@ void avg_pool2d(ComputeGraph& graph, const std::vector<ValueRef>& args) {
REGISTER_OPERATORS {
VK_REGISTER_OP(aten.avg_pool2d.default, avg_pool2d);
VK_REGISTER_OP(aten.max_pool2d_with_indices.default, max_pool2d);
VK_REGISTER_OP(aten.max_pool2d.default, max_pool2d);
}

} // namespace vkcompute
6 changes: 4 additions & 2 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,13 @@ def get_avg_pool2d_inputs():
return test_suite


@register_test_suite("aten.max_pool2d_with_indices.default")
@register_test_suite(
["aten.max_pool2d_with_indices.default", "aten.max_pool2d.default"]
)
def get_max_pool2d_inputs():
test_suite = VkTestSuite(
[
((S, M1, M2), [2, 2], [1, 1], [0, 0], [1, 1]),
((1, 7, 89, 77), [2, 2], [1, 1], [0, 0], [1, 1]),
]
)
return test_suite
Expand Down
Loading