Skip to content

Commit

Permalink
[functorch] Prevent using for-loop for out-of-place index_fill batch …
Browse files Browse the repository at this point in the history
…rule

ghstack-source-id: f31f7166b41d8dca2c6b3bdb94bba4478600ad2b
Pull Request resolved: #99229
  • Loading branch information
qqaatw committed Apr 17, 2023
1 parent 05809c7 commit bbcdbd2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 24 deletions.
50 changes: 40 additions & 10 deletions aten/src/ATen/functorch/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(

if (inplace) {
// Do for-loop for in-place because we cannot reshape
// `self_` having an incompatible stride without copying
// `self_` having an incompatible stride without copying.
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
Expand Down Expand Up @@ -1100,6 +1100,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
const Tensor & value, optional<int64_t> value_bdim,
const bool inplace) {
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor index_ = moveBatchDimToFront(index, index_bdim);
Tensor value_ = moveBatchDimToFront(value, value_bdim);
Expand All @@ -1123,20 +1124,49 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);

if (inplace || value_bdim.has_value()) {
// Do for-loop for in-place because we cannot reshape
// `self_` having an incompatible stride without copying.
// If value has a batch dim, we do for-loop as well because
// index_fill_ supports 1-element tensor only.
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
}
return std::make_tuple(self_, 0);
}

self_ = self_bdim.has_value() ? self_ : self_.clone();

for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
if (self_logical_rank != 0){
auto index_offset = at::arange(
batch_size,
at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
);
if (index_logical_rank == 0){
index_ = index_.unsqueeze(-1);
}
index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
index_ = reshape_dim_into(0, 0, index_);
self_ = reshape_dim_into(0, dim, self_);
self_.index_fill_(dim, index_, value);
self_ = reshape_dim_outof(dim, batch_size, self_);
return std::make_tuple(self_, dim);
}

if (index_logical_rank != 0){
index_ = reshape_dim_into(0, 0, index_);
}
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);

return std::make_tuple(self_, 0);
}
Expand Down
28 changes: 14 additions & 14 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3885,56 +3885,56 @@ def test2():
x = torch.zeros(B, 3, device=device)
dim = 0
index = torch.tensor([[0], [1]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

def test3():
# self batched, self logical rank 1, index logical rank 0
x = torch.zeros(B, 3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

def test4():
# self not batched, self logical rank 0, index logical rank 1
x = torch.zeros([], device=device)
dim = 0
index = torch.tensor([[0], [0]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test5():
# self not batched, self logical rank 0, index logical rank 0
x = torch.zeros([], device=device)
dim = 0
index = torch.tensor([0, 0], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test6():
# self not batched, self logical rank 0, index logical rank 1
x = torch.zeros(3, device=device)
dim = 0
index = torch.tensor([[0], [1]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test7():
# self not batched, self logical rank 0, index logical rank 0
x = torch.zeros(3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test8():
# self batched, self logical rank > 1, index logical rank 0
x = torch.zeros(B, 3, 3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

for test in (test1, test2, test3, test4, test5, test6, test7, test8):
check_vmap_fallback(self, test, torch.index_fill)
Expand Down

0 comments on commit bbcdbd2

Please sign in to comment.