Skip to content

Commit

Permalink
bad_indcies_on_cpu for ScatterNd and GatherNd
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636779834
  • Loading branch information
tensorflower-gardener committed May 24, 2024
1 parent b4afcb3 commit 245a343
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def LegalizeGather: Pat<
(TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">,
ConstantAttr<I32Attr, "0">)>;

def LegalizeGatherNd : Pat<(TF_GatherNdOp $params, $indices),
def LegalizeGatherNd : Pat<(TF_GatherNdOp $params, $indices, $bad_indices_on_cpu),
(TFL_GatherNdOp $params, $indices)>;

def LegalizeGatherV2 : Pat<
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -6230,7 +6230,9 @@ See also `tf.gather` and `tf.batch_gather`.

let arguments = (ins
Arg<TF_Tensor, [{The tensor from which to gather values.}]>:$params,
Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{Index tensor.}]>:$indices
Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{Index tensor.}]>:$indices,

DefaultValuedOptionalAttr<StrAttr, "\"\"">:$bad_indices_on_cpu
);

let results = (outs
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ package_group(
packages = [
"//tensorflow/...",
"//tensorflow_text/...",
"//waymo/ml/compiler/frontend/kernels/...",
"//waymo/onboard/ml/...",
],
)
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/core/kernels/gather_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class GatherNdOp : public OpKernel {

Tensor out;
OP_REQUIRES_OK(
c, functor::DoGatherNd<Device, T, Index, /*kDropBadIndices=*/false>(
c, params, indices, &out));
c, functor::DoGatherNd<Device, T, Index>(c, params, indices, &out));
c->set_output(0, out);
}
};
Expand Down
7 changes: 1 addition & 6 deletions tensorflow/core/kernels/gather_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ struct GatherNdSlice {
typename TTypes<T>::Matrix Tout);
};

template <typename Device, typename T, typename Index,
bool kDropBadIndices = false>
template <typename Device, typename T, typename Index>
Status DoGatherNd(OpKernelContext* c, const Tensor& params,
const Tensor& indices, Tensor* out) {
if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) {
Expand Down Expand Up @@ -152,10 +151,6 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
indices_nd);
}

if constexpr (kDropBadIndices) {
return absl::OkStatus();
}

// bad_i will only return >= 0 on CPUs right now.
if (bad_i >= 0) {
auto shape = indices.shape();
Expand Down
63 changes: 21 additions & 42 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ class IndexFlattener {
namespace {

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
scatter_nd_op::UpdateOp Op>
Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -925,11 +925,7 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
for (int i = 0; i < IXDIM; ++i) { \
output_shape_prefix[i] = shape.dim_size(i); \
} \
constexpr bool kShallDropBadIndices = \
kDropBadIndices || std::is_same<Device, GPUDevice>::value; \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM, \
kShallDropBadIndices> \
functor; \
functor::ScatterNdFunctor<Device, T, Index, Op, IXDIM> functor; \
bad_i = \
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
output_matrix, indices_flat, updates_flat, output_matrix); \
Expand All @@ -951,9 +947,6 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices,
slice_dim);
}
}
if constexpr (kDropBadIndices) {
return absl::OkStatus();
}
if (bad_i >= 0) {
auto slice_shape = indices.shape();
slice_shape.RemoveLastDims(1);
Expand All @@ -977,8 +970,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
// back to GPU. This is useful because the CPU implementation is deterministic
// and the GPU implementation is not. Tensor inputs to this function must be on
// the GPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp Op,
bool kDropBadIndices>
template <typename T, typename Index, scatter_nd_op::UpdateOp Op>
Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape,
Tensor* out, bool allocate) {
Expand Down Expand Up @@ -1023,7 +1015,7 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
}

TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op, kDropBadIndices>(
TF_RETURN_IF_ERROR(DoScatterNd<CPUDevice, T, Index, Op>(
c, host_indices, host_updates, shape, &host_out, /*allocate=*/false));

// Copy 'host_out' to device.
Expand All @@ -1041,57 +1033,44 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices,
} // namespace

template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op, bool kDropBadIndices>
scatter_nd_op::UpdateOp Op>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (std::is_same<Device, GPUDevice>::value &&
tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) {
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Run on the CPU for integer types, since the GPU implementation uses
// atomics, which are not supported for all integer types.
if constexpr (std::is_same<Device, GPUDevice>::value &&
std::is_integral<T>::value) {
return DoScatterNdOnCpu<T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
return DoScatterNdOnCpu<T, Index, Op>(c, indices, updates, shape, out,
allocate);
} else {
return DoScatterNdImpl<Device, T, Index, Op, kDropBadIndices>(
c, indices, updates, shape, out, allocate);
return DoScatterNdImpl<Device, T, Index, Op>(c, indices, updates, shape,
out, allocate);
}
}
} // namespace functor

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index \
ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, /*kDropBadIndices=*/true>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>; \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/false>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>::operator()( \
const GPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/scatter_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op, int IXDIM, bool kDropBadIndices>
scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor {
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index operator()(
Expand All @@ -63,7 +63,7 @@ struct ScatterNdFunctor {
// right type (T) and shape. This tensor will not be zeroed out
// before the scatter is executed.
template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp Op, bool kDropBadIndices = false>
scatter_nd_op::UpdateOp Op>
Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
const Tensor& updates, const TensorShape& shape, Tensor* out,
bool allocate);
Expand Down
52 changes: 20 additions & 32 deletions tensorflow/core/kernels/scatter_nd_op_cpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ class UpdateExecutor<T, Input, Update, Output, scatter_nd_op::UpdateOp::MAX> {
namespace functor {

// Implementation of update functor for CPU.
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM, kDropBadIndices> {
template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM> {
Index operator()(
const CPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -137,44 +136,33 @@ struct ScatterNdFunctor<CPUDevice, T, Index, OP, IXDIM, kDropBadIndices> {
i += ix_d * batch_strides[dim];
}
if (TF_PREDICT_FALSE(out_of_bounds)) {
if constexpr (kDropBadIndices) {
continue;
}
error_loc = loc;
break;
} else {
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}
auto input_chip = Toutput.template chip<0>(i);
auto output_chip = input_chip;
auto update_chip = Tupdates.template chip<0>(loc);
update_executor::UpdateExecutor<
CPUDevice, decltype(input_chip), decltype(update_chip),
decltype(output_chip), OP>::Execute(d, input_chip, update_chip,
output_chip);
}

return error_loc;
}
};

#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/false>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput); \
template Index ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM, \
/*kDropBadIndices=*/true>:: \
operator()(const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)
#define REGISTER_SCATTER_ND_FULL(T, Index, op) \
template Index \
ScatterNdFunctor<CPUDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator()( \
const CPUDevice& d, const Index slice_size, \
const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
output_shape_prefix, \
typename TTypes<T, 2>::Tensor Tparams, \
typename TTypes<Index, 2>::ConstTensor Tindices, \
typename TTypes<T, 2>::ConstTensor Tupdates, \
typename TTypes<T, 2>::Tensor Toutput)

#define REGISTER_SCATTER_ND_INDEX(type, op) \
REGISTER_SCATTER_ND_FULL(type, int32, op); \
Expand Down
10 changes: 4 additions & 6 deletions tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ __global__ void ScatterNdOpKernel(
namespace functor {

// Functor used by ScatterOp to do the computations.
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM,
bool kDropBadIndices>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, kDropBadIndices> {
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM> {
Index operator()(
const GPUDevice& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
Expand Down Expand Up @@ -165,9 +164,8 @@ struct ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, kDropBadIndices> {

} // namespace functor

#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM, \
/*kDropBadIndices=*/true>;
#define DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, IXDIM) \
template struct functor::ScatterNdFunctor<GPUDevice, T, Index, op, IXDIM>;

#define DECLARE_GPU_SPECS_INDEX_OP(T, Index, op) \
DECLARE_GPU_SPECS_INDEX_OP_IXDIM(T, Index, op, 1); \
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/ops/array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,7 @@ REGISTER_OP("GatherNd")
.Output("output: Tparams")
.Attr("Tparams: type")
.Attr("Tindices: {int16,int32,int64}")
.Attr("bad_indices_on_cpu: string = ''")
.SetShapeFn(shape_inference::GatherNdShape);

// --------------------------------------------------------------------------
Expand Down

0 comments on commit 245a343

Please sign in to comment.