Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[functorch] Prevent using for-loop for out-of-place index_fill batch rule #99229

Closed
wants to merge 6 commits into from
96 changes: 57 additions & 39 deletions aten/src/ATen/functorch/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,42 @@ std::tuple<Tensor,optional<int64_t>> masked_fill_scalar_batch_rule(
return std::make_tuple(result, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_batch_rule_helper(
int64_t batch_size,
int64_t self_logical_rank,
int64_t index_logical_rank,
Tensor & self_,
int64_t dim,
Tensor & index_,
const Scalar & value
){
if (self_logical_rank != 0){
auto index_offset = at::arange(
batch_size,
at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
);
if (index_logical_rank == 0){
index_ = index_.unsqueeze(-1);
}
index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
index_ = reshape_dim_into(0, 0, index_);
self_ = reshape_dim_into(0, dim, self_);
self_.index_fill_(dim, index_, value);
self_ = reshape_dim_outof(dim, batch_size, self_);
return std::make_tuple(self_, dim);
}

// If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
if (index_logical_rank != 0){
index_ = reshape_dim_into(0, 0, index_);
}
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);

return std::make_tuple(self_, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
Expand Down Expand Up @@ -1051,7 +1087,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(

if (inplace) {
// Do for-loop for in-place because we cannot reshape
// `self_` having an incompatible stride without copying
// `self_` having an incompatible stride without copying.
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
Expand All @@ -1066,31 +1102,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(

self_ = self_bdim.has_value() ? self_ : self_.clone();

if (self_logical_rank != 0){
auto index_offset = at::arange(
batch_size,
at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
);
if (index_logical_rank == 0){
index_ = index_.unsqueeze(-1);
}
index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
index_ = reshape_dim_into(0, 0, index_);
self_ = reshape_dim_into(0, dim, self_);
self_.index_fill_(dim, index_, value);
self_ = reshape_dim_outof(dim, batch_size, self_);
return std::make_tuple(self_, dim);
}

// If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
if (index_logical_rank != 0){
index_ = reshape_dim_into(0, 0, index_);
}
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);

return std::make_tuple(self_, 0);
return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
Expand All @@ -1100,6 +1112,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
const Tensor & value, optional<int64_t> value_bdim,
const bool inplace) {
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor index_ = moveBatchDimToFront(index, index_bdim);
Tensor value_ = moveBatchDimToFront(value, value_bdim);
Expand All @@ -1123,22 +1136,27 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);

self_ = self_bdim.has_value() ? self_ : self_.clone();

for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
if (inplace || value_bdim.has_value()) {
// Do for-loop for in-place because we cannot reshape
// `self_` having an incompatible stride without copying.
// If value has a batch dim, we do for-loop as well because
// index_fill_ supports 1-element tensor only.
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_bdim.has_value() ? value_.select(0, i) : value_
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we write this for-loop such that we won't have to check value_bdim every iteration?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is the most concise version, and checking value_bdim would not be costly.

Here are the ways I came up:

  1. Always ensure value has a batch dim and do select in every iteration - costly.
  2. Make copies of the for-loop statement and separate the work out. There are four combinations for inplace and value_bdim.has_value() - not concise.
  3. Determine value_bdim.has_value() before for-loop and use lambdas to wrap either value_ or value_.select(0, i) - I feel unnecessary tbh. Would it be better?

Or maybe I'm missing something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, makes sense, plus I would assume that most of the time would be spent in index_fill so this should be good.

);
}
return std::make_tuple(self_, 0);
}

return std::make_tuple(self_, 0);
self_ = self_bdim.has_value() ? self_ : self_.clone();
qqaatw marked this conversation as resolved.
Show resolved Hide resolved

return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value.item());
qqaatw marked this conversation as resolved.
Show resolved Hide resolved
}

void index_fill__int_scalar_batch_rule(
Expand Down
28 changes: 14 additions & 14 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3885,56 +3885,56 @@ def test2():
x = torch.zeros(B, 3, device=device)
dim = 0
index = torch.tensor([[0], [1]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

def test3():
# self batched, self logical rank 1, index logical rank 0
x = torch.zeros(B, 3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

def test4():
# self not batched, self logical rank 0, index logical rank 1
x = torch.zeros([], device=device)
dim = 0
index = torch.tensor([[0], [0]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test5():
# self not batched, self logical rank 0, index logical rank 0
x = torch.zeros([], device=device)
dim = 0
index = torch.tensor([0, 0], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test6():
# self not batched, self logical rank 0, index logical rank 1
x = torch.zeros(3, device=device)
dim = 0
index = torch.tensor([[0], [1]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test7():
# self not batched, self logical rank 0, index logical rank 0
x = torch.zeros(3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test8():
# self batched, self logical rank > 1, index logical rank 0
x = torch.zeros(B, 3, 3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))
for value in (1.0, torch.rand((), device=device)):
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

for test in (test1, test2, test3, test4, test5, test6, test7, test8):
check_vmap_fallback(self, test, torch.index_fill)
Expand Down