Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new GPU kernel for [unsorted] segment reductions #51392

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
136 changes: 49 additions & 87 deletions tensorflow/core/kernels/segment_reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_

// This file requires the following include because it uses GpuAtomicMax:
// #include "tensorflow/core/util/gpu_kernel_helper.h"

// Unfortunately we can't add the #include, since it breaks compilation for
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and GpuAtomicMax is used in template context.

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
Expand All @@ -32,6 +25,7 @@ namespace tensorflow {

class OpKernelContext;

bool UseNonDeterministicSegmentReductions();
bool DisableSegmentReductionOpDeterminismExceptions();

// Type of SparseSegmentReduction operation to perform gradient of.
Expand All @@ -40,9 +34,51 @@ enum class SparseSegmentReductionOperation { kSum, kMean, kSqrtN };
namespace functor {

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Note that we define this ourselves to avoid a dependency on gpuprim.
struct Sum {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return a + b;
}
};

struct Prod {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return a * b;
}
};

// Note that we don't use gpuprim::Min/Max because they use operator<, which is
// not implemented for AlignedVector types.
struct Min {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return min(a, b);
}
};

struct Max {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return max(a, b);
}
};

template <typename ReduceOp, typename T>
struct ReduceOpIsAssociative {};
template <typename T>
struct ReduceOpIsAssociative<functor::Sum, T> : std::is_integral<T> {};
template <typename T>
struct ReduceOpIsAssociative<functor::Prod, T> : std::is_integral<T> {};
template <typename T>
struct ReduceOpIsAssociative<functor::Max, T> : std::true_type {};
template <typename T>
struct ReduceOpIsAssociative<functor::Min, T> : std::true_type {};

typedef Eigen::GpuDevice GPUDevice;
// Functor for SegmentSumGPUOp & SegmentProdGPUOp & SegmentMaxGPUOp
// & SegmentMinGPUOp.
// Functor for SegmentReductionGPUOp.
// output_rows: the number of output segments (unique segment ids in
// 'segment_ids').
// segment_ids_shape: shape of 'segment_ids' tensor.
Expand All @@ -52,18 +88,18 @@ typedef Eigen::GpuDevice GPUDevice;
// data: input data tensor.
// output: output reshaped to {output_rows, output.size/output_rows}
template <typename T, typename Index, typename InitialValueF,
typename ReductionF, typename AtomicReductionF>
typename EmptySegmentValueF, typename ReductionF>
struct SegmentReductionFunctor {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
const Index output_rows, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
bool is_mean, typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::Tensor output);
static constexpr bool atomic_reduction_is_associative =
AtomicReductionF::is_associative;
ReduceOpIsAssociative<ReductionF, T>::value;
};

#endif
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

template <typename Device, typename T, typename Index, typename InitialValueF,
typename ReductionF>
Expand All @@ -74,80 +110,6 @@ struct UnsortedSegmentFunctor {
typename TTypes<T, 2>::Tensor output);
};

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Atomic reduction functors for the gpu.
template <typename T>
struct AtomicSumOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicAdd(dest, value);
}
static constexpr bool is_associative = std::is_integral<T>::value;
};

template <typename T>
struct AtomicProdOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMul(dest, value);
}
static constexpr bool is_associative = std::is_integral<T>::value;
};

template <typename T>
struct AtomicMaxOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMax(dest, value);
}
static constexpr bool is_associative = true;
};

template <typename T>
struct AtomicMinOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMin(dest, value);
}
static constexpr bool is_associative = true;
};

// Non-atomic reduction functors for the gpu.
template <typename T>
struct NonAtomicSumOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest += value;
}
};

template <typename T>
struct NonAtomicProdOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest *= value;
}
};

template <typename T>
struct NonAtomicMaxOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest = max(*dest, value);
}
};

template <typename T>
struct NonAtomicMinOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest = min(*dest, value);
}
};

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Initial value functors.
template <typename T>
struct Zero {
Expand Down