Skip to content

Commit

Permalink
Rollback of #51392 (for the second time).
Browse files Browse the repository at this point in the history
Causes an internal performance regression.

PiperOrigin-RevId: 394788152
Change-Id: I702fb4ec245823b96ce82f58b1b0d6c505b674c3
  • Loading branch information
reedwm authored and tensorflower-gardener committed Sep 4, 2021
1 parent 7bf4357 commit 51ee415
Show file tree
Hide file tree
Showing 13 changed files with 326 additions and 506 deletions.
136 changes: 87 additions & 49 deletions tensorflow/core/kernels/segment_reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_

// This file requires the following include because it uses GpuAtomicMax:
// #include "tensorflow/core/util/gpu_kernel_helper.h"

// Unfortunately we can't add the #include, since it breaks compilation for
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and GpuAtomicMax is used in template context.

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
Expand All @@ -25,7 +32,6 @@ namespace tensorflow {

class OpKernelContext;

bool UseNonDeterministicSegmentReductions();
bool DisableSegmentReductionOpDeterminismExceptions();

// Type of SparseSegmentReduction operation to perform gradient of.
Expand All @@ -34,51 +40,9 @@ enum class SparseSegmentReductionOperation { kSum, kMean, kSqrtN };
namespace functor {

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Note that we define this ourselves to avoid a dependency on gpuprim.
struct Sum {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return a + b;
}
};

struct Prod {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return a * b;
}
};

// Note that we don't use gpuprim::Min/Max because they use operator<, which is
// not implemented for AlignedVector types.
struct Min {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return min(a, b);
}
};

struct Max {
template <typename T>
__host__ __device__ T operator()(const T& a, const T& b) const {
return max(a, b);
}
};

template <typename ReduceOp, typename T>
struct ReduceOpIsAssociative {};
template <typename T>
struct ReduceOpIsAssociative<functor::Sum, T> : std::is_integral<T> {};
template <typename T>
struct ReduceOpIsAssociative<functor::Prod, T> : std::is_integral<T> {};
template <typename T>
struct ReduceOpIsAssociative<functor::Max, T> : std::true_type {};
template <typename T>
struct ReduceOpIsAssociative<functor::Min, T> : std::true_type {};

typedef Eigen::GpuDevice GPUDevice;
// Functor for SegmentReductionGPUOp.
// Functor for SegmentSumGPUOp & SegmentProdGPUOp & SegmentMaxGPUOp
// & SegmentMinGPUOp.
// output_rows: the number of output segments (unique segment ids in
// 'segment_ids').
// segment_ids_shape: shape of 'segment_ids' tensor.
Expand All @@ -88,18 +52,18 @@ typedef Eigen::GpuDevice GPUDevice;
// data: input data tensor.
// output: output reshaped to {output_rows, output.size/output_rows}
template <typename T, typename Index, typename InitialValueF,
typename EmptySegmentValueF, typename ReductionF>
typename ReductionF, typename AtomicReductionF>
struct SegmentReductionFunctor {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
const Index output_rows, const TensorShape& segment_ids_shape,
bool is_mean, typename TTypes<Index>::ConstFlat segment_ids,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::Tensor output);
static constexpr bool atomic_reduction_is_associative =
ReduceOpIsAssociative<ReductionF, T>::value;
AtomicReductionF::is_associative;
};

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif

template <typename Device, typename T, typename Index, typename InitialValueF,
typename ReductionF>
Expand All @@ -110,6 +74,80 @@ struct UnsortedSegmentFunctor {
typename TTypes<T, 2>::Tensor output);
};

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Atomic reduction functors for the gpu.
template <typename T>
struct AtomicSumOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicAdd(dest, value);
}
static constexpr bool is_associative = std::is_integral<T>::value;
};

template <typename T>
struct AtomicProdOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMul(dest, value);
}
static constexpr bool is_associative = std::is_integral<T>::value;
};

template <typename T>
struct AtomicMaxOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMax(dest, value);
}
static constexpr bool is_associative = true;
};

template <typename T>
struct AtomicMinOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
GpuAtomicMin(dest, value);
}
static constexpr bool is_associative = true;
};

// Non-atomic reduction functors for the gpu.
template <typename T>
struct NonAtomicSumOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest += value;
}
};

template <typename T>
struct NonAtomicProdOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest *= value;
}
};

template <typename T>
struct NonAtomicMaxOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest = max(*dest, value);
}
};

template <typename T>
struct NonAtomicMinOpGpu {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest,
const T& value) {
*dest = min(*dest, value);
}
};

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Initial value functors.
template <typename T>
struct Zero {
Expand Down

0 comments on commit 51ee415

Please sign in to comment.