Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions aten/src/ATen/native/cuda/DilatedMaxPool2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,35 +173,29 @@ __global__ void max_pool_backward_nchw(const int nthreads, const scalar_t* top_d
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
scalar_t* bottom_diff) {
CUDA_KERNEL_LOOP(index, height*width) {
int h = index/width;
int w = index - h * width;
CUDA_KERNEL_LOOP(index, height*width) {
int h = index / width;
int w = index % width;
int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h);
int phend = p_end(h, pad_h, pooled_height, stride_h);
int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w);
int pwend = p_end(w, pad_w, pooled_width, stride_w);
for (int n = blockIdx.y; n < num; n += gridDim.y)
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {

for (int n = blockIdx.y; n < num; n += gridDim.y) {
for (int c = blockIdx.z; c < channels; c+= gridDim.z) {
accscalar_t gradient = accscalar_t(0);
int offset = (n * channels + c) * pooled_height * pooled_width;
top_diff += offset;
top_mask += offset;
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
const scalar_t* ptr_top_diff = top_diff + offset;
const int64_t* ptr_top_mask = top_mask + offset;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (top_mask[ph * pooled_width + pw] == h * width + w) {
gradient += ScalarConvert<scalar_t, accscalar_t>::to(top_diff[ph * pooled_width + pw]);
if (ptr_top_mask[ph * pooled_width + pw] == h * width + w) {
gradient += ScalarConvert<scalar_t, accscalar_t>::to(ptr_top_diff[ph * pooled_width + pw]);
}
}
}
} else {
if (top_mask[phstart * pooled_width + pwstart] == h * width + w) {
gradient += ScalarConvert<scalar_t, accscalar_t>::to(top_diff[phstart * pooled_width + pwstart]);
}
}
bottom_diff[(n*channels+c)*height*width+index] = ScalarConvert<accscalar_t, scalar_t>::to(gradient);
}
}
}
}

Expand Down
21 changes: 21 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9604,6 +9604,27 @@ def helper(n, c, h, w, kernel_size, stride=None,
helper(10, 512, 31, 31, 3, stride=2)
helper(1, 129, 8, 8, 3, stride=2)

@onlyCUDA
def test_max_pool2d(self, device):
def helper(n, c, h, w, ks):
x = torch.randn(n, c, h, w, device='cuda', dtype=torch.float, requires_grad=True)
ref_x = x.detach().clone().cpu().requires_grad_()

pool = torch.nn.MaxPool2d(kernel_size=ks)

y = pool(x)
ref_y = pool(ref_x)

y.sum().backward()
ref_y.sum().backward()

self.assertEqual(y, ref_y)
self.assertEqual(x.grad, ref_x.grad)

helper(2, 8, 4, 4, ks=2)
helper(1, 100000, 32, 32, ks=4)
helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d

@onlyCUDA
@dtypesIfCUDA(torch.half, torch.float, torch.double)
def test_max_pool2d_nhwc(self, device, dtype):
Expand Down