Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/MaxPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static void check_max_pool1d(
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);

const int64_t OW = pooling_output_shape(self.size(-1), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
TORCH_CHECK(OW > 0, "max_pool1d() Invalid computed output size: ", OW);
}

} // namespace
Expand Down
10 changes: 10 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7228,6 +7228,16 @@ def test_fractional_max_pool2d_invalid_output_ratio(self):
"fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
res = arg_class(*arg_3)

def test_max_pool1d_invalid_output_size(self):
arg_1 = 3
arg_2 = 255
arg_3 = False
arg_class = torch.nn.MaxPool1d(kernel_size=arg_1, stride=arg_2, return_indices=arg_3)
arg_4_0 = torch.as_tensor([[0.3204]])
arg_4 = [arg_4_0,]

with self.assertRaises(RuntimeError):
res = arg_class(*arg_4)

class TestFusionEval(TestCase):
@given(X=hu.tensor(shapes=((5, 3, 5, 5),)),
Expand Down