Skip to content

Commit

Permalink
Add vmap support for torch.index_fill (#91364)
Browse files Browse the repository at this point in the history
Fixes #91177

Pull Request resolved: #91364
Approved by: https://github.com/zou3519
  • Loading branch information
qqaatw authored and pytorchmergebot committed Jan 30, 2023
1 parent 08035b1 commit 5112f44
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 5 deletions.
162 changes: 162 additions & 0 deletions aten/src/ATen/functorch/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,164 @@ std::tuple<Tensor,optional<int64_t>> masked_fill_scalar_batch_rule(
return std::make_tuple(result, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value,
const bool inplace) {
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor index_ = moveBatchDimToFront(index, index_bdim);
dim = maybe_wrap_dim(dim, self_logical_rank);

if (inplace && !self_bdim.has_value()) {
vmapIncompatibleInplaceError("index_fill_");
}

if (!index_bdim) {
if (self_logical_rank == 0){
self_.unsqueeze_(-1);
}
self_.index_fill_(dim + 1, index_, value);
if (self_logical_rank == 0) {
self_.squeeze_(-1);
}
return std::make_tuple(self_, 0);
}

auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);

if (inplace) {
// Do for-loop for in-place because we cannot reshape
// `self_` having an incompatible stride without copying
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value
);
}
return std::make_tuple(self_, 0);
}

self_ = self_bdim.has_value() ? self_ : self_.clone();

if (self_logical_rank != 0){
auto index_offset = at::arange(
batch_size,
at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
);
if (index_logical_rank == 0){
index_ = index_.unsqueeze(-1);
}
index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
index_ = reshape_dim_into(0, 0, index_);
self_ = reshape_dim_into(0, dim, self_);
self_.index_fill_(dim, index_, value);
self_ = reshape_dim_outof(dim, batch_size, self_);
return std::make_tuple(self_, dim);
}

// If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
if (index_logical_rank != 0){
index_ = reshape_dim_into(0, 0, index_);
}
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);

return std::make_tuple(self_, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Tensor & value, optional<int64_t> value_bdim,
const bool inplace) {
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor index_ = moveBatchDimToFront(index, index_bdim);
Tensor value_ = moveBatchDimToFront(value, value_bdim);
dim = maybe_wrap_dim(dim, self_logical_rank);

if (inplace && !self_bdim.has_value()) {
vmapIncompatibleInplaceError("index_fill_");
}

if (!index_bdim && !value_bdim) {
if (self_logical_rank == 0){
self_.unsqueeze_(-1);
}
self_.index_fill_(dim + 1, index_, value);
if (self_logical_rank == 0) {
self_.squeeze_(-1);
}
return std::make_tuple(self_, 0);
}

auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);

self_ = self_bdim.has_value() ? self_ : self_.clone();

for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
}

return std::make_tuple(self_, 0);
}

void index_fill__int_scalar_batch_rule(
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value) {
index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
}

void index_fill__int_tensor_batch_rule(
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Tensor & value, optional<int64_t> value_bdim) {
index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
const Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value) {
auto self_ = self.clone(at::MemoryFormat::Preserve);
return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
const Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor & index, optional<int64_t> index_bdim,
const Tensor & value, optional<int64_t> value_bdim) {
auto self_ = self.clone(at::MemoryFormat::Preserve);
return index_fill_int_tensor_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, value_bdim, false);
}


TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("index.Tensor", index_plumbing);
m.impl("index_put_", index_put__plumbing);
Expand All @@ -1066,6 +1224,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("index_copy", index_copy_decomp);
m.impl("index_select", index_select_decomp);
VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule);
VMAP_SUPPORT2(index_fill_, int_Tensor, index_fill__int_tensor_batch_rule);
VMAP_SUPPORT2(index_fill_, int_Scalar, index_fill__int_scalar_batch_rule);
VMAP_SUPPORT2(index_fill, int_Tensor, index_fill_int_tensor_batch_rule);
VMAP_SUPPORT2(index_fill, int_Scalar, index_fill_int_scalar_batch_rule);
VMAP_SUPPORT(index_add, index_add_batch_rule);
VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
VMAP_SUPPORT(gather, gather_batch_rule);
Expand Down
3 changes: 0 additions & 3 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,6 @@ def test_vmapjvpall(self, device, dtype, op):
xfail('fill'),
skip('masked.mean'), # ???
xfail('masked_scatter'),
xfail('index_fill'),
xfail('put'),
xfail('take'),
xfail('nn.functional.max_pool3d'),
Expand Down Expand Up @@ -1114,8 +1113,6 @@ def test():
xfail('fill'),
xfail('narrow'), # Batching rule not implemented for `narrow.Tensor` (and view op)
xfail('special.log_ndtr'),
xfail('index_copy'),
xfail('index_fill'),
xfail('linalg.householder_product'),
xfail('lu'),
xfail('lu_solve'),
Expand Down
75 changes: 73 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3613,7 +3613,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('native_batch_norm'),
xfail('_native_batch_norm_legit'),
xfail('histogram'),
xfail('index_fill'),
xfail('scatter_reduce', 'sum'),
xfail('scatter_reduce', 'mean'),
xfail('scatter_reduce', 'amax'),
Expand Down Expand Up @@ -3861,11 +3860,83 @@ def test_slogdet(self, device):
# There's no OpInfo for this
def test():
B = 2
x = torch.randn(2, 5, 5, device=device)
x = torch.randn(B, 5, 5, device=device)
self.vmap_outplace_test(torch.slogdet, (x,), {}, (0,))

check_vmap_fallback(self, test, torch.slogdet)

def test_index_fill(self, device):
# There's no OpInfo for these tests

B = 2

def test1():
# negative dim
x = torch.randn(B, 5, 5, device=device)
dim = -2
index = torch.tensor([[2, 3], [0, 4]], device=device)
value = 5.0
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test2():
# self batched, self logical rank 1, index logical rank 1
x = torch.zeros(B, 3, device=device)
dim = 0
index = torch.tensor([[0], [1]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

def test3():
# self batched, self logical rank 1, index logical rank 0
x = torch.zeros(B, 3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

def test4():
# self not batched, self logical rank 0, index logical rank 1
x = torch.zeros([], device=device)
dim = 0
index = torch.tensor([[0], [0]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test5():
# self not batched, self logical rank 0, index logical rank 0
x = torch.zeros([], device=device)
dim = 0
index = torch.tensor([0, 0], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test6():
# self not batched, self logical rank 0, index logical rank 1
x = torch.zeros(3, device=device)
dim = 0
index = torch.tensor([[0], [1]], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test7():
# self not batched, self logical rank 0, index logical rank 0
x = torch.zeros(3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (None, None, 0, None))

def test8():
# self batched, self logical rank > 1, index logical rank 0
x = torch.zeros(B, 3, 3, device=device)
dim = 0
index = torch.tensor([0, 1], device=device)
value = 1
self.vmap_outplace_test(torch.index_fill, (x, dim, index, value), {}, (0, None, 0, None))

for test in (test1, test2, test3, test4, test5, test6, test7, test8):
check_vmap_fallback(self, test, torch.index_fill)

def test_fill__Tensor(self, device):
# There's no OpInfo for fill_.Tensor, so here's an extra test for it.
def test():
Expand Down

0 comments on commit 5112f44

Please sign in to comment.