Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@

- op: scalar_tensor.out

- op: scatter.src_out

- op: scatter.value_out

- op: scatter_add.out
Expand Down
81 changes: 74 additions & 7 deletions kernels/portable/cpu/op_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,46 @@ using ScalarType = exec_aten::ScalarType;

namespace {

template <typename CTYPE>
void scatter_src_helper(
const Tensor& in,
int64_t dim,
const Tensor& index,
const Tensor& src,
Tensor& out) {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
const long* index_data = index.const_data_ptr<long>();
const CTYPE* src_data = src.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

memcpy(out_data, in_data, in.nbytes());

if (dim < 0) {
dim += nonzero_dim(in);
}

for (size_t ix = 0; ix < index.numel(); ++ix) {
// @lint-ignore CLANGTIDY facebook-hte-CArray
size_t ix_coord[kTensorDimensionLimit];
indexToCoordinate(index, ix, ix_coord);

size_t src_ix = coordinateToIndex(src, ix_coord);

// @lint-ignore CLANGTIDY facebook-hte-CArray
size_t out_coord[kTensorDimensionLimit];
for (size_t i = 0; i < out.dim(); ++i) {
if (i == dim) {
out_coord[i] = index_data[ix];
} else {
out_coord[i] = ix_coord[i];
}
}
size_t out_ix = coordinateToIndex(out, out_coord);

out_data[out_ix] = src_data[src_ix];
}
}

template <typename CTYPE, typename CTYPE_VAL>
void scatter_value_helper(
const Tensor& in,
Expand All @@ -36,15 +76,16 @@ void scatter_value_helper(

memcpy(out_data, in_data, in.nbytes());

if (index.dim() == 0) {
out_data[index_data[0]] = static_cast<CTYPE>(val);
return;
if (dim < 0) {
dim += nonzero_dim(in);
}

for (size_t ix = 0; ix < index.numel(); ++ix) {
// @lint-ignore CLANGTIDY facebook-hte-CArray
size_t ix_coord[kTensorDimensionLimit];
indexToCoordinate(index, ix, ix_coord);

// @lint-ignore CLANGTIDY facebook-hte-CArray
size_t out_coord[kTensorDimensionLimit];
for (size_t i = 0; i < out.dim(); ++i) {
if (i == dim) {
Expand All @@ -61,6 +102,36 @@ void scatter_value_helper(

} // namespace

Tensor& scatter_src_out(
RuntimeContext& context,
const Tensor& in,
int64_t dim,
const Tensor& index,
const Tensor& src,
Tensor& out) {
(void)context;

ET_KERNEL_CHECK(
context,
check_scatter_src_args(in, dim, index, src, out),
InvalidArgument,
out);

ET_KERNEL_CHECK(
context,
resize_tensor(out, in.sizes()) == Error::Ok,
InvalidArgument,
out);

constexpr auto name = "scatter.src_out";

ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
scatter_src_helper<CTYPE>(in, dim, index, src, out);
});

return out;
}

Tensor& scatter_value_out(
RuntimeContext& ctx,
const Tensor& in,
Expand All @@ -79,10 +150,6 @@ Tensor& scatter_value_out(
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);

if (dim < 0) {
dim += nonzero_dim(in);
}

ScalarType val_type = utils::get_scalar_dtype(value);

constexpr auto name = "scatter.value_out";
Expand Down
9 changes: 9 additions & 0 deletions kernels/portable/cpu/util/index_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ bool check_scatter_add_args(
return true;
}

bool check_scatter_src_args(
const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& src,
Tensor& out) {
return check_scatter_add_args(self, dim, index, src, out);
}

bool check_scatter_value_args(
const Tensor& self,
int64_t dim,
Expand Down
7 changes: 7 additions & 0 deletions kernels/portable/cpu/util/index_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ bool check_scatter_add_args(
const Tensor& src,
Tensor& out);

bool check_scatter_src_args(
const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& src,
Tensor& out);

bool check_scatter_value_args(
const Tensor& self,
int64_t dim,
Expand Down
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,11 @@
- arg_meta: null
kernel_name: torch::executor::scalar_tensor_out

- op: scatter.src_out
kernels:
- arg_meta: null
kernel_name: torch::executor::scatter_src_out

- op: scatter.value_out
kernels:
- arg_meta: null
Expand Down
Loading
Loading