Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C++ MaxPool Module #24860

Closed
wants to merge 16 commits into from
@@ -533,6 +533,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/loss.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/maxpool.cpp
This conversation was marked as resolved by yf225

This comment has been minimized.

Copy link
@yf225

yf225 Sep 9, 2019

Contributor

@ShahriarSS the Facebook internal build seems to fail, and I am thinking we might need to add the maxpool.cpp entry into

torch_cpp_srcs = [
"torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA
"torch/csrc/api/src/data/datasets/mnist.cpp",
"torch/csrc/api/src/data/samplers/distributed.cpp",
"torch/csrc/api/src/data/samplers/random.cpp",
"torch/csrc/api/src/data/samplers/sequential.cpp",
"torch/csrc/api/src/data/samplers/stream.cpp",
"torch/csrc/api/src/jit.cpp",
"torch/csrc/api/src/nn/init.cpp",
"torch/csrc/api/src/nn/module.cpp",
"torch/csrc/api/src/nn/modules/batchnorm.cpp",
"torch/csrc/api/src/nn/modules/conv.cpp",
"torch/csrc/api/src/nn/modules/dropout.cpp",
"torch/csrc/api/src/nn/modules/embedding.cpp",
"torch/csrc/api/src/nn/modules/functional.cpp",
"torch/csrc/api/src/nn/modules/linear.cpp",
"torch/csrc/api/src/nn/modules/named_any.cpp",
"torch/csrc/api/src/nn/modules/rnn.cpp",
"torch/csrc/api/src/optim/adagrad.cpp",
"torch/csrc/api/src/optim/adam.cpp",
"torch/csrc/api/src/optim/lbfgs.cpp",
"torch/csrc/api/src/optim/optimizer.cpp",
"torch/csrc/api/src/optim/rmsprop.cpp",
"torch/csrc/api/src/optim/serialize.cpp",
"torch/csrc/api/src/optim/sgd.cpp",
"torch/csrc/api/src/serialize/input-archive.cpp",
"torch/csrc/api/src/serialize/output-archive.cpp",
]

${TORCH_SRC_DIR}/csrc/api/src/nn/modules/named_any.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp
${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp
@@ -1258,11 +1258,13 @@ def fractional_max_pool3d_test(test_case):
dict(
module_name='MaxPool1d',
constructor_args=(4,),
cpp_constructor_args='(4)',
input_size=(2, 10, 4),
),
dict(
module_name='MaxPool1d',
constructor_args=(4, 4),
cpp_constructor_args='(torch::nn::MaxPool1dOptions(4).stride(4))',
input_size=(2, 10, 4),
desc='stride',
),
@@ -1376,6 +1378,7 @@ def fractional_max_pool3d_test(test_case):
dict(
module_name='MaxPool2d',
constructor_args=((3, 3), (2, 2), (1, 1)),
cpp_constructor_args='(torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1}))',
input_size=(1, 3, 7, 7),
),
dict(
@@ -1580,17 +1583,20 @@ def fractional_max_pool3d_test(test_case):
dict(
module_name='MaxPool3d',
constructor_args=((2, 2, 2),),
cpp_constructor_args='(torch::nn::MaxPool3dOptions({2, 2, 2}))',
input_size=(2, 3, 5, 5, 5),
),
dict(
module_name='MaxPool3d',
constructor_args=(2, (2, 2, 2)),
cpp_constructor_args='(torch::nn::MaxPool3dOptions(2).stride({2, 2, 2}))',
input_size=(2, 3, 5, 5, 5),
desc='stride',
),
dict(
module_name='MaxPool3d',
constructor_args=(2, 2, (1, 1, 1)),
cpp_constructor_args='(torch::nn::MaxPool3dOptions(2).stride(2).padding({1, 1, 1}))',
input_size=(2, 3, 5, 5, 5),
desc='stride_padding',
),
@@ -95,6 +95,58 @@ TEST_F(ModulesTest, Conv3d) {
ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3);
}

TEST_F(ModulesTest, MaxPool1d) {
MaxPool1d model(MaxPool1dOptions(3).stride(2));
auto x = torch::ones({1, 1, 5}, torch::requires_grad());
auto y = model(x);
torch::Tensor s = y.sum();

s.backward();
ASSERT_EQ(y.ndimension(), 3);
ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1 ,2})));
ASSERT_EQ(s.ndimension(), 0);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2}));

This comment has been minimized.

Copy link
@yf225

yf225 Aug 20, 2019

Contributor

Besides checking ndimension() and sizes(), can we also check that the values of y are what we expected? (we might need to set value of x in a specific way in order to test this)

}

TEST_F(ModulesTest, MaxPool2dEven) {
MaxPool2d model(MaxPool2dOptions(3).stride(2));
auto x = torch::ones({2, 5, 5}, torch::requires_grad());
auto y = model(x);
torch::Tensor s = y.sum();

s.backward();
ASSERT_EQ(y.ndimension(), 3);
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2 ,2})));
ASSERT_EQ(s.ndimension(), 0);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2}));

This comment has been minimized.

Copy link
@yf225

yf225 Aug 20, 2019

Contributor

Ditto for checking value of y.

}

TEST_F(ModulesTest, MaxPool2dUneven) {
MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2}));
auto x = torch::ones({2, 5, 4}, torch::requires_grad());
auto y = model(x);
torch::Tensor s = y.sum();

s.backward();
ASSERT_EQ(y.ndimension(), 3);
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2})));
ASSERT_EQ(s.ndimension(), 0);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2}));
}

TEST_F(ModulesTest, MaxPool3d) {
MaxPool3d model(MaxPool3dOptions(3).stride(2));
auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad());
auto y = model(x);
torch::Tensor s = y.sum();

s.backward();
ASSERT_EQ(y.ndimension(), 4);
ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2})));
ASSERT_EQ(s.ndimension(), 0);
ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2}));
}

TEST_F(ModulesTest, Linear) {
Linear model(5, 2);
auto x = torch::randn({10, 5}, torch::requires_grad());
@@ -368,6 +420,24 @@ TEST_F(ModulesTest, PrettyPrintConv) {
"torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])");
}

TEST_F(ModulesTest, PrettyPrintMaxPool) {
ASSERT_EQ(
c10::str(MaxPool1d(5)),
"torch::nn::MaxPool1d(kernel_size=5, stride=5)");
ASSERT_EQ(
c10::str(MaxPool2d(5)),
"torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[5, 5])");
ASSERT_EQ(
c10::str(MaxPool2d(MaxPool2dOptions(5).stride(2))),
"torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[2, 2])");

const auto options =
MaxPool2dOptions(torch::IntArrayRef{5, 6}).stride({1, 2});
ASSERT_EQ(
c10::str(MaxPool2d(options)),
"torch::nn::MaxPool2d(kernel_size=[5, 6], stride=[1, 2])");
}

TEST_F(ModulesTest, PrettyPrintDropout) {
ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)");
ASSERT_EQ(
@@ -17,9 +17,9 @@ torch.nn.ConvTranspose2d|No|No
torch.nn.ConvTranspose3d|No|No
torch.nn.Unfold|No|No
torch.nn.Fold|Yes|No
torch.nn.MaxPool1d|No|No
torch.nn.MaxPool2d|No|No
torch.nn.MaxPool3d|No|No
torch.nn.MaxPool1d|Yes|No
torch.nn.MaxPool2d|Yes|No
torch.nn.MaxPool3d|Yes|No
torch.nn.MaxUnpool1d|No|No
torch.nn.MaxUnpool2d|No|No
torch.nn.MaxUnpool3d|No|No
@@ -37,9 +37,18 @@
cpp_default_constructor_args="(3, 2)",
num_attrs_recursive=5,
),
'MaxPool1d': TorchNNModuleMetadata(),
'MaxPool2d': TorchNNModuleMetadata(),
'MaxPool3d': TorchNNModuleMetadata(),
'MaxPool1d': TorchNNModuleMetadata(
cpp_default_constructor_args="(2)",
num_attrs_recursive=6,
),
'MaxPool2d': TorchNNModuleMetadata(
cpp_default_constructor_args="(2)",
num_attrs_recursive=6,
),
'MaxPool3d': TorchNNModuleMetadata(
cpp_default_constructor_args="(2)",
num_attrs_recursive=6,
),
'MaxUnpool1d': TorchNNModuleMetadata(),
'MaxUnpool2d': TorchNNModuleMetadata(),
'MaxUnpool3d': TorchNNModuleMetadata(),
@@ -193,6 +193,7 @@ def add_torch_libs():
"torch/csrc/api/src/nn/modules/functional.cpp",
"torch/csrc/api/src/nn/modules/linear.cpp",
"torch/csrc/api/src/nn/modules/loss.cpp",
"torch/csrc/api/src/nn/modules/maxpool.cpp",
"torch/csrc/api/src/nn/modules/named_any.cpp",
"torch/csrc/api/src/nn/modules/rnn.cpp",
"torch/csrc/api/src/optim/adagrad.cpp",
@@ -9,6 +9,7 @@
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/loss.h>
#include <torch/nn/modules/maxpool.h>
#include <torch/nn/modules/modulelist.h>
#include <torch/nn/modules/named_any.h>
#include <torch/nn/modules/rnn.h>
@@ -0,0 +1,115 @@
#pragma once

#include <torch/expanding_array.h>
#include <torch/nn/cloneable.h>

#include <torch/csrc/WindowsTorchApiMacro.h>

namespace torch {
namespace nn {

/// Options for a `D`-dimensional maxpool module.
template <size_t D>
struct MaxPoolOptions {
MaxPoolOptions(ExpandingArray<D> kernel_size)
: kernel_size_(kernel_size), stride_(kernel_size) {}

/// the size of the window to take a max over
TORCH_ARG(ExpandingArray<D>, kernel_size);

/// the stride of the window. Default value is `kernel_size
TORCH_ARG(ExpandingArray<D>, stride);

/// implicit zero padding to be added on both sides
TORCH_ARG(ExpandingArray<D>, padding) = 0;

/// a parameter that controls the stride of elements in the window
TORCH_ARG(ExpandingArray<D>, dilation) = 1;

/// if true, will return the max indices along with the outputs. Useful
/// for `MaxUnpool1d` later
TORCH_ARG(bool, return_indices) = false;

/// when True, will use `ceil` instead of `floor` to compute the output shape
TORCH_ARG(bool, ceil_mode) = false;
};

/// Base class for all (dimension-specialized) maxpool modules.
template <size_t D, typename Derived>
class TORCH_API MaxPoolImpl : public torch::nn::Cloneable<Derived> {
public:
MaxPoolImpl(ExpandingArray<D> kernel_size)
: MaxPoolImpl(MaxPoolOptions<D>(kernel_size)) {}
explicit MaxPoolImpl(MaxPoolOptions<D> options);

void reset() override;

/// Pretty prints the `MaxPool{1,2,3}d` module into the given `stream`.
void pretty_print(std::ostream& stream) const override;

/// The options with which this `Module` was constructed.
MaxPoolOptions<D> options;
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Applies maxpool over a 1-D input.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxPool1d to learn
/// about the exact behavior of this module.

This comment has been minimized.

Copy link
@yf225

yf225 Aug 20, 2019

Contributor

We might want to write the formulas here similar to https://pytorch.org/docs/stable/nn.html#maxpool1d (we can either do it in this PR or in a follow-up PR).

class TORCH_API MaxPool1dImpl : public MaxPoolImpl<1, MaxPool1dImpl> {
public:
using MaxPoolImpl<1, MaxPool1dImpl>::MaxPoolImpl;
Tensor forward(const Tensor& input);
};

/// `MaxPoolOptions` specialized for 1-D maxpool.
using MaxPool1dOptions = MaxPoolOptions<1>;

/// A `ModuleHolder` subclass for `MaxPool1dImpl`.
/// See the documentation for `MaxPool1dImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(MaxPool1d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Applies maxpool over a 2-D input.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxPool2d to learn
/// about the exact behavior of this module.
class TORCH_API MaxPool2dImpl : public MaxPoolImpl<2, MaxPool2dImpl> {
public:
using MaxPoolImpl<2, MaxPool2dImpl>::MaxPoolImpl;
Tensor forward(const Tensor& input);
};

/// `MaxPoolOptions` specialized for 2-D maxpool.
using MaxPool2dOptions = MaxPoolOptions<2>;

/// A `ModuleHolder` subclass for `MaxPool2dImpl`.
/// See the documentation for `MaxPool2dImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(MaxPool2d);

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

/// Applies maxpool over a 3-D input.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxPool3d to learn
/// about the exact behavior of this module.
class TORCH_API MaxPool3dImpl : public MaxPoolImpl<3, MaxPool3dImpl> {
public:
using MaxPoolImpl<3, MaxPool3dImpl>::MaxPoolImpl;
Tensor forward(const Tensor& input);
};

/// `MaxPoolOptions` specialized for 3-D maxpool.
using MaxPool3dOptions = MaxPoolOptions<3>;

/// A `ModuleHolder` subclass for `MaxPool3dImpl`.
/// See the documentation for `MaxPool3dImpl` class to learn what methods it
/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(MaxPool3d);

} // namespace nn
} // namespace torch
@@ -0,0 +1,62 @@
#include <torch/nn/modules/maxpool.h>

#include <torch/expanding_array.h>

namespace torch {
namespace nn {

template <size_t D, typename Derived>
MaxPoolImpl<D, Derived>::MaxPoolImpl(MaxPoolOptions<D> options)
: options(std::move(options)) {}

template <size_t D, typename Derived>
void MaxPoolImpl<D, Derived>::reset() {}

template <size_t D, typename Derived>
void MaxPoolImpl<D, Derived>::pretty_print(std::ostream& stream) const {
stream << "torch::nn::MaxPool" << D << "d"
<< "(kernel_size=" << options.kernel_size_
<< ", stride=" << options.stride_ << ")";
}

Tensor MaxPool1dImpl::forward(const Tensor& input) {
return torch::max_pool1d(
input,
options.kernel_size_,
options.stride_,
options.padding_,
options.dilation_,
options.ceil_mode_);

This comment has been minimized.

Copy link
@yf225

yf225 Sep 10, 2019

Contributor

@ShahriarSS We might be missing options.return_indices_ here (and the forward calls for 2d and 3d), I will push a commit to add it.

This comment has been minimized.

Copy link
@ShahriarSS

ShahriarSS Sep 10, 2019

Author Contributor

Thanks. But I don't think that torch::max_poolxd uses it. That's why I didn't include it.

This comment has been minimized.

Copy link
@ShahriarSS

ShahriarSS Sep 10, 2019

Author Contributor

Yes @yf225 here is the documentation:

return_indices: if ``True``, will return the max indices along with the outputs.
                        Useful for :class:`torch.nn.MaxUnpool2d` later

This comment has been minimized.

Copy link
@yf225

yf225 Sep 10, 2019

Contributor

Ah got it, thanks a lot for the catch!

}

Tensor MaxPool2dImpl::forward(const Tensor& input) {
return torch::max_pool2d(
input,
options.kernel_size_,
options.stride_,
options.padding_,
options.dilation_,
options.ceil_mode_);
}

Tensor MaxPool3dImpl::forward(const Tensor& input) {
return torch::max_pool3d(
input,
options.kernel_size_,
options.stride_,
options.padding_,
options.dilation_,
options.ceil_mode_);
}

template struct MaxPoolOptions<1>;
template class MaxPoolImpl<1, MaxPool1dImpl>;

template struct MaxPoolOptions<2>;
template class MaxPoolImpl<2, MaxPool2dImpl>;

template struct MaxPoolOptions<3>;
template class MaxPoolImpl<3, MaxPool3dImpl>;

} // namespace nn
} // namespace torch
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.