Skip to content

Commit

Permalink
[vmap] Add max_pool3d batch rule (#99522)
Browse files Browse the repository at this point in the history
Also add a helper to integrate `max_pool2d_with_indices` and `max_pool3d_with_indices`

Pull Request resolved: #99522
Approved by: https://github.com/zou3519
  • Loading branch information
qqaatw authored and pytorchmergebot committed Apr 20, 2023
1 parent d31a00e commit c0674c4
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 14 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(max_pool1d);
OP_DECOMPOSE(max_pool1d_with_indices);
OP_DECOMPOSE(max_pool2d);
OP_DECOMPOSE(max_pool3d);
OP_DECOMPOSE(meshgrid);
OP_DECOMPOSE2(meshgrid, indexing);
OP_DECOMPOSE(mH);
Expand Down
40 changes: 30 additions & 10 deletions aten/src/ATen/functorch/BatchRulesPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,48 @@

namespace at { namespace functorch {

template <typename Func>
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
max_pool2d_with_indices_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
max_pool_with_indices_batch_rule_helper(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, int64_t n, Func pooling_fn) {

auto logical_rank = rankWithoutBatchDim(self, self_bdim);
TORCH_INTERNAL_ASSERT(logical_rank == 3 || logical_rank == 4);
// Tensor[B, C, H, W] -> just call max_pool2d
if (logical_rank == 3) {
TORCH_INTERNAL_ASSERT(logical_rank == n + 1 || logical_rank == n + 2);
// Tensor[B, logical_rank...] -> just call max_poolnd
if (logical_rank == n + 1) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::max_pool2d_with_indices(
auto result = pooling_fn(
self_, kernel_size, stride, padding, dilation, ceil_mode);
return std::make_tuple(std::move(std::get<0>(result)), 0, std::move(std::get<1>(result)), 0);
}
// Tensor[B, N, C, H, W] -> Tensor[B * N, C, H, W]
// Tensor[B, N, logical_rank...] -> Tensor[B * N, logical_rank...]
auto bdim_size = self.size(*self_bdim);
auto self_ = reshape_dim_into(*self_bdim, 0, self);
auto result = at::max_pool2d_with_indices(
auto result = pooling_fn(
self_, kernel_size, stride, padding, dilation, ceil_mode);
return std::make_tuple(
reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0,
reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0);
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
max_pool3d_with_indices_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
return max_pool_with_indices_batch_rule_helper(self, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, 3, at::max_pool3d_with_indices);
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
max_pool2d_with_indices_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
return max_pool_with_indices_batch_rule_helper(self, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, 2, at::max_pool2d_with_indices);
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
EXISTING_BDIM(_adaptive_avg_pool2d);
EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward);
Expand All @@ -49,7 +67,9 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, adaptive_max_pool2d_backward, 2);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, adaptive_max_pool3d_backward, 2);
VMAP_SUPPORT(max_pool2d_with_indices, max_pool2d_with_indices_batch_rule);
VMAP_SUPPORT(max_pool3d_with_indices, max_pool3d_with_indices_batch_rule);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, max_pool2d_with_indices_backward, 2);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, max_pool3d_with_indices_backward, 2);
}

}}
2 changes: 0 additions & 2 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,6 @@ def test_vmapjvpall(self, device, dtype, op):
xfail('masked_scatter'),
xfail('put'),
xfail('take'),
xfail('nn.functional.max_pool3d'),
xfail('nn.functional.feature_alpha_dropout', 'without_train'),
xfail('linalg.lu_factor', ''),
xfail('nn.functional.dropout2d', ''),
Expand Down Expand Up @@ -1171,7 +1170,6 @@ def test():
xfail('stft'),
xfail('nn.functional.rrelu'),
xfail('nn.functional.embedding_bag'),
xfail('nn.functional.max_pool3d'),
xfail('nn.functional.fractional_max_pool2d'),
xfail('linalg.lu_factor', ''),
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3663,7 +3663,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('unique'),
xfail('nn.functional.ctc_loss'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('nn.functional.max_pool3d'),
xfail('histc'),
xfail('as_strided'),
xfail('istft'),
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@
"aten::matrix_exp_backward",
"aten::max.names_dim",
"aten::max.names_dim_max",
"aten::max_pool3d",
"aten::mean.names_dim",
"aten::median.names_dim",
"aten::median.names_dim_values",
Expand Down

0 comments on commit c0674c4

Please sign in to comment.