From 7e3a395ba944d015cb03785b473b51ad322b6a46 Mon Sep 17 00:00:00 2001 From: "Kang, Letian" Date: Thu, 7 Mar 2019 14:39:23 +0800 Subject: [PATCH 1/6] The BiasAddGrad op is running on single thread, which badly influences the training performance of corresponding models. We provide a optimized parallel implementation of BiasAddGrad op. Change-Id: I3a8da878eea67a4903a3b68302c6c86c3e536025 --- tensorflow/core/kernels/bias_op.cc | 220 ++++++++++++++---- .../python/kernel_tests/bias_op_test.py | 20 +- 2 files changed, 198 insertions(+), 42 deletions(-) diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 074f64a634aa83..5a334146ee1691 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -18,13 +18,15 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/bias_op.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA #include "tensorflow/core/kernels/bias_op_gpu.h" @@ -139,10 +141,10 @@ class BiasOp : public BinaryOp { Eigen::DSizes three_dims(1, channel, 1); Eigen::DSizes broad_cast_dims(batch, 1, height); const Device& d = context->eigen_device(); - output->tensor().device(d) = - input.tensor() + bias.tensor() - .reshape(three_dims) - .broadcast(broad_cast_dims); + output->tensor().device(d) = input.tensor() + + bias.tensor() + .reshape(three_dims) + .broadcast(broad_cast_dims); } break; case 4: { Eigen::DSizes four_dims(1, channel, 1, 1); @@ -250,9 +252,8 @@ class BiasGradOp : public OpKernel { output_backprop.shape().DebugString())); OP_REQUIRES( - context, - FastBoundsCheck(output_backprop.NumElements(), - std::numeric_limits::max()), + context, FastBoundsCheck(output_backprop.NumElements(), + std::numeric_limits::max()), errors::InvalidArgument("BiasGrad requires tensor size <= int32 max")); int32 batch, height, width, depth, channel; @@ -268,43 +269,184 @@ class BiasGradOp : public OpKernel { // Eigen often crashes by design on empty tensors, but setZero is safe output->template flat().setZero(); } else { - // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. - if (data_format_ == FORMAT_NCHW) { - Eigen::DSizes three_dims(batch, channel, - height * width * depth); -#ifdef EIGEN_HAS_INDEX_LIST - using idx0 = Eigen::type2index<0>; - using idx2 = Eigen::type2index<2>; - Eigen::IndexList reduction_axes; -#else - Eigen::array reduction_axes = {0, 2}; -#endif - output->template flat().device(context->eigen_device()) = - output_backprop.flat() - .template cast::type>() - .reshape(three_dims) - .sum(reduction_axes) - .template cast(); // End of code by intel_tf. - } else { - Eigen::DSizes two_dims(batch * height * width * depth, - channel); -#ifdef EIGEN_HAS_INDEX_LIST - Eigen::IndexList > reduction_axis; -#else - Eigen::array reduction_axis = {0}; -#endif - output->template flat().device(context->eigen_device()) = - output_backprop.flat() - .template cast::type>() - .reshape(two_dims) - .sum(reduction_axis) - .template cast(); + // Modified for performance tune :: begin here + //****************************************************************************** + // Divide the input tensor into several blocks. + // As we don't know anything about the detail shape of incoming input + // tensor, and it will be too complex to deal with different shapes + // separately, so we just evenly distribute total workloads to each + // block. + //****************************************************************************** + + // Init the output to zero + output->template flat().setZero(); + // Get the location of input/output data. + const T* input_ptr = output_backprop.template flat().data(); + T* result_ptr = output->template flat().data(); + + const bool format_is_nhwc = (data_format_ == FORMAT_NHWC); + + // Get the intra-thread pool handle. + auto worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + const int num_threads = worker_threads.num_threads; + auto workers = worker_threads.workers; + + // Get the workload parameters. + const int reduce_dims = batch * height * width * depth; + const int total_workload = reduce_dims * channel; + + // For small workloads, large block number will waste + // scheduling and compute resources. + // Use minimum workload and max_parallelism to limit total + // thread number and guarantee the workload for a each thread. + // Roughly, persume the CPU is 2GHz, using 1ns as a block, + // then each block gets about 2000 FLOP. + const int min_block_workloads = 2000; + // For NHWC format, we use each channel layer as a scheduling unit, + // while for NCHW format, we use each FLOP as a scheduling unit. + int parallel_cell_size = 1; + if ((format_is_nhwc) || + ((!format_is_nhwc) && (height * width * depth == 1))) + parallel_cell_size = channel; + const int max_parallelism = total_workload / parallel_cell_size; + const int min_block_size = + (min_block_workloads + parallel_cell_size - 1) / parallel_cell_size; + const int max_num_blocks = + std::min(max_parallelism, + (total_workload + min_block_size - 1) / min_block_size); + // As the BiasAddGradOp is a reducing op, + // it is necessary to build buffer for each block to avoid hazard. + // To minimize the buffer, the block number is no more than thread number. + const int num_blocks = std::min(max_num_blocks, num_threads); + + // Build&initialize buffers for blocks. + TensorShape output_buffer_shape({num_blocks, channel}); + Tensor block_results_buffer(output->dtype(), output_buffer_shape); + block_results_buffer.template flat().setZero(); + T* block_results_buffer_ptr = + block_results_buffer.template flat().data(); + + //****************************************************************************** + // Job func for each thread + //****************************************************************************** + auto BiasGradWorker = [this, input_ptr, format_is_nhwc, total_workload, + num_blocks, block_results_buffer_ptr, height, + width, depth, channel](int64 my_job_begin, + int64 my_job_end) -> void { + // We generate a cover of [0,total_workload), which is comprised of + // num_blocks non-overlapping divisions of [0,total_workload) + // EXP: If we get 22 elements in input tensor, which are divided + // into 4 blocks: + // + // lockId : 0 | 1 | 2 | 3 | res + // Elements: $$$$$ | $$$$$ | $$$$$ | $$$$$ | ** + // ↓ + // BlockId : 0 | 1 | 2 | 3 + // Elements: $$$$$ | $$$$$* | $$$$$ | $$$$$* + // Range : [0,5) | [5,11) | [11,16) | [16,22) + // 22*0/4=0 22*1/4=5 22*2/4=11 22*3/4=16 22*4/4=22 + const int64 block_begin = total_workload * my_job_begin / num_blocks; + const int64 block_end = total_workload * my_job_end / num_blocks; + + // Get buffer pointer. + T* block_result_ptr = &block_results_buffer_ptr[my_job_begin * channel]; + + if ((format_is_nhwc) || + ((!format_is_nhwc) && (height * width * depth == 1))) { + // Align the calculation by inner most dim. + const int64 align_begin = (block_begin / channel) * channel; + const int64 align_end = (block_end / channel) * channel; + // Apply the calculation. + for (int64 i = align_begin; i < align_end; i += channel) { + InplaceVecAdd(block_result_ptr, &input_ptr[i], channel); + } + } else { // For NCHW format + // A straight forward impl for NCHW could be like: + // for(int64 i=block_begin;i void { workers->Schedule(c); }, + max_parallelism); + + //****************************************************************************** + // Sum block results up + //****************************************************************************** + for (int64 i = 0; i < num_blocks; i++) { + InplaceVecAdd(result_ptr, &block_results_buffer_ptr[i * channel], + channel); } + // Modified for performance tune :: end here } } private: TensorFormat data_format_; + + // Modified for performance tune :: new funcs + // Apply X[0:length-1] = X[0:length-1] + Y[0:length-1]; + inline void InplaceVecAdd(T* X, const T* Y, const int64 length) { + //#pragma simd + for (int64 i = 0; i < length; i++) { + X[i] = X[i] + Y[i]; + } + } + // Return sum(X[0:length-1]) + inline T VecSumReduce(const T* X, const int64 length) { + T result = (T)0; + //#pragma simd + for (int64 i = 0; i < length; i++) { + result += X[i]; + } + return result; + } + // Modified for performance tune :: end here }; // Registration of the GPU implementations. diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py index 94e20d93017b07..41300ee304de2d 100644 --- a/tensorflow/python/kernel_tests/bias_op_test.py +++ b/tensorflow/python/kernel_tests/bias_op_test.py @@ -130,6 +130,12 @@ def test4DFloatTypes(self): self._testAll( np.random.rand(4, 3, 2, 3).astype(t), np.random.rand(3).astype(t)) + self._testAll( + np.random.rand(2048, 4, 4, 4).astype(t), + np.random.rand(4).astype(t)) + self._testAll( + np.random.rand(4, 4, 4, 2048).astype(t), + np.random.rand(2048).astype(t)) @test_util.run_deprecated_v1 def test5DFloatTypes(self): @@ -186,7 +192,7 @@ def _testGradient(self, np_input, bias, dtype, data_format, use_gpu): bias_add_grad, bias.shape) - threshold = 2e-3 + threshold = 5e-3 if dtype == dtypes.float64: threshold = 1e-10 self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold) @@ -215,14 +221,22 @@ def testGradientTensor3D(self): @test_util.run_deprecated_v1 def testGradientTensor4D(self): - for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True), - ("NCHW", False), ("NCHW", True)]: + for (data_format, use_gpu) in [("NHWC", False)]: for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): np_input = np.arange( 1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape( [2, 3, 4, 2]).astype(np.float32) bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype) self._testGradient(np_input, bias, dtype, data_format, use_gpu) + np_input = np.arange( + 1.0, 513.0, dtype=dtype.as_numpy_dtype).reshape( + [64, 2, 2, 2]).astype(np.float32) + self._testGradient(np_input, bias, dtype, data_format, use_gpu) + np_input = np.arange( + 1.0, 513.0, dtype=dtype.as_numpy_dtype).reshape( + [2, 2, 2, 64]).astype(np.float32) + self._testGradient(np_input, + np.random.rand(64).astype(dtype.as_numpy_dtype), dtype, data_format, use_gpu) @test_util.run_deprecated_v1 def testGradientTensor5D(self): From d18de10d5657409777605b6027cd91118f1a7670 Mon Sep 17 00:00:00 2001 From: "Kang, Letian" Date: Mon, 11 Mar 2019 11:52:39 +0800 Subject: [PATCH 2/6] Codes normalization: 1. Group all tensorflow includes together. 2. Delete unnecessary comments. 3. Using eigen ops instead. Change-Id: I67edcd71cb4feeaf8ab0c1820d2c011e3409344d --- tensorflow/core/kernels/bias_op.cc | 135 ++++++++++++++--------------- 1 file changed, 63 insertions(+), 72 deletions(-) diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 5a334146ee1691..8091a4cd7a6e8f 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -24,9 +24,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/tensor_format.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - #include "tensorflow/core/util/work_sharder.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if GOOGLE_CUDA #include "tensorflow/core/kernels/bias_op_gpu.h" @@ -242,7 +241,6 @@ class BiasGradOp : public OpKernel { data_format_ = FORMAT_NHWC; } } - void Compute(OpKernelContext* context) override { const Tensor& output_backprop = context->input(0); @@ -269,7 +267,6 @@ class BiasGradOp : public OpKernel { // Eigen often crashes by design on empty tensors, but setZero is safe output->template flat().setZero(); } else { - // Modified for performance tune :: begin here //****************************************************************************** // Divide the input tensor into several blocks. // As we don't know anything about the detail shape of incoming input @@ -277,13 +274,12 @@ class BiasGradOp : public OpKernel { // separately, so we just evenly distribute total workloads to each // block. //****************************************************************************** - // Init the output to zero output->template flat().setZero(); // Get the location of input/output data. const T* input_ptr = output_backprop.template flat().data(); - T* result_ptr = output->template flat().data(); + // Get the format of input/output data. const bool format_is_nhwc = (data_format_ == FORMAT_NHWC); // Get the intra-thread pool handle. @@ -294,6 +290,7 @@ class BiasGradOp : public OpKernel { // Get the workload parameters. const int reduce_dims = batch * height * width * depth; + const int hwd_size = height * width * depth; const int total_workload = reduce_dims * channel; // For small workloads, large block number will waste @@ -306,9 +303,9 @@ class BiasGradOp : public OpKernel { // For NHWC format, we use each channel layer as a scheduling unit, // while for NCHW format, we use each FLOP as a scheduling unit. int parallel_cell_size = 1; - if ((format_is_nhwc) || - ((!format_is_nhwc) && (height * width * depth == 1))) + if ((format_is_nhwc) || ((!format_is_nhwc) && (hwd_size == 1))) { parallel_cell_size = channel; + } const int max_parallelism = total_workload / parallel_cell_size; const int min_block_size = (min_block_workloads + parallel_cell_size - 1) / parallel_cell_size; @@ -327,13 +324,16 @@ class BiasGradOp : public OpKernel { T* block_results_buffer_ptr = block_results_buffer.template flat().data(); + using Shell = Eigen::TensorMap>; + using ConstShell = Eigen::TensorMap>; + //****************************************************************************** // Job func for each thread //****************************************************************************** - auto BiasGradWorker = [this, input_ptr, format_is_nhwc, total_workload, - num_blocks, block_results_buffer_ptr, height, - width, depth, channel](int64 my_job_begin, - int64 my_job_end) -> void { + auto BiasGradWorker = [this, &total_workload, &num_blocks, + &format_is_nhwc, &input_ptr, + &block_results_buffer_ptr, &hwd_size, &channel]( + int my_job_begin, int my_job_end) mutable -> void { // We generate a cover of [0,total_workload), which is comprised of // num_blocks non-overlapping divisions of [0,total_workload) // EXP: If we get 22 elements in input tensor, which are divided @@ -346,62 +346,73 @@ class BiasGradOp : public OpKernel { // Elements: $$$$$ | $$$$$* | $$$$$ | $$$$$* // Range : [0,5) | [5,11) | [11,16) | [16,22) // 22*0/4=0 22*1/4=5 22*2/4=11 22*3/4=16 22*4/4=22 - const int64 block_begin = total_workload * my_job_begin / num_blocks; - const int64 block_end = total_workload * my_job_end / num_blocks; - - // Get buffer pointer. - T* block_result_ptr = &block_results_buffer_ptr[my_job_begin * channel]; - - if ((format_is_nhwc) || - ((!format_is_nhwc) && (height * width * depth == 1))) { - // Align the calculation by inner most dim. - const int64 align_begin = (block_begin / channel) * channel; - const int64 align_end = (block_end / channel) * channel; + const int block_begin = total_workload * my_job_begin / num_blocks; + const int block_end = total_workload * my_job_end / num_blocks; + + T* buffer_ptr = &block_results_buffer_ptr[my_job_begin * channel]; + Shell my_buffer(buffer_ptr, channel); + + if ((format_is_nhwc) || ((!format_is_nhwc) && (hwd_size == 1))) { + // For NHWC, it is easy to divide workload, because the parallelism + // mainly comes from layers outside channel (N*H*W). + // So we just divide NHW layers. + // Align the calculation by inner most layer (channel). + const int align_begin = (block_begin / channel) * channel; + const int align_end = (block_end / channel) * channel; // Apply the calculation. - for (int64 i = align_begin; i < align_end; i += channel) { - InplaceVecAdd(block_result_ptr, &input_ptr[i], channel); + for (int i = align_begin; i < align_end; i += channel) { + my_buffer += ConstShell(&input_ptr[i], channel); } } else { // For NCHW format // A straight forward impl for NCHW could be like: - // for(int64 i=block_begin;i sum = + ConstShell(&input_ptr[block_begin], align_begin - block_begin) + .sum(); + my_buffer(channel_id) += sum(0); + // Init channel_id to avoid the error when align_begin == block_begin. - channel_id = align_begin / stride % channel; - // Apply the reduction - for (int64 i = align_begin; i < align_end; i += stride) { + channel_id = align_begin / hwd_size % channel; + + for (int i = align_begin; i < align_end; i += hwd_size) { + // Apply the reduction if (channel_id < channel) { // When channel_id is in channel, // just add the sum of inside dim to block buffer. - block_result_ptr[channel_id] += - VecSumReduce(&input_ptr[i], stride); + sum = ConstShell(&input_ptr[i], hwd_size).sum(); + my_buffer(channel_id) += sum(0); channel_id++; } else { // When channel_id exceed the range of channel, // go back to the beginning of block buffer. channel_id = channel_id - channel; - block_result_ptr[channel_id] += - VecSumReduce(&input_ptr[i], stride); + sum = ConstShell(&input_ptr[i], hwd_size).sum(); + my_buffer(channel_id) += sum(0); channel_id++; } } // Dealing with back residual. - block_result_ptr[channel_id] += - VecSumReduce(&input_ptr[align_end], block_end - align_end); + sum = ConstShell(&input_ptr[align_end], block_end - align_end).sum(); + my_buffer(channel_id) += sum(0); } }; // Run multi-threads @@ -416,37 +427,17 @@ class BiasGradOp : public OpKernel { max_parallelism); //****************************************************************************** - // Sum block results up + // Now sum block results up //****************************************************************************** - for (int64 i = 0; i < num_blocks; i++) { - InplaceVecAdd(result_ptr, &block_results_buffer_ptr[i * channel], - channel); + for (int i = 0; i < num_blocks; i++) { + Shell buffer_i(&block_results_buffer_ptr[channel * i], channel); + output->template flat() += buffer_i; } - // Modified for performance tune :: end here } } private: TensorFormat data_format_; - - // Modified for performance tune :: new funcs - // Apply X[0:length-1] = X[0:length-1] + Y[0:length-1]; - inline void InplaceVecAdd(T* X, const T* Y, const int64 length) { - //#pragma simd - for (int64 i = 0; i < length; i++) { - X[i] = X[i] + Y[i]; - } - } - // Return sum(X[0:length-1]) - inline T VecSumReduce(const T* X, const int64 length) { - T result = (T)0; - //#pragma simd - for (int64 i = 0; i < length; i++) { - result += X[i]; - } - return result; - } - // Modified for performance tune :: end here }; // Registration of the GPU implementations. From 1c8e84166133a95aa8f906b47806a6cb700c6c70 Mon Sep 17 00:00:00 2001 From: Letian Kang Date: Fri, 15 Mar 2019 19:14:10 +0800 Subject: [PATCH 3/6] 1. Add ReduceMiddleDimensions to support middle dim reduce in tensorflow/core/kernels/redux_functor.h 2. Opt ReduceOuterDimensions for large inner dim. 3. Rewrite NCHW BiasAddGradOp with ReduceOuterDimensions. --- tensorflow/core/kernels/bias_op.cc | 628 ------------------------ tensorflow/core/kernels/redux_functor.h | 124 ----- 2 files changed, 752 deletions(-) delete mode 100644 tensorflow/core/kernels/bias_op.cc delete mode 100644 tensorflow/core/kernels/redux_functor.h diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc deleted file mode 100644 index 04ffdd942121c7..00000000000000 --- a/tensorflow/core/kernels/bias_op.cc +++ /dev/null @@ -1,628 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// See docs in ../ops/nn_ops.cc. - -#define EIGEN_USE_THREADS - -#include "tensorflow/core/kernels/bias_op.h" -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/numeric_op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/redux_functor.h" -#include "tensorflow/core/util/tensor_format.h" -#include "tensorflow/core/util/work_sharder.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - -#if GOOGLE_CUDA -#include "tensorflow/core/kernels/bias_op_gpu.h" -#include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/stream_executor/cuda/cuda_stream.h" -#endif // GOOGLE_CUDA - -namespace tensorflow { - -typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; -#ifdef TENSORFLOW_USE_SYCL -typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL - -namespace { - -void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format, - int32* batch, int32* height, int32* width, int32* depth, - int32* channel) { - *batch = 1; - *height = 1; - *width = 1; - *depth = 1; - *channel = 1; - if (data_format == FORMAT_NHWC) { - int32 channel_dim = value_tensor.dims() - 1; - *channel = static_cast(value_tensor.dim_size(channel_dim)); - for (int32 i = 0; i < channel_dim; i++) { - *batch *= static_cast(value_tensor.dim_size(i)); - } - } else if (data_format == FORMAT_NCHW) { - *batch = static_cast(value_tensor.dim_size(0)); - *channel = static_cast(value_tensor.dim_size(1)); - *height = static_cast(value_tensor.dim_size(2)); - if (value_tensor.dims() > 3) { - *width = static_cast(value_tensor.dim_size(3)); - } - if (value_tensor.dims() > 4) { - *depth = static_cast(value_tensor.dim_size(4)); - } - } -} - -template -struct AccumulatorType { - typedef T type; -}; - -// float is faster on the CPU than half, and also more precise, -// so use float for the temporary accumulators. -template <> -struct AccumulatorType { - typedef float type; -}; - -} // namespace - -template -class BiasOp : public BinaryOp { - public: - explicit BiasOp(OpKernelConstruction* context) : BinaryOp(context) { - string data_format; - if (context->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - } else { - data_format_ = FORMAT_NHWC; - } - } - - void Compute(OpKernelContext* context) override { - const Tensor& input = context->input(0); - const Tensor& bias = context->input(1); - - OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()), - errors::InvalidArgument("Input tensor must be at least 2D: ", - input.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()), - errors::InvalidArgument("Biases must be 1D: ", - bias.shape().DebugString())); - - // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. - size_t channel_dim; - if (data_format_ == FORMAT_NCHW) { - channel_dim = 1; // NCHW always have channel dim in 1 (with 3, 4, 5 - // dimensions data). - } else { - channel_dim = input.shape().dims() - 1; // End of code by intel_tf. - } - - OP_REQUIRES( - context, - bias.shape().dim_size(0) == input.shape().dim_size(channel_dim), - errors::InvalidArgument( - "Must provide as many biases as the last dimension " - "of the input tensor: ", - bias.shape().DebugString(), " vs. ", input.shape().DebugString())); - - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {0}, 0, input.shape(), &output)); - if (input.NumElements() == 0) return; - - // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. - if (data_format_ == FORMAT_NCHW) { - int32 batch, height, width, depth, channel; - GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth, - &channel); - switch (input.shape().dims()) { - case 3: { - Eigen::DSizes three_dims(1, channel, 1); - Eigen::DSizes broad_cast_dims(batch, 1, height); - const Device& d = context->eigen_device(); - output->tensor().device(d) = input.tensor() + - bias.tensor() - .reshape(three_dims) - .broadcast(broad_cast_dims); - } break; - case 4: { - Eigen::DSizes four_dims(1, channel, 1, 1); - Eigen::DSizes broad_cast_dims(batch, 1, height, width); - const Device& d = context->eigen_device(); - output->tensor().device(d) = - input.tensor() + - bias.tensor().reshape(four_dims).broadcast(broad_cast_dims); - } break; - case 5: { - Eigen::DSizes five_dims(1, channel, 1, 1, 1); - Eigen::DSizes broad_cast_dims(batch, 1, height, width, - depth); - const Device& d = context->eigen_device(); - output->tensor().device(d) = - input.tensor() + - bias.tensor().reshape(five_dims).broadcast(broad_cast_dims); - } break; - default: - OP_REQUIRES(context, false, - errors::InvalidArgument("Only ranks up to 5 supported: ", - input.shape().DebugString())); - } - return; - } // End of code by intel_tf. - - switch (input.shape().dims()) { - case 2: - Compute<2>(context, input, bias, output); - break; - case 3: - Compute<3>(context, input, bias, output); - break; - case 4: - Compute<4>(context, input, bias, output); - break; - case 5: - Compute<5>(context, input, bias, output); - break; - default: - OP_REQUIRES(context, false, - errors::InvalidArgument("Only ranks up to 5 supported: ", - input.shape().DebugString())); - } - } - - // Add biases for an input matrix of rank Dims, by using the Bias. - template - void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias, - Tensor* output) { - functor::Bias functor; - functor(ctx->eigen_device(), input.tensor(), bias.vec(), - output->tensor()); - } - - private: - TensorFormat data_format_; -}; - -#define REGISTER_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint("T"), \ - BiasOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint("T"), \ - BiasOp); - -TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); -#undef REGISTER_KERNEL - -#ifdef TENSORFLOW_USE_SYCL -#define REGISTER_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAdd").Device(DEVICE_SYCL).TypeConstraint("T"), \ - BiasOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAddV1").Device(DEVICE_SYCL).TypeConstraint("T"), \ - BiasOp); - -TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL); -REGISTER_KERNEL(float); -REGISTER_KERNEL(double); -#undef REGISTER_KERNEL -#endif // TENSORFLOW_USE_SYCL - -template -class BiasGradOp : public OpKernel { - public: - explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) { - string data_format; - if (context->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - } else { - data_format_ = FORMAT_NHWC; - } - } - void Compute(OpKernelContext* context) override { - const Tensor& output_backprop = context->input(0); - - OP_REQUIRES(context, - TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()), - errors::InvalidArgument("Input tensor must be at least 2D: ", - output_backprop.shape().DebugString())); - - OP_REQUIRES( - context, FastBoundsCheck(output_backprop.NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("BiasGrad requires tensor size <= int32 max")); - - int32 batch, height, width, depth, channel; - GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width, - &depth, &channel); - Tensor* output = nullptr; - TensorShape output_shape{channel}; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - - if (channel == 0) { - return; // Nothing to do - } else if (output_backprop.NumElements() == 0) { - // Eigen often crashes by design on empty tensors, but setZero is safe - output->template flat().setZero(); - } else { - // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. - if (data_format_ == FORMAT_NCHW) { - Eigen::DSizes three_dims(batch, channel, - height * width * depth); -#ifdef EIGEN_HAS_INDEX_LIST - using idx0 = Eigen::type2index<0>; - using idx2 = Eigen::type2index<2>; - Eigen::IndexList reduction_axes; -#else - Eigen::array reduction_axes = {0, 2}; -#endif - output->template flat().device(context->eigen_device()) = - output_backprop.flat() - .template cast::type>() - .reshape(three_dims) - .sum(reduction_axes) - .template cast(); // End of code by intel_tf. - } else { - using AccumT = typename AccumulatorType::type; - const functor::ReduceOuterDimensions< - T, AccumT, Eigen::internal::scalar_sum_op> - redux; - - Eigen::DSizes two_dims(batch * height * width * depth, - channel); - redux(context->eigen_device(), two_dims, output_backprop, - output); - } - } - } - - private: - TensorFormat data_format_; -}; - -// Registration of the GPU implementations. -#define REGISTER_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ - BiasGradOp); - -TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); -#undef REGISTER_KERNEL - -#ifdef TENSORFLOW_USE_SYCL -#define REGISTER_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAddGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ - BiasGradOp); - -TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL); -REGISTER_KERNEL(float); -REGISTER_KERNEL(double); -#undef REGISTER_KERNEL -#endif // TENSORFLOW_USE_SYCL - -#if GOOGLE_CUDA -template -class BiasOp : public BinaryOp { - public: - typedef GPUDevice Device; - explicit BiasOp(OpKernelConstruction* context) : BinaryOp(context) { - string data_format; - if (context->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - } else { - data_format_ = FORMAT_NHWC; - } - } - - void Compute(OpKernelContext* context) override { - const Tensor& input = context->input(0); - const Tensor& bias = context->input(1); - - OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()), - errors::InvalidArgument("Input tensor must be at least 2D: ", - input.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()), - errors::InvalidArgument("Biases must be 1D: ", - bias.shape().DebugString())); - int32 batch, height, width, depth, channel; - GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth, - &channel); - OP_REQUIRES(context, bias.shape().dim_size(0) == channel, - errors::InvalidArgument( - "Must provide as many biases as the channel dimension " - "of the input tensor: ", - bias.shape().DebugString(), " vs. ", channel, " in ", - input.shape().DebugString())); - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {0}, 0, input.shape(), &output)); - if (input.NumElements() > 0) { - BiasGPU::compute(context->template eigen_device(), - input.flat().data(), bias.flat().data(), - output->flat().data(), batch, width, height, depth, - channel, data_format_); - } - } - - private: - TensorFormat data_format_; -}; - -// Registration of the GPU implementations. -#define REGISTER_GPU_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint("T"), \ - BiasOp); \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint("T"), \ - BiasOp); - -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); -#undef REGISTER_GPU_KERNEL - -struct BiasGradAutotuneGroup { - static string name() { return "BiasGrad"; } -}; - -class BiasAddGradGPUConfig { - public: - BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {} - string ToString() const { - if (mode_ == BiasAddGradGPUMode::kNative) { - return "native CUDA kernel."; - } - if (mode_ == BiasAddGradGPUMode::kReduction) { - return "cub reduction kernel."; - } - return "unknown kernel."; - } - BiasAddGradGPUMode get_mode() const { return mode_; } - void set_mode(BiasAddGradGPUMode val) { mode_ = val; } - - bool operator==(const BiasAddGradGPUConfig& other) const { - return this->mode_ == other.get_mode(); - } - - bool operator!=(const BiasAddGradGPUConfig& other) const { - return !(*this == other); - } - - private: - BiasAddGradGPUMode mode_; -}; - -// Encapsulate all the shape information that is used in bias add grad -// operations. -class BiasAddParams { - public: - // We use a list to maintain both the shape value and the order (data format). - using SpatialArray = gtl::InlinedVector; - BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format, - DataType dtype, int device_id) - : in_shape_(in_shape), - data_format_(data_format), - dtype_(dtype), - device_id_(device_id) { - for (int64 val : in_shape_) { - hash_code_ = Hash64Combine(hash_code_, val); - } - hash_code_ = Hash64Combine(hash_code_, data_format); - hash_code_ = Hash64Combine(hash_code_, dtype); - hash_code_ = Hash64Combine(hash_code_, device_id); - } - bool operator==(const BiasAddParams& other) const { - return this->get_data_as_tuple() == other.get_data_as_tuple(); - } - - bool operator!=(const BiasAddParams& other) const { - return !(*this == other); - } - uint64 hash() const { return hash_code_; } - - string ToString() const { - // clang-format off - return strings::StrCat( - "(", str_util::Join(in_shape_, ", "), "), ", - data_format_, ", ", dtype_, ", ", device_id_); - // clang-format on - } - - protected: - using ParamsDataType = std::tuple; - - ParamsDataType get_data_as_tuple() const { - return std::make_tuple(in_shape_, data_format_, dtype_, device_id_); - } - - uint64 hash_code_ = 0; - - private: - SpatialArray in_shape_; - TensorFormat data_format_; - DataType dtype_; - int device_id_; -}; - -typedef AutoTuneSingleton - AutotuneBiasGrad; - -template -class BiasGradOp : public OpKernel { - public: - typedef GPUDevice Device; - explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) { - string data_format; - if (context->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); - } else { - data_format_ = FORMAT_NCHW; - } - } - - void ComputeWithCustomKernel(OpKernelContext* context, - const Tensor& output_backprop, int32 batch, - int32 width, int32 height, int32 depth, - int32 channel, Tensor* output) { - BiasGradGPU::compute(context->template eigen_device(), - output_backprop.template flat().data(), - output->flat().data(), batch, width, height, - depth, channel, data_format_); - } - - void ComputeWithReduceSum(OpKernelContext* context, - const Tensor& output_backprop, int32 batch, - int32 width, int32 height, int32 depth, - int32 channel, Tensor* output) { - if (data_format_ == FORMAT_NCHW) { - int32 row_count = batch * channel; - int32 col_count = height * width * depth; - Tensor temp_grad_outputs; - // For 'NCHW' format, we perform reduction twice: first HW, then N. - TensorShape temp_grad_output_shape{row_count, col_count}; - OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, - temp_grad_output_shape, - &temp_grad_outputs)); - BiasGradGPU::DoRowReduction( - context, temp_grad_outputs.flat().data(), - output_backprop.template flat().data(), row_count, col_count); - - row_count = batch; - col_count = channel; - BiasGradGPU::DoColReduction(context, output->flat().data(), - temp_grad_outputs.flat().data(), - row_count, col_count); - } else { - // For 'NHWC', we simply apply reduction once on NHW. - int32 row_count = batch * height * width * depth; - int32 col_count = channel; - BiasGradGPU::DoColReduction( - context, const_cast(output->flat().data()), - reinterpret_cast(output_backprop.template flat().data()), - row_count, col_count); - } - } - - void Compute(OpKernelContext* context) override { - const Tensor& output_backprop = context->input(0); - - OP_REQUIRES(context, - TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()), - errors::InvalidArgument("Input tensor must be at least 2D: ", - output_backprop.shape().DebugString())); - int32 batch, height, width, depth, channel; - GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width, - &depth, &channel); - Tensor* output = nullptr; - TensorShape output_shape{channel}; - OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - if (channel == 0) return; - auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - se::DeviceMemoryBase output_ptr(output->flat().data(), - output->NumElements() * sizeof(T)); - stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T)); - if (output_backprop.NumElements() <= 0) return; - - int device_id = stream->parent()->device_ordinal(); - DataType dtype = output_backprop.dtype(); - BiasAddParams bias_parameters = { - {batch, height * width * depth, channel}, - data_format_, - dtype, - device_id, - }; - - // Autotune two algorithm: customized - BiasAddGradGPUConfig algo_config; - if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) { - BiasGradGPUProfileResult best_result; - // Initialize the timer. - perftools::gputools::Timer timer(stream->parent()); - stream->InitTimer(&timer); - stream->ThenStartTimer(&timer); - ComputeWithCustomKernel(context, output_backprop, batch, width, height, - depth, channel, output); - stream->ThenStopTimer(&timer); - uint64 elapsed_microseconds = timer.Microseconds(); - VLOG(1) << "BiasAddGrad " << bias_parameters.ToString() - << " Native algo latency: " << elapsed_microseconds; - if (elapsed_microseconds < best_result.elapsed_time()) { - best_result.set_algorithm(BiasAddGradGPUMode::kNative); - best_result.set_elapsed_time(elapsed_microseconds); - } - - // Try reduction and profile. - stream->ThenStartTimer(&timer); - ComputeWithReduceSum(context, output_backprop, batch, width, height, - depth, channel, output); - stream->ThenStopTimer(&timer); - - elapsed_microseconds = timer.Microseconds(); - VLOG(1) << "BiasAddGrad " << bias_parameters.ToString() - << " Reduction algo latency: " << elapsed_microseconds; - if (elapsed_microseconds < best_result.elapsed_time()) { - best_result.set_algorithm(BiasAddGradGPUMode::kReduction); - best_result.set_elapsed_time(elapsed_microseconds); - } - - algo_config.set_mode(best_result.algorithm()); - AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config); - - // Results are already available during autotune, so no need to continue. - return; - } - - // Choose the best algorithm based on autotune results. - if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) { - ComputeWithReduceSum(context, output_backprop, batch, width, height, - depth, channel, output); - } else { - // Default to the customized kernel. - ComputeWithCustomKernel(context, output_backprop, batch, width, height, - depth, channel, output); - } - } - - private: - TensorFormat data_format_; -}; - -// Registration of the GPU implementations. -#define REGISTER_GPU_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ - BiasGradOp); - -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); -#undef REGISTER_GPU_KERNEL - -#endif // GOOGLE_CUDA - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/redux_functor.h b/tensorflow/core/kernels/redux_functor.h deleted file mode 100644 index c542099cc0870f..00000000000000 --- a/tensorflow/core/kernels/redux_functor.h +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ -#define TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ - -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -using CPUDevice = Eigen::ThreadPoolDevice; - -namespace functor { - -// Compute reduction over all outer dimensions. -// Example: -// input: [32, 32, 256] -// -> -// output: [256] -template -struct ReduceOuterDimensions { - template - void operator()(const CPUDevice& device, - const Eigen::DSizes& input_dims, - const Tensor& input, Tensor* output) const { - static_assert(num_dims >= 2, "Input dimensions must at least 2"); - - // Compute inner and outer dim after reshaping into 2d tensor. - int64 inner_dim = input_dims[num_dims - 1]; - int64 outer_dim = 1; - for (int i = 0; i < num_dims - 1; ++i) outer_dim *= input_dims[i]; - - // Compute block size along the outer dimension for efficiency. - const int64 parallel_cell_size = inner_dim; - const int64 total_workload = outer_dim * inner_dim; - const int64 max_parallelism = total_workload / parallel_cell_size; - - const int64 min_block_workload = 2000; - const int64 min_block_size = - Eigen::divup(min_block_workload, parallel_cell_size); - const int64 max_num_blocks = - std::min(max_parallelism, Eigen::divup(total_workload, min_block_size)); - - // Do not create more blocks than there are threads in a pool. - const int64 num_threads = device.numThreads(); - const int64 num_blocks = std::min(max_num_blocks, num_threads); - - // Block size along the outer dimension. - const int64 outer_block_size = Eigen::divup(outer_dim, num_blocks); - - const T* input_data = input.template flat().data(); - - // Allocate temporary buffer for partial reductions. - Tensor buffer(DataTypeToEnum::v(), {num_blocks, inner_dim}); - buffer.template flat().setZero(); - AccumT* buffer_data = buffer.template flat().data(); - - using Buffer = Eigen::TensorMap< - Eigen::Tensor, - Eigen::Unaligned>; - - using Input = Eigen::TensorMap< - Eigen::Tensor, - Eigen::Unaligned>; - - const auto compute = [inner_dim, num_blocks, outer_block_size, buffer_data, - input_data, outer_dim](Eigen::Index start, - Eigen::Index limit) -> void { - DCHECK(start >= 0 && limit <= num_blocks); - int64 outer_dim_start = start * outer_block_size; - int64 outer_dim_limit = limit * outer_block_size; - outer_dim_limit = std::min(outer_dim, outer_dim_limit); - - Buffer buf(buffer_data + start * inner_dim, inner_dim); - for (int64 i = outer_dim_start; i < outer_dim_limit; ++i) { - auto in = Input(input_data + i * inner_dim, inner_dim); - auto cast = in.template cast(); - buf = Eigen::TensorCwiseBinaryOp(buf, cast); - } - }; - - // Compute cost of reducing a single block. - const int64 compute_size = outer_block_size * inner_dim; - const int64 compute_input_bytes = compute_size * sizeof(T); - const Eigen::TensorOpCost cost( - compute_input_bytes, - 0, // We'll be mostly writing to L1, assume store cost is 0 - compute_size * Eigen::internal::functor_traits::Cost); - - device.parallelFor(num_blocks, cost, compute); - - // Aggregate partial results from temporary buffer into first block. - auto buf0 = Buffer(buffer_data, inner_dim); - // TODO(ezhulenev): Parallelize this loop for large inner dimensions? - for (int i = 1; i < num_blocks; ++i) { - auto buf = Buffer(buffer_data + i * inner_dim, inner_dim); - buf0 = Eigen::TensorCwiseBinaryOp(buf0, buf); - } - - // Write final result to the output. - output->template flat() = buf0.template cast(); - } -}; - -} // namespace functor -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ From c9c64cdc4bb95bcea30d4081d5803d8d76d8331e Mon Sep 17 00:00:00 2001 From: Letian Kang Date: Fri, 15 Mar 2019 19:20:19 +0800 Subject: [PATCH 4/6] 1. Add ReduceMiddleDimensions to support middle dim reduce in tensorflow/core/kernels/redux_functor.h 2. Opt ReduceOuterDimensions for large inner dim. 3. Rewrite NCHW BiasAddGradOp with ReduceOuterDimensions. --- tensorflow/core/kernels/bias_op.cc | 621 ++++++++++++++++++++++++ tensorflow/core/kernels/redux_functor.h | 317 ++++++++++++ 2 files changed, 938 insertions(+) create mode 100644 tensorflow/core/kernels/bias_op.cc create mode 100644 tensorflow/core/kernels/redux_functor.h diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc new file mode 100644 index 00000000000000..41b115983b3b2f --- /dev/null +++ b/tensorflow/core/kernels/bias_op.cc @@ -0,0 +1,621 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/bias_op.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/redux_functor.h" +#include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/bias_op_gpu.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/stream_executor/cuda/cuda_stream.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + +namespace { + +void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format, + int32* batch, int32* height, int32* width, int32* depth, + int32* channel) { + *batch = 1; + *height = 1; + *width = 1; + *depth = 1; + *channel = 1; + if (data_format == FORMAT_NHWC) { + int32 channel_dim = value_tensor.dims() - 1; + *channel = static_cast(value_tensor.dim_size(channel_dim)); + for (int32 i = 0; i < channel_dim; i++) { + *batch *= static_cast(value_tensor.dim_size(i)); + } + } else if (data_format == FORMAT_NCHW) { + *batch = static_cast(value_tensor.dim_size(0)); + *channel = static_cast(value_tensor.dim_size(1)); + *height = static_cast(value_tensor.dim_size(2)); + if (value_tensor.dims() > 3) { + *width = static_cast(value_tensor.dim_size(3)); + } + if (value_tensor.dims() > 4) { + *depth = static_cast(value_tensor.dim_size(4)); + } + } +} + +template +struct AccumulatorType { + typedef T type; +}; + +// float is faster on the CPU than half, and also more precise, +// so use float for the temporary accumulators. +template <> +struct AccumulatorType { + typedef float type; +}; + +} // namespace + +template +class BiasOp : public BinaryOp { + public: + explicit BiasOp(OpKernelConstruction* context) : BinaryOp(context) { + string data_format; + if (context->GetAttr("data_format", &data_format).ok()) { + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } else { + data_format_ = FORMAT_NHWC; + } + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& bias = context->input(1); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()), + errors::InvalidArgument("Input tensor must be at least 2D: ", + input.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()), + errors::InvalidArgument("Biases must be 1D: ", + bias.shape().DebugString())); + + // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. + size_t channel_dim; + if (data_format_ == FORMAT_NCHW) { + channel_dim = 1; // NCHW always have channel dim in 1 (with 3, 4, 5 + // dimensions data). + } else { + channel_dim = input.shape().dims() - 1; // End of code by intel_tf. + } + + OP_REQUIRES( + context, + bias.shape().dim_size(0) == input.shape().dim_size(channel_dim), + errors::InvalidArgument( + "Must provide as many biases as the last dimension " + "of the input tensor: ", + bias.shape().DebugString(), " vs. ", input.shape().DebugString())); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input.shape(), &output)); + if (input.NumElements() == 0) return; + + // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. + if (data_format_ == FORMAT_NCHW) { + int32 batch, height, width, depth, channel; + GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth, + &channel); + switch (input.shape().dims()) { + case 3: { + Eigen::DSizes three_dims(1, channel, 1); + Eigen::DSizes broad_cast_dims(batch, 1, height); + const Device& d = context->eigen_device(); + output->tensor().device(d) = input.tensor() + + bias.tensor() + .reshape(three_dims) + .broadcast(broad_cast_dims); + } break; + case 4: { + Eigen::DSizes four_dims(1, channel, 1, 1); + Eigen::DSizes broad_cast_dims(batch, 1, height, width); + const Device& d = context->eigen_device(); + output->tensor().device(d) = + input.tensor() + + bias.tensor().reshape(four_dims).broadcast(broad_cast_dims); + } break; + case 5: { + Eigen::DSizes five_dims(1, channel, 1, 1, 1); + Eigen::DSizes broad_cast_dims(batch, 1, height, width, + depth); + const Device& d = context->eigen_device(); + output->tensor().device(d) = + input.tensor() + + bias.tensor().reshape(five_dims).broadcast(broad_cast_dims); + } break; + default: + OP_REQUIRES(context, false, + errors::InvalidArgument("Only ranks up to 5 supported: ", + input.shape().DebugString())); + } + return; + } // End of code by intel_tf. + + switch (input.shape().dims()) { + case 2: + Compute<2>(context, input, bias, output); + break; + case 3: + Compute<3>(context, input, bias, output); + break; + case 4: + Compute<4>(context, input, bias, output); + break; + case 5: + Compute<5>(context, input, bias, output); + break; + default: + OP_REQUIRES(context, false, + errors::InvalidArgument("Only ranks up to 5 supported: ", + input.shape().DebugString())); + } + } + + // Add biases for an input matrix of rank Dims, by using the Bias. + template + void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias, + Tensor* output) { + functor::Bias functor; + functor(ctx->eigen_device(), input.tensor(), bias.vec(), + output->tensor()); + } + + private: + TensorFormat data_format_; +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint("T"), \ + BiasOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint("T"), \ + BiasOp); + +TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAdd").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BiasOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddV1").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BiasOp); + +TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL +#endif // TENSORFLOW_USE_SYCL + +template +class BiasGradOp : public OpKernel { + public: + explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) { + string data_format; + if (context->GetAttr("data_format", &data_format).ok()) { + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } else { + data_format_ = FORMAT_NHWC; + } + } + + void Compute(OpKernelContext* context) override { + const Tensor& output_backprop = context->input(0); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()), + errors::InvalidArgument("Input tensor must be at least 2D: ", + output_backprop.shape().DebugString())); + + OP_REQUIRES( + context, FastBoundsCheck(output_backprop.NumElements(), + std::numeric_limits::max()), + errors::InvalidArgument("BiasGrad requires tensor size <= int32 max")); + + int32 batch, height, width, depth, channel; + GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width, + &depth, &channel); + Tensor* output = nullptr; + TensorShape output_shape{channel}; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + if (channel == 0) { + return; // Nothing to do + } else if (output_backprop.NumElements() == 0) { + // Eigen often crashes by design on empty tensors, but setZero is safe + output->template flat().setZero(); + } else { + // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. + using AccumT = typename AccumulatorType::type; + if (data_format_ == FORMAT_NCHW) { + const functor::ReduceMiddleDimensions< + T, AccumT, Eigen::internal::scalar_sum_op, + Eigen::internal::SumReducer> + redux; + Eigen::DSizes three_dims(batch, channel, + height * width * depth); + redux(context->eigen_device(), three_dims, output_backprop, + output, 1); + } else { + const functor::ReduceOuterDimensions< + T, AccumT, Eigen::internal::scalar_sum_op> + redux; + + Eigen::DSizes two_dims(batch * height * width * depth, + channel); + redux(context->eigen_device(), two_dims, output_backprop, + output); + } + } + } + + private: + TensorFormat data_format_; +}; + +// Registration of the GPU implementations. +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint("T"), \ + BiasGradOp); + +TF_CALL_NUMBER_TYPES(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BiasGradOp); + +TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +template +class BiasOp : public BinaryOp { + public: + typedef GPUDevice Device; + explicit BiasOp(OpKernelConstruction* context) : BinaryOp(context) { + string data_format; + if (context->GetAttr("data_format", &data_format).ok()) { + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } else { + data_format_ = FORMAT_NHWC; + } + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& bias = context->input(1); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()), + errors::InvalidArgument("Input tensor must be at least 2D: ", + input.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()), + errors::InvalidArgument("Biases must be 1D: ", + bias.shape().DebugString())); + int32 batch, height, width, depth, channel; + GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth, + &channel); + OP_REQUIRES(context, bias.shape().dim_size(0) == channel, + errors::InvalidArgument( + "Must provide as many biases as the channel dimension " + "of the input tensor: ", + bias.shape().DebugString(), " vs. ", channel, " in ", + input.shape().DebugString())); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input.shape(), &output)); + if (input.NumElements() > 0) { + BiasGPU::compute(context->template eigen_device(), + input.flat().data(), bias.flat().data(), + output->flat().data(), batch, width, height, depth, + channel, data_format_); + } + } + + private: + TensorFormat data_format_; +}; + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint("T"), \ + BiasOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint("T"), \ + BiasOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +struct BiasGradAutotuneGroup { + static string name() { return "BiasGrad"; } +}; + +class BiasAddGradGPUConfig { + public: + BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {} + string ToString() const { + if (mode_ == BiasAddGradGPUMode::kNative) { + return "native CUDA kernel."; + } + if (mode_ == BiasAddGradGPUMode::kReduction) { + return "cub reduction kernel."; + } + return "unknown kernel."; + } + BiasAddGradGPUMode get_mode() const { return mode_; } + void set_mode(BiasAddGradGPUMode val) { mode_ = val; } + + bool operator==(const BiasAddGradGPUConfig& other) const { + return this->mode_ == other.get_mode(); + } + + bool operator!=(const BiasAddGradGPUConfig& other) const { + return !(*this == other); + } + + private: + BiasAddGradGPUMode mode_; +}; + +// Encapsulate all the shape information that is used in bias add grad +// operations. +class BiasAddParams { + public: + // We use a list to maintain both the shape value and the order (data format). + using SpatialArray = gtl::InlinedVector; + BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format, + DataType dtype, int device_id) + : in_shape_(in_shape), + data_format_(data_format), + dtype_(dtype), + device_id_(device_id) { + for (int64 val : in_shape_) { + hash_code_ = Hash64Combine(hash_code_, val); + } + hash_code_ = Hash64Combine(hash_code_, data_format); + hash_code_ = Hash64Combine(hash_code_, dtype); + hash_code_ = Hash64Combine(hash_code_, device_id); + } + bool operator==(const BiasAddParams& other) const { + return this->get_data_as_tuple() == other.get_data_as_tuple(); + } + + bool operator!=(const BiasAddParams& other) const { + return !(*this == other); + } + uint64 hash() const { return hash_code_; } + + string ToString() const { + // clang-format off + return strings::StrCat( + "(", str_util::Join(in_shape_, ", "), "), ", + data_format_, ", ", dtype_, ", ", device_id_); + // clang-format on + } + + protected: + using ParamsDataType = std::tuple; + + ParamsDataType get_data_as_tuple() const { + return std::make_tuple(in_shape_, data_format_, dtype_, device_id_); + } + + uint64 hash_code_ = 0; + + private: + SpatialArray in_shape_; + TensorFormat data_format_; + DataType dtype_; + int device_id_; +}; + +typedef AutoTuneSingleton + AutotuneBiasGrad; + +template +class BiasGradOp : public OpKernel { + public: + typedef GPUDevice Device; + explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) { + string data_format; + if (context->GetAttr("data_format", &data_format).ok()) { + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } else { + data_format_ = FORMAT_NCHW; + } + } + + void ComputeWithCustomKernel(OpKernelContext* context, + const Tensor& output_backprop, int32 batch, + int32 width, int32 height, int32 depth, + int32 channel, Tensor* output) { + BiasGradGPU::compute(context->template eigen_device(), + output_backprop.template flat().data(), + output->flat().data(), batch, width, height, + depth, channel, data_format_); + } + + void ComputeWithReduceSum(OpKernelContext* context, + const Tensor& output_backprop, int32 batch, + int32 width, int32 height, int32 depth, + int32 channel, Tensor* output) { + if (data_format_ == FORMAT_NCHW) { + int32 row_count = batch * channel; + int32 col_count = height * width * depth; + Tensor temp_grad_outputs; + // For 'NCHW' format, we perform reduction twice: first HW, then N. + TensorShape temp_grad_output_shape{row_count, col_count}; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + temp_grad_output_shape, + &temp_grad_outputs)); + BiasGradGPU::DoRowReduction( + context, temp_grad_outputs.flat().data(), + output_backprop.template flat().data(), row_count, col_count); + + row_count = batch; + col_count = channel; + BiasGradGPU::DoColReduction(context, output->flat().data(), + temp_grad_outputs.flat().data(), + row_count, col_count); + } else { + // For 'NHWC', we simply apply reduction once on NHW. + int32 row_count = batch * height * width * depth; + int32 col_count = channel; + BiasGradGPU::DoColReduction( + context, const_cast(output->flat().data()), + reinterpret_cast(output_backprop.template flat().data()), + row_count, col_count); + } + } + + void Compute(OpKernelContext* context) override { + const Tensor& output_backprop = context->input(0); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()), + errors::InvalidArgument("Input tensor must be at least 2D: ", + output_backprop.shape().DebugString())); + int32 batch, height, width, depth, channel; + GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width, + &depth, &channel); + Tensor* output = nullptr; + TensorShape output_shape{channel}; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + if (channel == 0) return; + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + se::DeviceMemoryBase output_ptr(output->flat().data(), + output->NumElements() * sizeof(T)); + stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T)); + if (output_backprop.NumElements() <= 0) return; + + int device_id = stream->parent()->device_ordinal(); + DataType dtype = output_backprop.dtype(); + BiasAddParams bias_parameters = { + {batch, height * width * depth, channel}, + data_format_, + dtype, + device_id, + }; + + // Autotune two algorithm: customized + BiasAddGradGPUConfig algo_config; + if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) { + BiasGradGPUProfileResult best_result; + // Initialize the timer. + perftools::gputools::Timer timer(stream->parent()); + stream->InitTimer(&timer); + stream->ThenStartTimer(&timer); + ComputeWithCustomKernel(context, output_backprop, batch, width, height, + depth, channel, output); + stream->ThenStopTimer(&timer); + uint64 elapsed_microseconds = timer.Microseconds(); + VLOG(1) << "BiasAddGrad " << bias_parameters.ToString() + << " Native algo latency: " << elapsed_microseconds; + if (elapsed_microseconds < best_result.elapsed_time()) { + best_result.set_algorithm(BiasAddGradGPUMode::kNative); + best_result.set_elapsed_time(elapsed_microseconds); + } + + // Try reduction and profile. + stream->ThenStartTimer(&timer); + ComputeWithReduceSum(context, output_backprop, batch, width, height, + depth, channel, output); + stream->ThenStopTimer(&timer); + + elapsed_microseconds = timer.Microseconds(); + VLOG(1) << "BiasAddGrad " << bias_parameters.ToString() + << " Reduction algo latency: " << elapsed_microseconds; + if (elapsed_microseconds < best_result.elapsed_time()) { + best_result.set_algorithm(BiasAddGradGPUMode::kReduction); + best_result.set_elapsed_time(elapsed_microseconds); + } + + algo_config.set_mode(best_result.algorithm()); + AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config); + + // Results are already available during autotune, so no need to continue. + return; + } + + // Choose the best algorithm based on autotune results. + if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) { + ComputeWithReduceSum(context, output_backprop, batch, width, height, + depth, channel, output); + } else { + // Default to the customized kernel. + ComputeWithCustomKernel(context, output_backprop, batch, width, height, + depth, channel, output); + } + } + + private: + TensorFormat data_format_; +}; + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint("T"), \ + BiasGradOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/redux_functor.h b/tensorflow/core/kernels/redux_functor.h new file mode 100644 index 00000000000000..9df61d07170c58 --- /dev/null +++ b/tensorflow/core/kernels/redux_functor.h @@ -0,0 +1,317 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "third_party/eigen3/Eigen/Core" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +// Compute reduction over outer dimensions. +// Example: +// input: [D1, D2, ... , DN] +// -> +// output: [Di, ... , DN] where i belongs to set [1,N] +template +struct ReduceOuterDimensions { + template + void operator()(const CPUDevice& device, + const Eigen::DSizes& input_dims, + const Tensor& input, Tensor* output, + const int num_reduce_dims = 1) const { + // Compute inner and outer dim after reshaping into 2d tensor. + int64 inner_dim = 1, outer_dim = 1; + for (int i = 0; i < num_dims - num_reduce_dims; ++i) + outer_dim *= input_dims[i]; + for (int i = num_dims - num_reduce_dims; i < num_dims; ++i) + inner_dim *= input_dims[i]; + + if (1 == inner_dim) { + // Nothing to do but passing input to output. + *output = input; + return; + } + + // Get device thread num. + const int64 num_threads = device.numThreads(); + + // If the inner dim parallelism is large enough + if (inner_dim > num_threads * 16) { + // Do not create more blocks than there are threads in a pool. + const int64 num_blocks = num_threads; + + // Block size along the outer dimension. + const int64 inner_block_size = Eigen::divup(inner_dim, num_blocks); + const T* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Eigen::Tensor buffer( + {inner_dim}); + buffer.setZero(); + AccumT* buffer_data = buffer.data(); + + using Buffer = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + using Input = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + const auto compute = [inner_dim, outer_dim, num_blocks, inner_block_size, + input_data, buffer_data]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + int64 inner_dim_start = start * inner_block_size; + int64 inner_dim_limit = limit * inner_block_size; + inner_dim_limit = std::min(inner_dim, inner_dim_limit); + int64 my_job_len = inner_dim_limit - inner_dim_start; + + const T* my_job_start = input_data + inner_dim_start; + Buffer buf(buffer_data + inner_dim_start, my_job_len); + + for (int64 i = 0; i < outer_dim; ++i) { + auto in = Input(my_job_start + i * inner_dim, my_job_len); + auto cast = in.template cast(); + buf = Eigen::TensorCwiseBinaryOp(buf, cast); + } + }; + + // Compute cost of reducing a single block. + const int64 compute_size = outer_dim * inner_block_size; + const int64 compute_input_bytes = compute_size * sizeof(T); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + // Write final result to the output. + output->template flat() = buffer.template cast(); + } else { + // Compute block size along the outer dimension for efficiency. + const int64 parallel_cell_size = inner_dim; + const int64 total_workload = outer_dim * inner_dim; + const int64 max_parallelism = total_workload / parallel_cell_size; + + const int64 min_block_workload = 2000; + const int64 min_block_size = + Eigen::divup(min_block_workload, parallel_cell_size); + const int64 max_num_blocks = std::min( + max_parallelism, Eigen::divup(total_workload, min_block_size)); + + // Do not create more blocks than there are threads in a pool. + const int64 num_blocks = std::min(max_num_blocks, num_threads); + + // Block size along the outer dimension. + const int64 outer_block_size = Eigen::divup(outer_dim, num_blocks); + + const T* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Tensor buffer(DataTypeToEnum::v(), {num_blocks, inner_dim}); + buffer.template flat().setZero(); + AccumT* buffer_data = buffer.template flat().data(); + + using Buffer = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + using Input = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + const auto compute = [inner_dim, num_blocks, outer_block_size, + buffer_data, input_data, outer_dim]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + int64 outer_dim_start = start * outer_block_size; + int64 outer_dim_limit = limit * outer_block_size; + outer_dim_limit = std::min(outer_dim, outer_dim_limit); + + Buffer buf(buffer_data + start * inner_dim, inner_dim); + for (int64 i = outer_dim_start; i < outer_dim_limit; ++i) { + auto in = Input(input_data + i * inner_dim, inner_dim); + auto cast = in.template cast(); + buf = Eigen::TensorCwiseBinaryOp(buf, cast); + } + }; + + // Compute cost of reducing a single block. + const int64 compute_size = outer_block_size * inner_dim; + const int64 compute_input_bytes = compute_size * sizeof(T); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + // Aggregate partial results from temporary buffer into first block. + auto buf0 = Buffer(buffer_data, inner_dim); + // Just sum the buffer up, as inner dimensions is not large in this case. + for (int i = 1; i < num_blocks; ++i) { + auto buf = Buffer(buffer_data + i * inner_dim, inner_dim); + buf0 = Eigen::TensorCwiseBinaryOp(buf0, buf); + } + // Write final result to the output. + output->template flat() = buf0.template cast(); + } + } +}; + +// Compute reduction to some serial middle dimensions (like a axis). +// Example: +// input: [D1, D2, ... , DN] +// -> +// output: [Di, ... , Dj] where i & j belongs to set [1,N]. +template +struct ReduceMiddleDimensions { + template + void operator()(const CPUDevice& device, + const Eigen::DSizes& input_dims, + const Tensor& input, Tensor* output, const int axis_begin_dim, + const int axis_num_dim = 1) const { + // Compute dims after reshaping into 3d tensor. + int64 inner_dim = 1, middle_dim = 1, outer_dim = 1; + for (int i = 0; i < axis_begin_dim; ++i) outer_dim *= input_dims[i]; + for (int i = axis_begin_dim; i < axis_begin_dim + axis_num_dim; ++i) + middle_dim *= input_dims[i]; + for (int i = axis_begin_dim + axis_num_dim; i < num_dims; ++i) + inner_dim *= input_dims[i]; + + if ((1 == inner_dim * outer_dim)) { + // Nothing to do. + *output = input; + return; + } else if (1 == inner_dim) { + // Equivalent to ReduceOuterDimensions. + const ReduceOuterDimensions redux; + redux(device, input_dims, input, output); + return; + } + + // Compute block size along the outer dimension for efficiency. + const int64 parallel_cell_size = inner_dim; + const int64 max_parallelism = outer_dim * middle_dim; + const int64 total_workload = max_parallelism * inner_dim; + + const int64 min_block_workload = 2000; + const int64 min_block_size = + Eigen::divup(min_block_workload, parallel_cell_size); + const int64 max_num_blocks = + std::min(max_parallelism, Eigen::divup(total_workload, min_block_size)); + + // Do not create more blocks than there are threads in a pool. + const int64 num_threads = device.numThreads(); + const int64 num_blocks = std::min(max_num_blocks, num_threads); + + // Block size along the outer dimension. + const int64 outer_block_size = Eigen::divup(total_workload, num_blocks); + + const T* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Eigen::Tensor buffer(num_blocks, middle_dim); + buffer.setZero(); + AccumT* buffer_data = buffer.data(); + + using Buffer = Eigen::TensorMap>; + using Input = Eigen::TensorMap>; + + Eigen::array reduction_axis = {0}; + const Reducer reducer; + const BinaryFunctor binary_op; + + const auto compute = [inner_dim, middle_dim, input_data, buffer_data, + total_workload, num_blocks, outer_block_size, + reduction_axis, reducer, binary_op]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + int64 block_start = start * outer_block_size; + int64 block_limit = limit * outer_block_size; + block_limit = std::min(total_workload, block_limit); + Buffer buf(buffer_data + start * middle_dim, middle_dim); + + const int align_start = + ((block_start + inner_dim - 1) / inner_dim) * inner_dim; + const int align_end = (block_limit / inner_dim) * inner_dim; + + int64 coordinate = block_start / inner_dim % middle_dim; + Eigen::Tensor reduced = + Input(&input_data[block_start], align_start - block_start) + .reduce(reduction_axis, reducer) + .template cast(); + + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + + coordinate = align_start / inner_dim % middle_dim; + for (int i = align_start; i < align_end; i += inner_dim) { + reduced = Input(&input_data[i], inner_dim) + .reduce(reduction_axis, reducer) + .template cast(); + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + ++coordinate; + if (middle_dim == coordinate) coordinate = 0; + } + + reduced = Input(&input_data[align_end], block_limit - align_end) + .reduce(reduction_axis, reducer) + .template cast(); + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + }; + + // Compute cost of reducing a single block. + const int64 compute_size = outer_block_size * inner_dim; + const int64 compute_input_bytes = compute_size * sizeof(T); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + using Output = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + // Aggregate partial results from temporary buffer into first block. + auto buf0 = Output(buffer_data, middle_dim); + // TODO(ezhulenev): Parallelize this loop for large inner dimensions? + for (int i = 1; i < num_blocks; ++i) { + auto buf = Output(buffer_data + i * middle_dim, middle_dim); + buf0 = Eigen::TensorCwiseBinaryOp(buf0, buf); + } + + // Write final result to the output. + output->template flat() = buf0.template cast(); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ From 886c4d8069f28189bc47daf320d14b3576d81cd0 Mon Sep 17 00:00:00 2001 From: Letian Kang Date: Fri, 22 Mar 2019 12:44:04 +0800 Subject: [PATCH 5/6] Fix logical bug in front of ReduceOuterDimensions. --- tensorflow/core/kernels/redux_functor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/redux_functor.h b/tensorflow/core/kernels/redux_functor.h index 9df61d07170c58..e1d8fe00c76e8e 100644 --- a/tensorflow/core/kernels/redux_functor.h +++ b/tensorflow/core/kernels/redux_functor.h @@ -45,7 +45,7 @@ struct ReduceOuterDimensions { for (int i = num_dims - num_reduce_dims; i < num_dims; ++i) inner_dim *= input_dims[i]; - if (1 == inner_dim) { + if (1 == outer_dim) { // Nothing to do but passing input to output. *output = input; return; From 12a55780b28cead8b5387c9058369b87764333cf Mon Sep 17 00:00:00 2001 From: Letian Kang Date: Mon, 25 Mar 2019 10:46:28 +0800 Subject: [PATCH 6/6] Fix shape of result(input when nothing to do) unmatch output error. --- tensorflow/core/kernels/redux_functor.h | 34 +++++++++++++++---------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/kernels/redux_functor.h b/tensorflow/core/kernels/redux_functor.h index e1d8fe00c76e8e..e903618662fcf3 100644 --- a/tensorflow/core/kernels/redux_functor.h +++ b/tensorflow/core/kernels/redux_functor.h @@ -36,18 +36,21 @@ struct ReduceOuterDimensions { template void operator()(const CPUDevice& device, const Eigen::DSizes& input_dims, - const Tensor& input, Tensor* output, - const int num_reduce_dims = 1) const { + const Tensor& input, Tensor* output) const { // Compute inner and outer dim after reshaping into 2d tensor. + const int num_output_dims = output->dims(); + auto output_dims = output->template flat().dimensions(); + int64 inner_dim = 1, outer_dim = 1; - for (int i = 0; i < num_dims - num_reduce_dims; ++i) + for (int i = 0; i < num_dims - num_output_dims; ++i) outer_dim *= input_dims[i]; - for (int i = num_dims - num_reduce_dims; i < num_dims; ++i) + for (int i = num_dims - num_output_dims; i < num_dims; ++i) inner_dim *= input_dims[i]; if (1 == outer_dim) { // Nothing to do but passing input to output. - *output = input; + output->template flat() = + input.template flat().reshape(output_dims); return; } @@ -108,7 +111,8 @@ struct ReduceOuterDimensions { device.parallelFor(num_blocks, cost, compute); // Write final result to the output. - output->template flat() = buffer.template cast(); + output->template flat() = + buffer.template cast().reshape(output_dims); } else { // Compute block size along the outer dimension for efficiency. const int64 parallel_cell_size = inner_dim; @@ -178,7 +182,7 @@ struct ReduceOuterDimensions { const decltype(buf)>(buf0, buf); } // Write final result to the output. - output->template flat() = buf0.template cast(); + output->template flat() = buf0.template cast().reshape(output_dims); } } }; @@ -193,19 +197,23 @@ struct ReduceMiddleDimensions { template void operator()(const CPUDevice& device, const Eigen::DSizes& input_dims, - const Tensor& input, Tensor* output, const int axis_begin_dim, - const int axis_num_dim = 1) const { + const Tensor& input, Tensor* output, + const int axis_begin_dim) const { // Compute dims after reshaping into 3d tensor. + const int num_output_dims = output->dims(); + auto output_dims = output->template flat().dimensions(); + int64 inner_dim = 1, middle_dim = 1, outer_dim = 1; for (int i = 0; i < axis_begin_dim; ++i) outer_dim *= input_dims[i]; - for (int i = axis_begin_dim; i < axis_begin_dim + axis_num_dim; ++i) + for (int i = axis_begin_dim; i < axis_begin_dim + num_output_dims; ++i) middle_dim *= input_dims[i]; - for (int i = axis_begin_dim + axis_num_dim; i < num_dims; ++i) + for (int i = axis_begin_dim + num_output_dims; i < num_dims; ++i) inner_dim *= input_dims[i]; if ((1 == inner_dim * outer_dim)) { // Nothing to do. - *output = input; + output->template flat() = + input.template flat().reshape(output_dims); return; } else if (1 == inner_dim) { // Equivalent to ReduceOuterDimensions. @@ -307,7 +315,7 @@ struct ReduceMiddleDimensions { } // Write final result to the output. - output->template flat() = buf0.template cast(); + output->template flat() = buf0.template cast().reshape(output_dims); } };