Skip to content

Commit

Permalink
fix max_pool2d cuda version Dimension out of range issue(#36046) (#36095
Browse files Browse the repository at this point in the history
)

Summary: Pull Request resolved: #36095

Test Plan: Imported from OSS

Differential Revision: D20876733

Pulled By: glaringlee

fbshipit-source-id: a2b92fd2dd0254c5443af469e3fb2faa2323e5c9
  • Loading branch information
lixinyu authored and facebook-github-bot committed Apr 7, 2020
1 parent 3e5d25f commit b55dee9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/DilatedMaxPool2d.cu
Expand Up @@ -358,7 +358,7 @@ void max_pool2d_with_indices_out_cuda_template(

Tensor input = input_.contiguous(memory_format);

const int64_t in_stride_n = input.stride(-4);
const int64_t in_stride_n = input_.ndimension() == 4 ? input.stride(-4) : 0;
const int64_t in_stride_c = input.stride(-3);
const int64_t in_stride_h = input.stride(-2);
const int64_t in_stride_w = input.stride(-1);
Expand Down Expand Up @@ -506,7 +506,7 @@ void max_pool2d_with_indices_backward_out_cuda_template(
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);

const int64_t in_stride_n = input.stride(-4);
const int64_t in_stride_n = input.ndimension() == 4 ? input.stride(-4) : 0;
const int64_t in_stride_c = input.stride(-3);
const int64_t in_stride_h = input.stride(-2);
const int64_t in_stride_w = input.stride(-1);
Expand Down
10 changes: 9 additions & 1 deletion torch/testing/_internal/common_nn.py
Expand Up @@ -66,7 +66,7 @@ def get_weight(m):
# and the `cpp_var_map` entry must be
# `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
# used in the C++ constructor argument with the Python tensor value `random_samples`.
#
#
# For NN functional:
# 1. Make sure you already have a test dict with the functional configuration you want to test.
# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
Expand Down Expand Up @@ -1812,12 +1812,20 @@ def fractional_max_pool3d_test(test_case):
cpp_constructor_args='torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)',
input_size=(2, 4, 5, 5),
),
dict(
module_name='MaxPool2d',
constructor_args=((3, 3), (2, 2), (1, 1)),
cpp_constructor_args='torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})',
input_size=(3, 7, 7),
desc='3d_input'
),
dict(
module_name='MaxPool2d',
constructor_args=((3, 3), (2, 2), (1, 1)),
cpp_constructor_args='torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})',
input_size=(1, 3, 7, 7),
check_with_channels_last=True,
desc='4d_input'
),
dict(
module_name='AvgPool1d',
Expand Down

0 comments on commit b55dee9

Please sign in to comment.