Skip to content

Commit

Permalink
add batch impl. for inplace index_add operation (#112276)
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermeleobas authored and pytorchmergebot committed Oct 31, 2023
1 parent 424c093 commit 86196bf
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 9 deletions.
56 changes: 51 additions & 5 deletions aten/src/ATen/functorch/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,12 +934,18 @@ std::tuple<Tensor, optional<int64_t>> diagonal_scatter_batch_rule(
return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0);
}

std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
std::tuple<Tensor,optional<int64_t>> index_add_batch_rule_impl(
Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& other, optional<int64_t> other_bdim,
const Scalar& alpha) {
const Scalar& alpha,
const bool inplace) {

if (inplace && !self_bdim.has_value()){
vmapIncompatibleInplaceError("index_add_");
}

if (!index_bdim) {
// Handle scalar tensors... self, other can be scalar tensors
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
Expand All @@ -958,6 +964,14 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
other_ = ensure_has_bdim(other_, other_bdim.has_value(), batch_size);

if (inplace) {
self_.index_add_(dim + 1, index, other_, alpha);
if (self_logical_rank == 0) {
self_ = self_.squeeze(-1);
}
return std::make_tuple(self, 0);
}

auto result = self_.index_add(dim + 1, index, other_, alpha);
if (self_logical_rank == 0) {
result = result.squeeze(-1);
Expand All @@ -969,19 +983,50 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
// right now. We really want generalized index_add kernel in PyTorch
auto batch_size = get_bdim_size3(self, self_bdim, other, other_bdim, index, index_bdim);
std::vector<Tensor> results;
results.reserve(batch_size);
if (!inplace) {
results.reserve(batch_size);
}
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_bdim.has_value() ?
self.select(*self_bdim, i) : self;
const auto& other_slice = other_bdim.has_value() ?
other.select(*other_bdim, i) : other;
const auto& index_slice = index_bdim.has_value() ?
index.select(*index_bdim, i) : index;
results.push_back(at::index_add(self_slice, dim, index_slice, other_slice, alpha));

if (inplace) {
self_slice.index_add_(dim, index_slice, other_slice, alpha);
} else {
results.push_back(at::index_add(self_slice, dim, index_slice, other_slice, alpha));
}
}
if (inplace) {
return std::make_tuple(at::stack(self), 0);
}
return std::make_tuple(at::stack(results), 0);
}

void index_add__batch_rule(
Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& other, optional<int64_t> other_bdim,
const Scalar& alpha) {
index_add_batch_rule_impl(self, self_bdim, dim, index, index_bdim, other,
other_bdim, alpha, true);
}

std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& other, optional<int64_t> other_bdim,
const Scalar& alpha) {
auto self_ = self.clone(at::MemoryFormat::Preserve);
return index_add_batch_rule_impl(self_, self_bdim, dim, index, index_bdim,
other, other_bdim, alpha, false);
}

static std::tuple<Tensor,Tensor> binary_pointwise_align(
const Tensor & self,
optional<int64_t> self_bdim,
Expand Down Expand Up @@ -1208,6 +1253,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT2(index_fill_, int_Scalar, index_fill__int_scalar_batch_rule);
VMAP_SUPPORT2(index_fill, int_Tensor, index_fill_int_tensor_batch_rule);
VMAP_SUPPORT2(index_fill, int_Scalar, index_fill_int_scalar_batch_rule);
VMAP_SUPPORT(index_add_, index_add__batch_rule);
VMAP_SUPPORT(index_add, index_add_batch_rule);
VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
VMAP_SUPPORT(gather, gather_batch_rule);
Expand Down
4 changes: 1 addition & 3 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,6 @@ def vjp_of_vjp(*args_and_cotangents):
skip('nn.functional.alpha_dropout'), # randomness
skip('nn.functional.scaled_dot_product_attention'), # randomness
skip('nn.functional.multi_head_attention_forward'), # randomness
xfail('as_strided'), # as_strided is too wild for us to support, wontfix
xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset
xfail('masked_scatter'), # dynamic
xfail('nn.functional.fractional_max_pool2d'), # random
Expand Down Expand Up @@ -957,6 +956,7 @@ def vjp_of_vjp(*args_and_cotangents):
{torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"),
))
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail.union({
xfail('as_strided'),
xfail('as_strided', 'partial_views'),
}))
def test_vmapvjp(self, device, dtype, op):
Expand Down Expand Up @@ -1186,8 +1186,6 @@ def test():
xfail('nn.functional.bilinear'),
xfail('nn.functional.fractional_max_pool3d'),
xfail('nn.functional.ctc_loss'),
xfail('as_strided'),
xfail('stft'),
xfail('nn.functional.rrelu'),
xfail('nn.functional.embedding_bag'),
xfail('nn.functional.fractional_max_pool2d'),
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3778,7 +3778,6 @@ def test_op_has_batch_rule(self, device, dtype, op):
'hypot',
'igamma',
'igammac',
'index_add',
'index_copy',
'lcm',
'ldexp',
Expand Down

0 comments on commit 86196bf

Please sign in to comment.