Skip to content

Commit

Permalink
AveragePool: expand incomplete kernel_size for the C++ API
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch/pytorch#22075

Differential Revision: D15945260

Pulled By: mrshenli

fbshipit-source-id: 827660c19ebbdb5f0aae2f4eadb6025ae2f93674
  • Loading branch information
skrah authored and facebook-github-bot committed Jun 24, 2019
1 parent 5027395 commit efc5665
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 52 deletions.
22 changes: 10 additions & 12 deletions aten/src/ATen/native/AveragePool2d.cpp
Expand Up @@ -90,18 +90,17 @@ void avg_pool2d_out_cpu_template(
bool ceil_mode,
bool count_include_pad)
{
// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
"avg_pool2d: all IntArrayRef sizes must be 2");

TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
"non-empty 2D or 3D (batch mode) tensor expected for input");

const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);

const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
Expand Down Expand Up @@ -236,11 +235,10 @@ Tensor& avg_pool2d_backward_out_cpu_template(
bool ceil_mode,
bool count_include_pad)
{
// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
"avg_pool2d: all IntArrayRef sizes must be 2");

const int64_t ndim = input.ndimension();
Expand All @@ -249,7 +247,7 @@ Tensor& avg_pool2d_backward_out_cpu_template(
"non-empty 3D or 4D (batch mode) tensor expected for input");

const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);

const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
Expand Down
26 changes: 12 additions & 14 deletions aten/src/ATen/native/AveragePool3d.cpp
Expand Up @@ -104,19 +104,18 @@ void avg_pool3d_out_cpu_template(
bool ceil_mode,
bool count_include_pad)
{
// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
"avg_pool3d: all IntArrayRef sizes must be 3");

TORCH_CHECK((input_.ndimension() == 4 || input_.ndimension() == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input");

const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);

const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
Expand Down Expand Up @@ -293,19 +292,18 @@ Tensor& avg_pool3d_backward_out_cpu_template(
bool ceil_mode,
bool count_include_pad)
{
// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
"avg_pool3d: all IntArrayRef sizes must be 3");

TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input");

const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);

const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
Expand Down
22 changes: 10 additions & 12 deletions aten/src/ATen/native/cuda/AveragePool2d.cu
Expand Up @@ -114,18 +114,17 @@ void avg_pool2d_out_cuda_template(

checkAllSameGPU("avg_pool2d_out_cuda", {output_arg, input_arg});

// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
"avg_pool2d: all IntArrayRef sizes must be 2");

TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");

const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);

const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
Expand Down Expand Up @@ -230,18 +229,17 @@ Tensor& avg_pool2d_backward_out_cuda_template(
checkAllSameGPU("avg_pool2d_backward_out_cuda",
{gradInput_arg, gradOutput_arg, input_arg});

// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 2 &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 2) &&
(stride.empty() || stride.size() == 2) &&
(padding.size() == 1 || padding.size() == 2),
"avg_pool2d: all IntArrayRef sizes must be 2");

TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");

const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);

const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[1]);
Expand Down
26 changes: 12 additions & 14 deletions aten/src/ATen/native/cuda/AveragePool3d.cu
Expand Up @@ -312,19 +312,18 @@ void avg_pool3d_out_cuda_template(

checkAllSameGPU("avg_pool3d_out_cuda", {output_arg, input_arg});

// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
"avg_pool3d: all IntArrayRef sizes must be 3");

TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input");

const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);

const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
Expand Down Expand Up @@ -435,11 +434,10 @@ void avg_pool3d_backward_out_cuda_template(
checkAllSameGPU("avg_pool3d_backward_out_cuda",
{gradInput_arg, gradOutput_arg, input_arg});

// #20866 [JIT] stride.empty() is passed through
// #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK((kernel_size.size() == 1 || kernel_size.size() == 3) &&
(stride.empty() || stride.size() == 3) &&
(padding.size() == 1 || padding.size() == 3),
"avg_pool3d: all IntArrayRef sizes must be 3");

TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
Expand All @@ -453,8 +451,8 @@ void avg_pool3d_backward_out_cuda_template(
gradInput.zero_();

const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
const int kH = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[1]);
const int kW = kernel_size.size() == 1 ? kT : safe_downcast<int, int64_t>(kernel_size[2]);

const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
Expand Down

0 comments on commit efc5665

Please sign in to comment.