Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/native/AveragePool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ static void avg_pool2d_out_frame(
hend = std::min(hend, inputHeight);
wend = std::min(wend, inputWidth);

if (hstart >= hend || wstart >= wend) {
++ptr_output;
continue;
}

scalar_t sum = 0;

int divide_factor;
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/AveragePool3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ static void avg_pool3d_out_frame(
hend = std::min(hend, iheight);
wend = std::min(wend, iwidth);

if (tstart >= tend || hstart >= hend || wstart >= wend) {
++op;
continue;
}

int divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
Expand Down
22 changes: 22 additions & 0 deletions aten/src/ATen/native/cuda/AveragePool2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ __global__ void avg_pool2d_out_cuda_frame(const int nthreads,
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);

if (hstart >= hend || wstart >= wend) {
top_data[index] = scalar_t(0);
continue;
}

accscalar_t aveval = accscalar_t(0);
const scalar_t* const bottom_slice = bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
Expand Down Expand Up @@ -86,6 +92,12 @@ __global__ void avg_pool2d_out_cuda_frame_nhwc(const int nthreads,
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);

if (hstart >= hend || wstart >= wend) {
top_data[index] = scalar_t(0);
continue;
}

accscalar_t aveval = accscalar_t(0);
const scalar_t* const bottom_slice = bottom_data + n * channels * height * width + c;
for (int h = hstart; h < hend; ++h) {
Expand Down Expand Up @@ -141,6 +153,11 @@ __global__ void avg_pool2d_backward_out_cuda_frame(const int nthreads, const sca
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);

if (hstart >= hend || wstart >= wend) {
continue;
}

int divide_factor;
if (use_divisor) {
divide_factor = divisor_override;
Expand Down Expand Up @@ -191,6 +208,11 @@ __global__ void avg_pool2d_backward_out_cuda_frame_nhwc(const int nthreads,
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);

if (hstart >= hend || wstart >= wend) {
continue;
}

int divide_factor;
if (use_divisor) {
divide_factor = divisor_override;
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/cuda/AveragePool3d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ __global__ void avg_pool3d_cuda_update_output(
hend = min(hend, input.size(2));
wend = min(wend, input.size(3));

if (tstart >= tend || hstart >= hend || wstart >= wend) {
output[slice][oFrame][oRow][oCol] = scalar_t(0);
return;
}

accscalar_t divide_factor;
if (divisor_override) {
divide_factor = static_cast<accscalar_t>(divisor_override);
Expand Down Expand Up @@ -119,6 +124,11 @@ __global__ void avg_pool3d_cuda_update_output(
hend = min(hend, input.size(2));
wend = min(wend, input.size(3));

if (tstart >= tend || hstart >= hend || wstart >= wend) {
output[slice][oFrame][oRow][oCol] = scalar_t(0);
return;
}

accscalar_t divide_factor;
if (divisor_override) {
divide_factor = static_cast<accscalar_t>(divisor_override);
Expand Down
41 changes: 41 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,47 @@ def test_avg_pool3d_with_zero_divisor(self):
self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
lambda: torch.nn.functional.avg_pool3d(torch.zeros(3, 3, 3, 3), (2, 2, 2), divisor_override=0))

def test_avg_pool1d_ceil_mode(self):
# Regression test for gh-36977
x = 10 * torch.randn((1, 16, 4))
y = torch.nn.functional.avg_pool1d(
x, ceil_mode=True, count_include_pad=True, kernel_size=1, stride=2)
self.assertTrue(not torch.isnan(y).any())

if TEST_CUDA:
y = torch.nn.functional.avg_pool1d(
x.to('cuda'), ceil_mode=True, count_include_pad=True, kernel_size=1, stride=2)
self.assertTrue(not torch.isnan(y).any())


def test_avg_pool2d_ceil_mode(self):
# Regression test for gh-36977
x = 10 * torch.randn((1, 16, 4, 4))
y = torch.nn.functional.avg_pool2d(
x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
padding=(0, 1), stride=2)
self.assertTrue(not torch.isnan(y).any())

if TEST_CUDA:
y = torch.nn.functional.avg_pool2d(
x.to('cuda'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
padding=(0, 1), stride=2)
self.assertTrue(not torch.isnan(y).any())


def test_avg_pool3d_ceil_mode(self):
# Regression test for gh-36977
x = 10 * torch.randn((1, 16, 4, 4, 4))
y = torch.nn.functional.avg_pool3d(
x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2, 3), stride=2)
self.assertTrue(not torch.isnan(y).any())

if TEST_CUDA:
y = torch.nn.functional.avg_pool3d(
x.to('cuda'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2, 3), stride=2)
self.assertTrue(not torch.isnan(y).any())


class TestNN(NNTestCase):
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True
Expand Down