Skip to content

Commit

Permalink
Added option to drop bad gather index. This is to help developer add …
Browse files Browse the repository at this point in the history
…GPU-compatible version of GatherNd afterwards.

PiperOrigin-RevId: 636206934
  • Loading branch information
tensorflower-gardener committed May 22, 2024
1 parent 98a4c09 commit df40f8d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/gather_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class GatherNdOp : public OpKernel {

Tensor out;
OP_REQUIRES_OK(
c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
c, functor::DoGatherNd<Device, T, Index, /*kDropBadIndices=*/false>(
c, params, indices, &out));
c->set_output(0, out);
}
};
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/core/kernels/gather_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ struct GatherNdSlice {
typename TTypes<T>::Matrix Tout);
};

template <typename Device, typename T, typename Index>
template <typename Device, typename T, typename Index,
bool kDropBadIndices = false>
Status DoGatherNd(OpKernelContext* c, const Tensor& params,
const Tensor& indices, Tensor* out) {
if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
Expand Down Expand Up @@ -151,6 +152,10 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
indices_nd);
}

if constexpr (kDropBadIndices) {
return absl::OkStatus();
}

// bad_i will only return >= 0 on CPUs right now.
if (bad_i >= 0) {
auto shape = indices.shape();
Expand Down

0 comments on commit df40f8d

Please sign in to comment.