Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Dec 27, 2022
1 parent 1ec685e commit b0cbf0f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 40 deletions.
100 changes: 62 additions & 38 deletions aten/src/ATen/functorch/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1056,19 +1056,7 @@ std::tuple<Tensor,optional<int64_t>> masked_fill_scalar_batch_rule(
return std::make_tuple(result, 0);
}

Tensor ensure_has_bdim_copy(const Tensor& tensor, bool has_bdim, int64_t batch_size) {
if (has_bdim) {
return tensor;
}
const auto sizes = tensor.sizes();
DimVector expanded_shape;
expanded_shape.reserve(sizes.size()); // dims
expanded_shape.emplace_back(batch_size);
expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
return tensor.repeat(expanded_shape);
}

std::tuple<Tensor,optional<int64_t>> index_fill_batch_rule_impl(
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
int dim,
const Tensor & index, optional<int64_t> index_bdim,
Expand All @@ -1080,63 +1068,99 @@ std::tuple<Tensor,optional<int64_t>> index_fill_batch_rule_impl(
Tensor index_ = moveBatchDimToFront(index, index_bdim);
dim = maybe_wrap_dim(dim, self_logical_rank);

if (inplace) {
if (!self_bdim.has_value())
vmapIncompatibleInplaceError("index_fill_");
if (index_bdim.has_value() && (self_.size(0) < index_.size(0)))
vmapIncompatibleInplaceError("index_fill_");
if (inplace && !self_bdim.has_value()) {
vmapIncompatibleInplaceError("index_fill_");
}

if (!index_bdim) {
if (inplace){
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);
} else {
self_ = self_.unsqueeze(-1);
self_ = self_.index_fill(dim + 1, index_, value);
self_ = self_.squeeze(-1);
}
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);
return std::make_tuple(self_, 0);
}

auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);

if (!inplace){
//if (!self_bdim.has_value())
if (!self_bdim.has_value())
self_ = self_.clone();
}

for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_bdim.has_value() ?
self_.select(0, i) : self_;
const auto& index_slice = index_bdim.has_value() ?
index_.select(0, i) : index_;
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
self_slice.index_fill_(
self_bdim.has_value() || self.dim() == 0 ? dim : dim + 1,
dim,
index_slice,
value
);
}
return std::make_tuple(self_, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
int dim,
const Tensor & index, optional<int64_t> index_bdim,
const Tensor & value, optional<int64_t> value_bdim,
const bool inplace) {

const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor index_ = moveBatchDimToFront(index, index_bdim);
Tensor value_ = moveBatchDimToFront(value, value_bdim);
dim = maybe_wrap_dim(dim, self_logical_rank);

if (inplace && !self_bdim.has_value()) {
vmapIncompatibleInplaceError("index_fill_");
}

if (!index_bdim && !value_bdim) {
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);
return std::make_tuple(self_, 0);
}

auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);

if (!inplace){
if (!self_bdim.has_value())
self_ = self_.clone();
}

for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
}
return std::make_tuple(self_, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill__int_scalar_batch_rule(
Tensor & self, optional<int64_t> self_bdim,
int dim,
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value) {
return index_fill_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
return index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
}

std::tuple<Tensor,optional<int64_t>> index_fill__int_tensor_batch_rule(
Tensor & self, optional<int64_t> self_bdim,
int dim,
const Tensor & index, optional<int64_t> index_bdim,
const Tensor & value, optional<int64_t> value_bdim) {
return index_fill_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value.item(), true);
return index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
Expand All @@ -1145,7 +1169,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
const Tensor & index, optional<int64_t> index_bdim,
const Scalar & value) {
auto self_ = self.clone(at::MemoryFormat::Preserve);
return index_fill_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
Expand All @@ -1154,7 +1178,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
const Tensor & index, optional<int64_t> index_bdim,
const Tensor & value, optional<int64_t> value_bdim) {
auto self_ = self.clone(at::MemoryFormat::Preserve);
return index_fill_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value.item(), false);
return index_fill_int_tensor_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, value_bdim, false);
}


Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,8 +1126,6 @@ def test():
xfail('nanmean'),
xfail('narrow'), # Batching rule not implemented for `narrow.Tensor` (and view op)
xfail('special.log_ndtr'),
xfail('index_copy'),
xfail('index_fill'),
xfail('linalg.eig'),
xfail('linalg.householder_product'),
xfail('lu'),
Expand Down

0 comments on commit b0cbf0f

Please sign in to comment.