Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
from executorch.backends.arm.tosa_specification import TosaSpecification


# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible:
# `(input + 2 * pad - kernel_size) / stride`
# PyTorch however, does not require this, so as needed, we must adjust the padding.
def adjust_pad_if_needed(
input_size: int, kernel_size: int, stride: int, pad: int
) -> int:
if pad == 0:
return pad

mod_remainder = (input_size + 2 * pad - kernel_size) % stride

# No need to adjust
if mod_remainder == 0:
return pad

return pad - mod_remainder


@register_node_visitor
class MaxPool2dVisitor_0_80(NodeVisitor):
target = "aten.max_pool2d.default"
Expand Down Expand Up @@ -61,6 +79,20 @@ def define_node(
except IndexError:
pad_size_list = [0, 0, 0, 0]

# Adjust the padding as necessary
pad_size_list[1] = adjust_pad_if_needed(
input_tensor.shape[2],
kernel_size[0],
stride[0],
pad_size_list[1],
)
pad_size_list[3] = adjust_pad_if_needed(
input_tensor.shape[3],
kernel_size[1],
stride[1],
pad_size_list[3],
)

accumulator_type = output.dtype

# Initilize zero point to zero.
Expand Down Expand Up @@ -131,6 +163,20 @@ def define_node(
except IndexError:
pad_size_list = [0, 0, 0, 0]

# Adjust the padding as necessary
pad_size_list[1] = adjust_pad_if_needed(
input_tensor.shape[2],
kernel_size[0],
stride[0],
pad_size_list[1],
)
pad_size_list[3] = adjust_pad_if_needed(
input_tensor.shape[3],
kernel_size[1],
stride[1],
pad_size_list[3],
)

attr = ts.TosaSerializerAttribute()
attr.MaxPool2dAttribute(
kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1
Expand Down
1 change: 1 addition & 0 deletions backends/arm/test/ops/test_max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
("zeros", torch.zeros(1, 1, 4, 8), [2, 2, 1]),
("ones", torch.ones(1, 16, 50, 32), [4, 2, 0]),
("rand", torch.rand(1, 16, 52, 16), [4, 3, 0]),
("non_divisible", torch.rand(1, 16, 112, 112), [3, 2, 1]),
]

test_data_suite_mult_batches = [
Expand Down
Loading