Skip to content

Commit

Permalink
Update on "[vmap] Add max_pool3d batch rule"
Browse files Browse the repository at this point in the history
Also add a helper to integrate `max_pool2d_with_indices` and `max_pool3d_with_indices`



[ghstack-poisoned]
  • Loading branch information
qqaatw committed Apr 20, 2023
2 parents 98deae4 + c463c73 commit 72187b0
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 35 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(matrix_H);
OP_DECOMPOSE(matrix_power);
OP_DECOMPOSE2(max, other );
OP_DECOMPOSE(max_pool1d);
OP_DECOMPOSE(max_pool1d_with_indices);
OP_DECOMPOSE(max_pool2d);
OP_DECOMPOSE(max_pool3d);
Expand Down
33 changes: 0 additions & 33 deletions aten/src/ATen/functorch/BatchRulesPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,36 +53,6 @@ max_pool2d_with_indices_batch_rule(
return max_pool_with_indices_batch_rule_helper(self, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, 2, at::max_pool2d_with_indices);
}

std::tuple<Tensor,optional<int64_t>>
max_pool1d_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
TORCH_INTERNAL_ASSERT(logical_rank == 2 || logical_rank == 3);

// Tensor[B, C, L] -> just call max_pool1d
if (logical_rank == 2) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::max_pool1d(
self_, kernel_size, stride, padding, dilation, ceil_mode
);
return std::make_tuple(std::move(result), 0);
}

// Tensor[B, N, C, L] -> Tensor[B * N, C, L]
auto bdim_size = self.size(*self_bdim);
auto self_ = reshape_dim_into(*self_bdim, 0, self);
auto result = at::max_pool1d(
self_, kernel_size, stride, padding, dilation, ceil_mode
);
return std::make_tuple(
reshape_dim_outof(0, bdim_size, result), 0);
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
EXISTING_BDIM(_adaptive_avg_pool2d);
EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward);
Expand All @@ -96,9 +66,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
EXISTING_BDIM_ALL_BOXED(adaptive_max_pool3d);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, adaptive_max_pool2d_backward, 2);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, adaptive_max_pool3d_backward, 2);
// We can get max_pool1d to work on CUDA through decomposition,
// but fails on CPU due to max_pool1d_cpu not having a derivative.
VMAP_SUPPORT(max_pool1d, max_pool1d_batch_rule);
VMAP_SUPPORT(max_pool2d_with_indices, max_pool2d_with_indices_batch_rule);
VMAP_SUPPORT(max_pool3d_with_indices, max_pool3d_with_indices_batch_rule);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, max_pool2d_with_indices_backward, 2);
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_vmap_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
"aten::log_softmax.int",
"aten::logdet",
"aten::masked_select_backward",
"aten::max_pool1d",
"aten::movedim.intlist",
"aten::one_hot",
"aten::real",
Expand Down Expand Up @@ -173,7 +172,6 @@
"aten::matrix_exp_backward",
"aten::max.names_dim",
"aten::max.names_dim_max",
"aten::max_pool1d",
"aten::mean.names_dim",
"aten::median.names_dim",
"aten::median.names_dim_values",
Expand Down

0 comments on commit 72187b0

Please sign in to comment.