From 6a187ccddaebb741ea77fc3201c6e36625f0aadb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 4 May 2016 07:46:46 -0800 Subject: [PATCH] Add support for 3d convolutions and pooling. CPU kernels use Eigen, GPU kernels use CuDNN. Change: 121484787 --- tensorflow/core/kernels/BUILD | 21 +- tensorflow/core/kernels/conv_2d.h | 87 ++- tensorflow/core/kernels/conv_3d.h | 48 ++ tensorflow/core/kernels/conv_grad_ops.cc | 96 +-- tensorflow/core/kernels/conv_grad_ops_3d.cc | 739 ++++++++++++++++++ tensorflow/core/kernels/conv_ops.cc | 280 +++---- tensorflow/core/kernels/conv_ops_3d.cc | 355 +++++++++ tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | 202 +++-- tensorflow/core/kernels/cudnn_pooling_gpu.cc | 216 +++++ tensorflow/core/kernels/cudnn_pooling_gpu.h | 65 ++ tensorflow/core/kernels/ops_util.cc | 28 + tensorflow/core/kernels/ops_util.h | 15 + tensorflow/core/kernels/pooling_ops_3d.cc | 515 ++++++++++++ tensorflow/core/kernels/pooling_ops_common.cc | 24 +- tensorflow/core/ops/nn_ops.cc | 153 ++++ tensorflow/core/util/tensor_format.h | 52 +- tensorflow/python/BUILD | 1 + .../python/kernel_tests/conv_ops_3d_test.py | 420 ++++++++++ .../python/kernel_tests/conv_ops_test.py | 4 +- .../kernel_tests/pooling_ops_3d_test.py | 340 ++++++++ .../python/kernel_tests/pooling_ops_test.py | 4 +- tensorflow/python/ops/common_shapes.py | 124 +-- tensorflow/python/ops/nn.py | 3 + tensorflow/python/ops/nn_grad.py | 101 ++- tensorflow/python/ops/nn_ops.py | 158 +++- 25 files changed, 3591 insertions(+), 460 deletions(-) create mode 100644 tensorflow/core/kernels/conv_3d.h create mode 100644 tensorflow/core/kernels/conv_grad_ops_3d.cc create mode 100644 tensorflow/core/kernels/conv_ops_3d.cc create mode 100644 tensorflow/core/kernels/cudnn_pooling_gpu.cc create mode 100644 tensorflow/core/kernels/cudnn_pooling_gpu.h create mode 100644 tensorflow/core/kernels/pooling_ops_3d.cc create mode 100644 tensorflow/python/kernel_tests/conv_ops_3d_test.py create mode 100644 tensorflow/python/kernel_tests/pooling_ops_3d_test.py diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0307765ac06047..a6a9fee68ab9bf 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -71,6 +71,16 @@ cc_library( ], ) +cc_library( + name = "conv_3d", + hdrs = ["conv_3d.h"], + deps = [ + ":eigen_helpers", + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + cc_library( name = "fill_functor", hdrs = ["fill_functor.h"], @@ -1106,11 +1116,15 @@ tf_cuda_cc_test( # conv_ops_gpu.h has be separated into its own library. tf_kernel_library( name = "conv_ops", - srcs = ["conv_grad_ops.cc"], + srcs = [ + "conv_grad_ops.cc", + "conv_grad_ops_3d.cc", + ], prefix = "conv_ops", deps = [ ":bounds_check", ":conv_2d", + ":conv_3d", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -1238,11 +1252,14 @@ tf_kernel_library( name = "pooling_ops", srcs = [ "avgpooling_op.cc", + "cudnn_pooling_gpu.cc", "maxpooling_op.cc", + "pooling_ops_3d.cc", "pooling_ops_common.cc", ], hdrs = [ "avgpooling_op.h", + "cudnn_pooling_gpu.h", "maxpooling_op.h", "pooling_ops_common.h", ], @@ -1257,6 +1274,8 @@ tf_kernel_library( ], deps = [ ":conv_2d", + ":conv_3d", + ":conv_ops", ":eigen_helpers", ":ops_util", "//tensorflow/core:core_cpu", diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index c7d5c3aeeb9716..9bbc67520f35d5 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -116,23 +116,31 @@ struct MatMulConvFunctor { } }; -template +// Shuffles a filter tensor from: +// [, in, out] +// to: +// [out, in, ] +template struct TransformFilter { void operator()(const Device& d, - typename TTypes::ConstTensor in, - typename TTypes::Tensor out) { - // We want a 3, 2, 0, 1 shuffle. We can merge dimensions 0 and 1 together - // to help speedup the shuffle operation. + typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together + // to speed up the shuffle operation. Eigen::DSizes merged_dims; - merged_dims[0] = in.dimension(0) * in.dimension(1); - merged_dims[1] = in.dimension(2); - merged_dims[2] = in.dimension(3); - - Eigen::DSizes expanded_dims; - expanded_dims[0] = in.dimension(3); - expanded_dims[1] = in.dimension(2); - expanded_dims[2] = in.dimension(0); - expanded_dims[3] = in.dimension(1); + merged_dims[0] = in.dimension(0); // spatial dimensions + for (int i = 1; i < NDIMS - 2; ++i) { + merged_dims[0] *= in.dimension(i); + } + merged_dims[1] = in.dimension(NDIMS - 2); // input filters + merged_dims[2] = in.dimension(NDIMS - 1); // output filters + + Eigen::DSizes expanded_dims; + expanded_dims[0] = in.dimension(NDIMS - 1); // output filters + expanded_dims[1] = in.dimension(NDIMS - 2); // input filters + for (int i = 0; i < NDIMS; ++i) { // spatial dimensions + expanded_dims[i + 2] = in.dimension(i); + } out.device(d) = in.reshape(merged_dims) .shuffle(Eigen::DSizes(2, 1, 0)) @@ -194,41 +202,50 @@ struct TransformDepth { } }; -template +template struct PadInput { void operator()(const Device& d, - typename TTypes::ConstTensor in, - int padding_rows_left, int padding_rows_right, - int padding_cols_left, int padding_cols_right, - typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const std::array& padding_left, + const std::array& padding_right, + typename TTypes::Tensor out, TensorFormat format) { - Eigen::array, 4> padding; - padding[GetTensorDimIndex(format, 'N')] = std::make_pair(0, 0); - padding[GetTensorDimIndex(format, 'H')] = - std::make_pair(padding_rows_left, padding_rows_right); - padding[GetTensorDimIndex(format, 'W')] = - std::make_pair(padding_cols_left, padding_cols_right); - padding[GetTensorDimIndex(format, 'C')] = std::make_pair(0, 0); + Eigen::array, NDIMS> padding; + padding[GetTensorDimIndex(format, 'N')] = std::make_pair(0, 0); + for (int i = 0; i < NDIMS - 2; ++i) { + padding[GetTensorDimIndex(format, '0' + i)] = + std::make_pair(padding_left[i], padding_right[i]); + } + padding[GetTensorDimIndex(format, 'C')] = std::make_pair(0, 0); out.device(d) = in.pad(padding); } }; -template +// Converts a tensor from: +// [batch, , filters] +// to: +// [batch, filters, ] +template struct NHWCToNCHW { - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out); + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out); }; -template +// Converts a tensor from: +// [batch, filters, ] +// to: +// [batch, , filters] +template struct NCHWToNHWC { - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out); + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out); }; -template +// Reverses the effect of TransformFilter above. +template struct ReverseTransformFilter { - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out); + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out); }; } // namespace functor diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h new file mode 100644 index 00000000000000..af3841ad4a9ce9 --- /dev/null +++ b/tensorflow/core/kernels/conv_3d.h @@ -0,0 +1,48 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functors for 3d convolution. + +#ifndef TENSORFLOW_KERNELS_CONV_3D_H_ +#define TENSORFLOW_KERNELS_CONV_3D_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_cuboid_convolution.h" + +namespace tensorflow { +namespace functor { + +// Applies a 3D convolution to a batch of multi-channel volumes. +template +struct CuboidConvolution; + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +struct CuboidConvolution { + void operator()(const CPUDevice& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride_planes, + int stride_rows, int stride_cols, + const Eigen::PaddingType& padding) { + output.device(d) = Eigen::CuboidConvolution( + input, filter, stride_planes, stride_rows, stride_cols, padding); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CONV_3D_H_ diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index f5daa9b2ec88f3..84cc7017c4019e 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -946,7 +946,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel { filter_rows, filter_cols}), &transformed_filter)); - functor::TransformFilter()( + functor::TransformFilter()( context->eigen_device(), To32Bit(filter.tensor()), To32Bit(transformed_filter.tensor())); @@ -959,9 +959,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel { output_cols, out_depth), &transformed_out_backprop)); - functor::NHWCToNCHW()(context->eigen_device(), - out_backprop.tensor(), - transformed_out_backprop.tensor()); + functor::NHWCToNCHW()( + context->eigen_device(), out_backprop.tensor(), + transformed_out_backprop.tensor()); } else { transformed_out_backprop = out_backprop; } @@ -1022,11 +1022,11 @@ class Conv2DSlowBackpropInputOp : public OpKernel { &in_backprop_remove_padding)); // Remove the padding for odd rows or cols. - functor::PadInput()( + functor::PadInput()( context->template eigen_device(), To32Bit(const_cast(pre_transformed_in_backprop) .tensor()), - 0, -rows_odd, 0, -cols_odd, + {{0, 0}}, {{-rows_odd, -cols_odd}}, To32Bit(in_backprop_remove_padding.tensor()), FORMAT_NCHW); pre_transformed_in_backprop = in_backprop_remove_padding; @@ -1034,7 +1034,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel { if (data_format_ == FORMAT_NHWC) { auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::NCHWToNHWC()( + functor::NCHWToNHWC()( context->eigen_device(), toConstTensor(pre_transformed_in_backprop).template tensor(), in_backprop->tensor()); @@ -1167,9 +1167,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { input_cols + cols_odd, in_depth), &compatible_input)); - functor::PadInput()( + functor::PadInput()( context->template eigen_device(), - To32Bit(input.tensor()), 0, rows_odd, 0, cols_odd, + To32Bit(input.tensor()), {{0, 0}}, {{rows_odd, cols_odd}}, To32Bit(compatible_input.tensor()), data_format_); } else { compatible_input = input; @@ -1227,9 +1227,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { ShapeFromFormat(FORMAT_NCHW, batch, output_rows, output_cols, out_depth), &transformed_out_backprop)); - functor::NHWCToNCHW()(context->eigen_device(), - out_backprop.tensor(), - transformed_out_backprop.tensor()); + functor::NHWCToNCHW()( + context->eigen_device(), out_backprop.tensor(), + transformed_out_backprop.tensor()); } else { transformed_out_backprop = out_backprop; } @@ -1246,7 +1246,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { GetTensorDim(compatible_input, data_format_, 'W'), GetTensorDim(compatible_input, data_format_, 'C')), &transformed_input)); - functor::NHWCToNCHW()( + functor::NHWCToNCHW()( context->eigen_device(), const_cast(compatible_input).tensor(), transformed_input.tensor()); @@ -1284,7 +1284,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { } auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::ReverseTransformFilter()( + functor::ReverseTransformFilter()( context->eigen_device(), toConstTensor(pre_transformed_filter_backprop).template tensor(), filter_backprop->tensor()); @@ -1301,40 +1301,40 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void ShuffleAndReverse::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - const Eigen::DSizes& order, \ - const Eigen::array& reverse_dims, \ - typename TTypes::Tensor output); \ - extern template struct ShuffleAndReverse; \ - template <> \ - void InflatePadAndShuffle::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor input, \ - const Eigen::DSizes& strides, \ - const Eigen::array, 4>& pad_dims, \ - const Eigen::DSizes& order, \ - typename TTypes::Tensor output); \ - extern template struct InflatePadAndShuffle; \ - template <> \ - void TransformFilter::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ - typename TTypes::Tensor out); \ - extern template struct TransformFilter; \ - template <> \ - void TransformDepth::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ - const Eigen::DSizes& shuffle, \ - typename TTypes::Tensor out); \ - extern template struct TransformDepth; \ - template <> \ - void PadInput::operator()( \ - const GPUDevice& d, typename TTypes::ConstTensor in, \ - int padding_rows_left, int padding_rows_right, int padding_cols_left, \ - int padding_cols_right, typename TTypes::Tensor out, \ - TensorFormat data_format); \ - extern template struct PadInput; +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ShuffleAndReverse::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& order, \ + const Eigen::array& reverse_dims, \ + typename TTypes::Tensor output); \ + extern template struct ShuffleAndReverse; \ + template <> \ + void InflatePadAndShuffle::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + const Eigen::array, 4>& pad_dims, \ + const Eigen::DSizes& order, \ + typename TTypes::Tensor output); \ + extern template struct InflatePadAndShuffle; \ + template <> \ + void TransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + extern template struct TransformFilter; \ + template <> \ + void TransformDepth::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const Eigen::DSizes& shuffle, \ + typename TTypes::Tensor out); \ + extern template struct TransformDepth; \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const std::array& padding_left, \ + const std::array& padding_right, \ + typename TTypes::Tensor out, TensorFormat data_format); \ + extern template struct PadInput; DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc new file mode 100644 index 00000000000000..1be72034a4da16 --- /dev/null +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -0,0 +1,739 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/conv_3d.h" + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/platform/stream_executor.h" +using perftools::gputools::dnn::DimIndex; +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// TODO(mjanusz): Get rid of the macro and return shapes directly. +#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \ + const Tensor& input = context->input(0); \ + const Tensor& filter = context->input(1); \ + const Tensor& out_backprop = context->input(2); \ + OP_REQUIRES( \ + context, input.dims() == 5, \ + errors::InvalidArgument(label, ": input must be 5-dimensional")); \ + OP_REQUIRES( \ + context, filter.dims() == 5, \ + errors::InvalidArgument(label, ": filter must be 5-dimensional")); \ + OP_REQUIRES( \ + context, out_backprop.dims() == 5, \ + errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \ + const int64 batch = input.dim_size(0); \ + OP_REQUIRES( \ + context, batch == out_backprop.dim_size(0), \ + errors::InvalidArgument( \ + label, ": input and out_backprop must have the same batch size")); \ + const std::array input_size = { \ + {input.dim_size(1), input.dim_size(2), input.dim_size(3)}}; \ + const std::array filter_size = { \ + {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}}; \ + const int64 output_cols = out_backprop.dim_size(3); \ + const int64 output_rows = out_backprop.dim_size(2); \ + const int64 output_planes = out_backprop.dim_size(1); \ + const int64 in_depth = input.dim_size(4); \ + OP_REQUIRES(context, in_depth == filter.dim_size(3), \ + errors::InvalidArgument( \ + label, ": input and filter must have the same depth")); \ + const int64 out_depth = filter.dim_size(4); \ + OP_REQUIRES( \ + context, out_depth == out_backprop.dim_size(4), \ + errors::InvalidArgument( \ + label, ": filter and out_backprop must have the same out_depth")); \ + const std::array strides = {{stride_[1], stride_[2], stride_[3]}}; \ + std::array out, padding; \ + OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides, \ + padding_, &out, &padding)); \ + OP_REQUIRES(context, output_planes == out[0], \ + errors::InvalidArgument( \ + label, \ + ": Number of planes of out_backprop doesn't match " \ + "computed: actual = ", \ + output_planes, ", computed = ", out[0])); \ + OP_REQUIRES( \ + context, output_rows == out[1], \ + errors::InvalidArgument( \ + label, ": Number of rows of out_backprop doesn't match computed: ", \ + "actual = ", output_rows, ", computed = ", out[1])); \ + OP_REQUIRES( \ + context, output_cols == out[2], \ + errors::InvalidArgument( \ + label, ": Number of cols of out_backprop doesn't match computed: ", \ + "actual = ", output_cols, ", computed = ", out[2])); \ + const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \ + const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \ + const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \ + const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \ + const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \ + const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \ + const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \ + const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \ + const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \ + const auto bottom_pad_planes = \ + padded_out_planes - expanded_out_planes - top_pad_planes; \ + const auto bottom_pad_rows = \ + padded_out_rows - expanded_out_rows - top_pad_rows; \ + const auto right_pad_cols = \ + padded_out_cols - expanded_out_cols - left_pad_cols; \ + VLOG(2) << "Conv3d: " << label \ + << ": expanded_out_planes = " << expanded_out_planes \ + << ": expanded_out_rows = " << expanded_out_rows \ + << ", expanded_out_cols = " << expanded_out_cols \ + << ", padded_out_planes = " << padded_out_planes \ + << ", padded_out_rows = " << padded_out_rows \ + << ", padded_out_cols = " << padded_out_cols \ + << ", top_pad_planes = " << top_pad_planes \ + << ", top_pad_rows = " << top_pad_rows \ + << ", left_pad_cols = " << left_pad_cols \ + << ", bottom_pad_planes = " << bottom_pad_planes \ + << ", bottom_pad_rows = " << bottom_pad_rows \ + << ", right_pad_cols = " << right_pad_cols + +// Backprop for input. +template +class Conv3DBackpropInputOp : public OpKernel { + public: + explicit Conv3DBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, (stride_[0] == 1 && stride_[4] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); + Eigen::array, 5> pad_dims{ + {0, 0}, + {top_pad_planes, bottom_pad_planes}, + {top_pad_rows, bottom_pad_rows}, + {left_pad_cols, right_pad_cols}, + {0, 0}}; + Tensor* in_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &in_backprop)); + + // Fill out a padded out_backprop. + TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows, + padded_out_cols, out_depth}); + Tensor padded_output; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + padded_out_shape, &padded_output)); + Eigen::DSizes no_op_shuffle{0, 1, 2, 3, 4}; + Eigen::DSizes eigen_strides{1, strides[0], strides[1], + strides[2], 1}; + functor::InflatePadAndShuffle()( + context->eigen_device(), out_backprop.tensor(), + eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor()); + const Tensor& padded_output_cref = padded_output; + + // Fill a new "reverted" filter. We need to transpose the in_depth and + // out_depth for the filter and reverse the planes, rows and cols. + TensorShape r_filter_shape( + {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth}); + Tensor r_filter; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), + r_filter_shape, &r_filter)); + Eigen::DSizes filter_order{0, 1, 2, 4, 3}; + Eigen::array filter_rev_dims{true, true, true, false, false}; + functor::ShuffleAndReverse()( + context->eigen_device(), filter.tensor(), filter_order, + filter_rev_dims, r_filter.tensor()); + const Tensor& r_filter_cref = r_filter; + + // Now we can call conv_3d directly. + functor::CuboidConvolution()( + context->eigen_device(), in_backprop->tensor(), + padded_output_cref.tensor(), r_filter_cref.tensor(), 1, 1, + 1, BrainPadding2EigenPadding(VALID)); + } + + private: + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER( + Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint("T"), + Conv3DBackpropInputOp); +#ifndef __ANDROID__ +REGISTER_KERNEL_BUILDER( + Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint("T"), + Conv3DBackpropInputOp); +#endif + +// Backprop for filter. +template +class Conv3DBackpropFilterOp : public OpKernel { + public: + explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, (stride_[0] == 1 && stride_[4] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); + Eigen::array, 5> pad_dims{ + {0, 0}, + {top_pad_planes, bottom_pad_planes}, + {top_pad_rows, bottom_pad_rows}, + {left_pad_cols, right_pad_cols}, + {0, 0}}; + Tensor* filter_backprop; + OP_REQUIRES_OK( + context, context->allocate_output(0, filter.shape(), &filter_backprop)); + + // For the backprop of the filter, we need to also transpose the + // out_backprop. + // The shape of backprop is + // [batch, out_z, out_y, out_x, out_depth] + // And we need to change it to + // [out_depth, out_x, out_y, out_z, batch] + Eigen::DSizes out_order{4, 1, 2, 3, 0}; + TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows, + padded_out_cols, batch}); + Tensor padded_output; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + padded_out_shape, &padded_output)); + Eigen::DSizes eigen_strides{1, strides[0], strides[1], + strides[2], 1}; + functor::InflatePadAndShuffle()( + context->eigen_device(), out_backprop.tensor(), + eigen_strides, pad_dims, out_order, padded_output.tensor()); + const Tensor& padded_output_cref = padded_output; + + // For the backprop of the filter, we need to transpose the input. + // The shape of input is + // [batch, in_z, in_y, in_x, in_depth] + // And we need to change it to + // [in_z, in_y, in_x, batch, in_depth] + Eigen::DSizes in_order{1, 2, 3, 0, 4}; + TensorShape in_shuffle_shape( + {input_size[0], input_size[1], input_size[2], batch, in_depth}); + Tensor in_shuffle; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::v(), + in_shuffle_shape, &in_shuffle)); + // No need for reversing this time. + Eigen::array no_reverse{false, false, false, false, false}; + functor::ShuffleAndReverse()( + context->eigen_device(), input.tensor(), in_order, + no_reverse, in_shuffle.tensor()); + const Tensor& in_shuffle_cref = in_shuffle; + + // The output of the conv_3d would be + // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth] + // and we need to shuffle it back to + // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth]; + // And we need to reverse the filter backprops. + // So we need to allocate (sigh) yet another piece of memory to hold the + // output. + TensorShape filter_shuffle_shape( + {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth}); + Tensor filter_shuffle; + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum::v(), + filter_shuffle_shape, &filter_shuffle)); + functor::CuboidConvolution()( + context->eigen_device(), filter_shuffle.tensor(), + padded_output_cref.tensor(), in_shuffle_cref.tensor(), 1, 1, + 1, BrainPadding2EigenPadding(VALID)); + + // Now copy the filter_backprop back to the destination. + Eigen::DSizes filter_order{1, 2, 3, 4, 0}; + Eigen::array filter_rev_dims{true, true, true, false, false}; + const Tensor& filter_shuffle_cref = filter_shuffle; + functor::ShuffleAndReverse()( + context->eigen_device(), filter_shuffle_cref.tensor(), + filter_order, filter_rev_dims, filter_backprop->tensor()); + } + + private: + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER( + Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint("T"), + Conv3DBackpropFilterOp); +#ifndef __ANDROID__ +REGISTER_KERNEL_BUILDER( + Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint("T"), + Conv3DBackpropFilterOp); +#endif + +// GPU definitions of both ops. +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +// This ensures that the custom implementation is used instead of the default +// Eigen one (which is used for CPU). +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void TransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + template <> \ + void ReverseTransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const std::array& padding_left, \ + const std::array& padding_right, \ + typename TTypes::Tensor out, TensorFormat format); + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC +} // namespace functor + +template +class Conv3DBackpropInputOp : public OpKernel { + public: + explicit Conv3DBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, (stride_[0] == 1 && stride_[4] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + void Compute(OpKernelContext* context) override { + EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); + Tensor* in_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &in_backprop)); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 && + stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1) { + const uint64 m = batch * input_size[1] * input_size[2] * input_size[0]; + const uint64 k = out_depth; + const uint64 n = in_depth; + + auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(in_backprop->template flat().data(), + in_backprop->template flat().size()); + + auto transpose = perftools::gputools::blas::Transpose::kTranspose; + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + + bool blas_launch_status = + stream + ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } + + int padding_rows = 0, padding_cols = 0, padding_planes = 0; + + if (padding_ == Padding::SAME) { + padding_planes = + (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]; + padding_cols = + (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]; + padding_rows = + (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]; + } + const bool rows_odd = (padding_rows % 2 != 0); + const bool cols_odd = (padding_cols % 2 != 0); + const bool planes_odd = (padding_planes % 2 != 0); + + TensorShape compatible_input_shape; + if (rows_odd || cols_odd || planes_odd) { + // cuDNN only supports the same amount of padding on both sides. + compatible_input_shape = { + batch, + in_depth, + input_size[0] + planes_odd, + input_size[1] + rows_odd, + input_size[2] + cols_odd, + }; + } else { + compatible_input_shape = {batch, in_depth, input_size[0], input_size[1], + input_size[2]}; + } + + perftools::gputools::dnn::BatchDescriptor input_desc(3); + input_desc.set_count(batch) + .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) + .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) + .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) + .set_feature_map_count(in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc(3); + output_desc.set_count(batch) + .set_spatial_dim(DimIndex::X, output_cols) + .set_spatial_dim(DimIndex::Y, output_rows) + .set_spatial_dim(DimIndex::Z, output_planes) + .set_feature_map_count(out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc(3); + filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) + .set_spatial_dim(DimIndex::Y, filter_size[1]) + .set_spatial_dim(DimIndex::Z, filter_size[0]) + .set_input_feature_map_count(in_depth) + .set_output_feature_map_count(out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3); + conv_desc.set_filter_stride(DimIndex::X, strides[2]) + .set_filter_stride(DimIndex::Y, strides[1]) + .set_filter_stride(DimIndex::Z, strides[0]) + .set_zero_padding(DimIndex::X, padding_cols / 2) + .set_zero_padding(DimIndex::Y, padding_rows / 2) + .set_zero_padding(DimIndex::Z, padding_planes / 2); + + // Shape: out, in, z, y, x. + Tensor transformed_filter; + OP_REQUIRES_OK( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({out_depth, in_depth, filter_size[0], + filter_size[1], filter_size[2]}), + &transformed_filter)); + functor::TransformFilter()( + context->eigen_device(), To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + + // Shape: batch, filters, z, y, x. + Tensor transformed_out_backprop; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + {batch, out_depth, output_planes, + output_rows, output_cols}, + &transformed_out_backprop)); + functor::NHWCToNCHW()( + context->eigen_device(), out_backprop.tensor(), + transformed_out_backprop.tensor()); + + // Shape: batch, filters, z, y, x. + Tensor pre_transformed_in_backprop; + OP_REQUIRES_OK( + context, + context->allocate_temp(DataTypeToEnum::value, compatible_input_shape, + &pre_transformed_in_backprop)); + + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat().data(), + transformed_filter.template flat().size()); + auto in_backprop_ptr = + AsDeviceMemory(pre_transformed_in_backprop.template flat().data(), + pre_transformed_in_backprop.template flat().size()); + + static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default + + CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, + context); + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardDataWithScratch( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, + conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator) + .ok(); + + if (!cudnn_launch_status) { + context->SetStatus(errors::Internal( + "cuDNN Backward Data function launch failure : input shape(", + input.shape().DebugString(), ") filter shape(", + filter.shape().DebugString(), ")")); + } + + if (rows_odd || cols_odd || planes_odd) { + Tensor in_backprop_remove_padding; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + {batch, in_depth, input_size[0], + input_size[1], input_size[2]}, + &in_backprop_remove_padding)); + + // Remove the padding for odd spatial dimensions. + functor::PadInput()( + context->eigen_device(), + To32Bit(const_cast(pre_transformed_in_backprop) + .tensor()), + {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}}, + To32Bit(in_backprop_remove_padding.tensor()), FORMAT_NCHW); + + pre_transformed_in_backprop = in_backprop_remove_padding; + } + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::NCHWToNHWC()( + context->eigen_device(), + toConstTensor(pre_transformed_in_backprop).template tensor(), + in_backprop->tensor()); + } + + private: + std::vector stride_; + Padding padding_; +}; + +template +class Conv3DBackpropFilterOp : public OpKernel { + public: + explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, (stride_[0] == 1 && stride_[4] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); + + Tensor* filter_backprop; + OP_REQUIRES_OK( + context, context->allocate_output(0, filter.shape(), &filter_backprop)); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 && + strides[2] == 1 && strides[1] == 1 && strides[0] == 1) { + const uint64 m = in_depth; + const uint64 k = batch * input_size[1] * input_size[2] * input_size[0]; + const uint64 n = out_depth; + + // The shape of output backprop is + // [batch, out_z, out_y, out_x, out_depth] + // From cublas's perspective, it is: n x k + auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + + // The shape of input is: + // [batch, in_z, in_y, in_x, in_depth], + // From cublas's perspective, it is: m x k + auto b_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + + // The shape of the filter backprop is: + // [1, 1, 1, in_depth, out_depth] + // From cublas's perspective, it is: n x m + auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), + filter_backprop->template flat().size()); + + bool blas_launch_status = + stream + ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose, + n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } + int padding_rows = 0, padding_cols = 0, padding_planes = 0; + + if (padding_ == Padding::SAME) { + padding_planes = + (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]; + padding_cols = + (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]; + padding_rows = + (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]; + } + bool rows_odd = (padding_rows % 2 != 0); + bool cols_odd = (padding_cols % 2 != 0); + bool planes_odd = (padding_planes % 2 != 0); + + Tensor compatible_input; + if (rows_odd || cols_odd || planes_odd) { + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum::value, + {batch, input_size[0] + planes_odd, + input_size[1] + rows_odd, + input_size[2] + cols_odd, in_depth}, + &compatible_input)); + + functor::PadInput()( + context->template eigen_device(), + To32Bit(input.tensor()), {{0, 0, 0}}, + {{planes_odd, rows_odd, cols_odd}}, + To32Bit(compatible_input.tensor()), FORMAT_NHWC); + } else { + compatible_input = input; + } + + perftools::gputools::dnn::BatchDescriptor input_desc(3); + input_desc.set_count(batch) + .set_spatial_dim(DimIndex::X, compatible_input.dim_size(3)) + .set_spatial_dim(DimIndex::Y, compatible_input.dim_size(2)) + .set_spatial_dim(DimIndex::Z, compatible_input.dim_size(1)) + .set_feature_map_count(in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc(3); + output_desc.set_count(batch) + .set_spatial_dim(DimIndex::X, output_cols) + .set_spatial_dim(DimIndex::Y, output_rows) + .set_spatial_dim(DimIndex::Z, output_planes) + .set_feature_map_count(out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc(3); + filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) + .set_spatial_dim(DimIndex::Y, filter_size[1]) + .set_spatial_dim(DimIndex::Z, filter_size[0]) + .set_input_feature_map_count(in_depth) + .set_output_feature_map_count(out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3); + conv_desc.set_filter_stride(DimIndex::X, strides[2]) + .set_filter_stride(DimIndex::Y, strides[1]) + .set_filter_stride(DimIndex::Z, strides[0]) + .set_zero_padding(DimIndex::X, padding_cols / 2) + .set_zero_padding(DimIndex::Y, padding_rows / 2) + .set_zero_padding(DimIndex::Z, padding_planes / 2); + + Tensor pre_transformed_filter_backprop; + OP_REQUIRES_OK( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({out_depth, in_depth, filter_size[0], + filter_size[1], filter_size[2]}), + &pre_transformed_filter_backprop)); + + Tensor transformed_out_backprop; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + {batch, out_depth, output_planes, + output_rows, output_cols}, + &transformed_out_backprop)); + functor::NHWCToNCHW()( + context->eigen_device(), out_backprop.tensor(), + transformed_out_backprop.tensor()); + + Tensor transformed_input; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + {batch, in_depth, compatible_input.dim_size(1), + compatible_input.dim_size(2), + compatible_input.dim_size(3)}, + &transformed_input)); + functor::NHWCToNCHW()( + context->eigen_device(), + const_cast(compatible_input).tensor(), + transformed_input.tensor()); + + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_backprop_ptr = AsDeviceMemory( + pre_transformed_filter_backprop.template flat().data(), + pre_transformed_filter_backprop.template flat().size()); + auto input_ptr = + AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + + static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default + CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + context); + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardFilterWithScratch( + input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, + filter_desc, &filter_backprop_ptr, &scratch_allocator) + .ok(); + + if (!cudnn_launch_status) { + context->SetStatus(errors::Internal( + "cuDNN Backward Filter function launch failure : input shape(", + input.shape().DebugString(), ") filter shape(", + filter.shape().DebugString(), ")")); + } + + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::ReverseTransformFilter()( + context->eigen_device(), + toConstTensor(pre_transformed_filter_backprop).template tensor(), + filter_backprop->tensor()); + } + + private: + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER( + Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint("T"), + Conv3DBackpropInputOp); +REGISTER_KERNEL_BUILDER( + Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint("T"), + Conv3DBackpropFilterOp); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 69fd0dc00b9536..d88c1025af245c 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -303,145 +303,145 @@ struct LaunchConvOp { ", n=", n, ", k=", k)); } return; - } - int padding_rows = 0; - int padding_cols = 0; - const int64 in_batch = GetTensorDim(input, data_format, 'N'); - int64 in_rows = GetTensorDim(input, data_format, 'H'); - int64 in_cols = GetTensorDim(input, data_format, 'W'); - const int64 in_depths = GetTensorDim(input, data_format, 'C'); - const int64 out_batch = GetTensorDim(*output, data_format, 'N'); - const int64 out_rows = GetTensorDim(*output, data_format, 'H'); - const int64 out_cols = GetTensorDim(*output, data_format, 'W'); - const int64 out_depths = GetTensorDim(*output, data_format, 'C'); - const int64 patch_rows = filter.dim_size(0); - const int64 patch_cols = filter.dim_size(1); - if (padding == Eigen::PADDING_SAME) { - // Total padding on rows and cols is - // Pr = (R' - 1) * S + Kr - R - // Pc = (C' - 1) * S + Kc - C - // where (R', C') are output dimensions, (R, C) are input dimensions, S - // is stride, (Kr, Kc) are filter dimensions. - // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top - // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means - // we pad more on the right and bottom than on the top and left. - padding_rows = (out_rows - 1) * row_stride + patch_rows - in_rows; - padding_cols = (out_cols - 1) * col_stride + patch_cols - in_cols; - const bool rows_odd = (padding_rows % 2 != 0); - const bool cols_odd = (padding_cols % 2 != 0); - if (rows_odd || cols_odd) { - Tensor transformed_input; - int64 new_in_rows = in_rows + rows_odd; - int64 new_in_cols = in_cols + cols_odd; - OP_REQUIRES_OK(ctx, - ctx->allocate_temp( - DataTypeToEnum::value, - ShapeFromFormat(data_format, in_batch, new_in_rows, - new_in_cols, in_depths), - &transformed_input)); - - functor::PadInput()( - ctx->eigen_device(), - To32Bit(input_param.tensor()), 0, rows_odd, 0, cols_odd, - To32Bit(transformed_input.tensor()), data_format); - input = transformed_input; - in_rows = new_in_rows; - in_cols = new_in_cols; - } - } - - if (data_format == FORMAT_NHWC) { - // Convert the input tensor from NHWC to NCHW. + } + int padding_rows = 0; + int padding_cols = 0; + const int64 in_batch = GetTensorDim(input, data_format, 'N'); + int64 in_rows = GetTensorDim(input, data_format, 'H'); + int64 in_cols = GetTensorDim(input, data_format, 'W'); + const int64 in_depths = GetTensorDim(input, data_format, 'C'); + const int64 out_batch = GetTensorDim(*output, data_format, 'N'); + const int64 out_rows = GetTensorDim(*output, data_format, 'H'); + const int64 out_cols = GetTensorDim(*output, data_format, 'W'); + const int64 out_depths = GetTensorDim(*output, data_format, 'C'); + const int64 patch_rows = filter.dim_size(0); + const int64 patch_cols = filter.dim_size(1); + if (padding == Eigen::PADDING_SAME) { + // Total padding on rows and cols is + // Pr = (R' - 1) * S + Kr - R + // Pc = (C' - 1) * S + Kc - C + // where (R', C') are output dimensions, (R, C) are input dimensions, S + // is stride, (Kr, Kc) are filter dimensions. + // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top + // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means + // we pad more on the right and bottom than on the top and left. + padding_rows = (out_rows - 1) * row_stride + patch_rows - in_rows; + padding_cols = (out_cols - 1) * col_stride + patch_cols - in_cols; + const bool rows_odd = (padding_rows % 2 != 0); + const bool cols_odd = (padding_cols % 2 != 0); + if (rows_odd || cols_odd) { Tensor transformed_input; - OP_REQUIRES_OK(ctx, ctx->allocate_temp( - DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, - in_cols, in_depths), - &transformed_input)); - functor::NHWCToNCHW()( - ctx->eigen_device(), - const_cast(input).tensor(), - transformed_input.tensor()); + int64 new_in_rows = in_rows + rows_odd; + int64 new_in_cols = in_cols + cols_odd; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(data_format, in_batch, new_in_rows, + new_in_cols, in_depths), + &transformed_input)); + + functor::PadInput()( + ctx->eigen_device(), To32Bit(input_param.tensor()), + {{0, 0}}, {{rows_odd, cols_odd}}, + To32Bit(transformed_input.tensor()), data_format); input = transformed_input; + in_rows = new_in_rows; + in_cols = new_in_cols; } + } - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(in_batch) - .set_feature_map_count(in_depths) - .set_height(in_rows) - .set_width(in_cols) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(out_batch) - .set_height(out_rows) - .set_width(out_cols) - .set_feature_map_count(out_depths) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(filter.dim_size(0)) - .set_input_filter_width(filter.dim_size(1)) - .set_input_feature_map_count(filter.dim_size(2)) - .set_output_feature_map_count(filter.dim_size(3)); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(row_stride) - .set_horizontal_filter_stride(col_stride) - .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); - - Tensor transformed_filter; - OP_REQUIRES_OK(ctx, - ctx->allocate_temp( - DataTypeToEnum::value, - TensorShape({filter.dim_size(3), filter.dim_size(2), - filter.dim_size(0), filter.dim_size(1)}), - &transformed_filter)); - - functor::TransformFilter()( - ctx->eigen_device(), To32Bit(filter.tensor()), - To32Bit(transformed_filter.tensor())); - - Tensor transformed_output; - OP_REQUIRES_OK(ctx, ctx->allocate_temp( - DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows, - out_cols, out_depths), - &transformed_output)); - - auto input_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); - auto filter_ptr = - AsDeviceMemory(transformed_filter.template flat().data(), - transformed_filter.template flat().size()); - auto output_ptr = - AsDeviceMemory(transformed_output.template flat().data(), - transformed_output.template flat().size()); - - static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( - "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default - ); - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); - bool cudnn_launch_status = - stream - ->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc, - filter_ptr, conv_desc, output_desc, - &output_ptr, &scratch_allocator) - .ok(); + if (data_format == FORMAT_NHWC) { + // Convert the input tensor from NHWC to NCHW. + Tensor transformed_input; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + ShapeFromFormat(FORMAT_NCHW, in_batch, + in_rows, in_cols, in_depths), + &transformed_input)); + functor::NHWCToNCHW()( + ctx->eigen_device(), + const_cast(input).tensor(), + transformed_input.tensor()); + input = transformed_input; + } - if (!cudnn_launch_status) { - ctx->SetStatus(errors::Internal( - "cuDNN launch failure : input shape(", input.shape().DebugString(), - ") filter shape(", filter.shape().DebugString(), ")")); - } + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(in_batch) + .set_feature_map_count(in_depths) + .set_height(in_rows) + .set_width(in_cols) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(out_batch) + .set_height(out_rows) + .set_width(out_cols) + .set_feature_map_count(out_depths) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(filter.dim_size(0)) + .set_input_filter_width(filter.dim_size(1)) + .set_input_feature_map_count(filter.dim_size(2)) + .set_output_feature_map_count(filter.dim_size(3)); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(row_stride) + .set_horizontal_filter_stride(col_stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + Tensor transformed_filter; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape({filter.dim_size(3), filter.dim_size(2), + filter.dim_size(0), filter.dim_size(1)}), + &transformed_filter)); + + functor::TransformFilter()( + ctx->eigen_device(), To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + + Tensor transformed_output; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + ShapeFromFormat(FORMAT_NCHW, out_batch, + out_rows, out_cols, out_depths), + &transformed_output)); + + auto input_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat().data(), + transformed_filter.template flat().size()); + auto output_ptr = + AsDeviceMemory(transformed_output.template flat().data(), + transformed_output.template flat().size()); + + static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default + ); + CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + bool cudnn_launch_status = + stream + ->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc, + filter_ptr, conv_desc, output_desc, + &output_ptr, &scratch_allocator) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN launch failure : input shape(", input.shape().DebugString(), + ") filter shape(", filter.shape().DebugString(), ")")); + } - // Convert the output tensor back from NHWC to NCHW. - if (data_format == FORMAT_NHWC) { - functor::NCHWToNHWC()( - ctx->eigen_device(), - const_cast(transformed_output).tensor(), - output->tensor()); - } else { - *output = transformed_output; - } + // Convert the output tensor back from NHWC to NCHW. + if (data_format == FORMAT_NHWC) { + functor::NCHWToNHWC()( + ctx->eigen_device(), + const_cast(transformed_output).tensor(), + output->tensor()); + } else { + *output = transformed_output; + } } }; @@ -466,17 +466,17 @@ namespace functor { const Eigen::array, 1>& dim_pair); \ extern template struct MatMulConvFunctor; \ template <> \ - void TransformFilter::operator()( \ + void TransformFilter::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor in, \ typename TTypes::Tensor out); \ - extern template struct TransformFilter; \ + extern template struct TransformFilter; \ template <> \ - void PadInput::operator()( \ + void PadInput::operator()( \ const GPUDevice& d, typename TTypes::ConstTensor in, \ - int padding_rows_left, int padding_rows_right, int padding_cols_left, \ - int padding_cols_right, typename TTypes::Tensor out, \ - TensorFormat data_format); \ - extern template struct PadInput + const std::array& padding_left, \ + const std::array& padding_right, \ + typename TTypes::Tensor out, TensorFormat data_format); \ + extern template struct PadInput DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc new file mode 100644 index 00000000000000..ea3c4c90cb972c --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -0,0 +1,355 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_3d.h" + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/platform/stream_executor.h" +using perftools::gputools::dnn::DimIndex; +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct LaunchConvOp; + +template +struct LaunchConvOp { + static void launch(OpKernelContext* context, const Tensor& input, + const Tensor& filter, const std::array& strides, + const Padding padding, Tensor* output) { + functor::CuboidConvolution()( + context->eigen_device(), output->tensor(), + input.tensor(), filter.tensor(), strides[0], strides[1], + strides[2], BrainPadding2EigenPadding(padding)); + } +}; + +template +class Conv3DOp : public BinaryOp { + public: + explicit Conv3DOp(OpKernelConstruction* context) : BinaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, (stride_[0] == 1 && stride_[4] == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_z, in_y, in_x, in_channels ] + const Tensor& input = context->input(0); + + // Input filter is of the following dimensions: + // [ filter_z, filter_y, filter_x, in_channels, out_channels] + const Tensor& filter = context->input(1); + + // NOTE: The ordering of the spatial dimensions is arbitrary, but has to be + // kept consistent between input/filter/output. + OP_REQUIRES(context, input.dims() == 5, + errors::InvalidArgument("input must be 5-dimensional")); + OP_REQUIRES(context, filter.dims() == 5, + errors::InvalidArgument("filter must be 5-dimensional")); + + const int64 in_depth = input.dim_size(4); + const int64 in_batch = input.dim_size(0); + + const int64 out_depth = filter.dim_size(4); + OP_REQUIRES( + context, in_depth == filter.dim_size(3), + errors::InvalidArgument("input and filter must have the same depth")); + + std::array input_size = { + {input.dim_size(1), input.dim_size(2), input.dim_size(3)}}; + std::array filter_size = { + {filter.dim_size(0), filter.dim_size(1), filter.dim_size(2)}}; + std::array strides = {{stride_[1], stride_[2], stride_[3]}}; + std::array out, padding; + + OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides, + padding_, &out, &padding)); + + TensorShape out_shape = {in_batch, out[0], out[1], out[2], out_depth}; + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + // Return early if nothing to do. + if (out_shape.num_elements() == 0) return; + + LaunchConvOp::launch(context, input, filter, strides, padding_, + output); + } + + private: + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER( + Name("Conv3D").Device(DEVICE_CPU).TypeConstraint("T"), + Conv3DOp); + +#ifndef __ANDROID__ +REGISTER_KERNEL_BUILDER( + Name("Conv3D").Device(DEVICE_CPU).TypeConstraint("T"), + Conv3DOp); +#endif + +#if GOOGLE_CUDA + +// TODO(mjanusz): Share logic with 2d implementation as much as possible. +template +struct LaunchConvOp { + static void launch(OpKernelContext* ctx, const Tensor& input_param, + const Tensor& filter, const std::array& strides, + const Padding padding, Tensor* output) { + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + Tensor input = input_param; + + const int64 in_batch = input.dim_size(0); + int64 in_planes = input.dim_size(1); + int64 in_rows = input.dim_size(2); + int64 in_cols = input.dim_size(3); + const int64 in_depth = input.dim_size(4); + + const int64 filter_planes = filter.dim_size(0); + const int64 filter_rows = filter.dim_size(1); + const int64 filter_cols = filter.dim_size(2); + const int64 out_depth = filter.dim_size(4); + + int64 pad_planes = 0, pad_rows = 0, pad_cols = 0; + int64 out_planes = output->dim_size(1); + int64 out_rows = output->dim_size(2); + int64 out_cols = output->dim_size(3); + + if (padding == Padding::SAME) { + pad_planes = (out_planes - 1) * strides[0] + filter_planes - in_planes; + pad_rows = (out_rows - 1) * strides[1] + filter_rows - in_rows; + pad_cols = (out_cols - 1) * strides[2] + filter_cols - in_cols; + } + + // NOTE: This only works in NHWC. + if (filter_planes == 1 && filter_rows == 1 && filter_cols == 1 && + strides[0] == 1 && strides[1] == 1 && strides[2] == 1) { + // 1x1 filter, so call cublas directly. + const uint64 m = in_batch * in_cols * in_rows * in_planes; + const uint64 k = in_depth; + const uint64 n = out_depth; + + auto a_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(output->template flat().data(), + output->template flat().size()); + + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, + n, a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } + + if (padding == Padding::SAME) { + const bool rows_odd = (pad_rows % 2 != 0); + const bool cols_odd = (pad_cols % 2 != 0); + const bool planes_odd = (pad_planes % 2 != 0); + + // Necessary because cuDNN only supports symmetric padding. + // TODO(mjanusz): Consider making this optional? This would save some + // overhead and would work as long as an op trained this way is only + // used on GPU. + if (rows_odd || cols_odd || planes_odd) { + Tensor transformed_input; + int64 new_in_rows = in_rows + rows_odd; + int64 new_in_cols = in_cols + cols_odd; + int64 new_in_planes = in_planes + planes_odd; + + TensorShape transformed_shape( + {in_batch, new_in_planes, new_in_rows, new_in_cols, in_depth}); + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, transformed_shape, + &transformed_input)); + + functor::PadInput()( + ctx->eigen_device(), To32Bit(input_param.tensor()), + {{0, 0, 0}}, {{planes_odd, rows_odd, cols_odd}}, + To32Bit(transformed_input.tensor()), FORMAT_NHWC); + input = transformed_input; + in_rows = new_in_rows; + in_cols = new_in_cols; + in_planes = new_in_planes; + } + } + + Tensor transformed_input; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape({in_batch, in_depth, in_planes, in_rows, in_cols}), + &transformed_input)); + // input: [b, x, y, z, d] + // t_input: [b, d, x, y, z] + // NCDHW is the only format universally supported by cuDNN. + functor::NHWCToNCHW()( + ctx->eigen_device(), + const_cast(input).tensor(), + transformed_input.tensor()); + input = transformed_input; + + perftools::gputools::dnn::BatchDescriptor input_desc(3); + input_desc.set_count(in_batch) + .set_feature_map_count(in_depth) + .set_spatial_dim(DimIndex::X, in_cols) + .set_spatial_dim(DimIndex::Y, in_rows) + .set_spatial_dim(DimIndex::Z, in_planes) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc(3); + output_desc.set_count(in_batch) + .set_spatial_dim(DimIndex::X, out_cols) + .set_spatial_dim(DimIndex::Y, out_rows) + .set_spatial_dim(DimIndex::Z, out_planes) + .set_feature_map_count(out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc(3); + filter_desc.set_spatial_dim(DimIndex::X, filter_cols) + .set_spatial_dim(DimIndex::Y, filter_rows) + .set_spatial_dim(DimIndex::Z, filter_planes) + .set_input_feature_map_count(in_depth) + .set_output_feature_map_count(out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3); + conv_desc.set_filter_stride(DimIndex::X, strides[2]) + .set_filter_stride(DimIndex::Y, strides[1]) + .set_filter_stride(DimIndex::Z, strides[0]) + .set_zero_padding(DimIndex::X, pad_cols / 2) + .set_zero_padding(DimIndex::Y, pad_rows / 2) + .set_zero_padding(DimIndex::Z, pad_planes / 2); + + Tensor transformed_filter; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({out_depth, in_depth, filter_planes, + filter_rows, filter_cols}), + &transformed_filter)); + // filter: [x, y, z, in, out] + // t_filter: [out, in, x, y, z] + functor::TransformFilter()( + ctx->eigen_device(), To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + + Tensor transformed_output; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({in_batch, out_depth, out_planes, + out_rows, out_cols}), + &transformed_output)); + + auto input_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat().data(), + transformed_filter.template flat().size()); + auto output_ptr = + AsDeviceMemory(transformed_output.template flat().data(), + transformed_output.template flat().size()); + + static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default + CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + bool cudnn_launch_status = + stream + ->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc, + filter_ptr, conv_desc, output_desc, + &output_ptr, &scratch_allocator) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN launch failure : input shape(", input.shape().DebugString(), + ") filter shape(", filter.shape().DebugString(), ")")); + } + + // t_output: [b, out, x, y, z] + // output: [b, x, y, z, out] + functor::NCHWToNHWC()( + ctx->eigen_device(), + const_cast(transformed_output).tensor(), + output->tensor()); + } +}; + +// Forward declarations of the functor specializations for GPU. +// This ensures that the custom implementation is used instead of the default +// Eigen one (which is used for CPU). +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void TransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + template <> \ + void ReverseTransformFilter::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const std::array& padding_left, \ + const std::array& padding_right, \ + typename TTypes::Tensor out, TensorFormat format); + +DECLARE_GPU_SPEC(float); +#undef DECLARE_GPU_SPEC + +} // namespace functor + +// Registration of the GPU implementations. +REGISTER_KERNEL_BUILDER( + Name("Conv3D").Device(DEVICE_GPU).TypeConstraint("T"), + Conv3DOp); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index ccd983833d4ae2..5af9bc0e5bfb75 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -18,6 +18,7 @@ limitations under the License. #define EIGEN_USE_GPU #include +#include #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/conv_2d.h" @@ -30,6 +31,7 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { +// TODO(mjanusz): Move this to a shared util file. // A simple array that contains data that can be passed between CPU and GPU. template struct Array { @@ -65,6 +67,11 @@ struct Array { data[i] = DefaultValue; } } + EIGEN_STRONG_INLINE Array(const std::array& array) { + for (int i = 0; i < IndexCount; i++) { + data[i] = array[i]; + } + } T data[IndexCount]; }; @@ -78,6 +85,8 @@ struct Dimension : Array { : Base(a0, a1) {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2) : Base(a0, a1, a2) {} + EIGEN_STRONG_INLINE Dimension(const std::array& array) + : Base(array) {} }; // An index type with compile-time known size. @@ -248,25 +257,28 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input, // A Cuda custom kernel that convert input to output, given proper padding on // the left and the top. The padded value is zero. -template +template __global__ void PadInputCustomKernelNHWC(int nthreads, const T* input, - Dimension<4> input_dims, T* output, - Dimension<4> output_dims, - int padding_rows_left, - int padding_cols_left) { + Dimension input_dims, T* output, + Dimension output_dims, + Dimension padding_left) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int output_index = index; - Index<4> output_tensor_index = FlatToTensorIndex(output_index, output_dims); - - Index<4> input_tensor_index; - input_tensor_index[0] = output_tensor_index[0]; - input_tensor_index[1] = output_tensor_index[1] - padding_rows_left; - input_tensor_index[2] = output_tensor_index[2] - padding_cols_left; - input_tensor_index[3] = output_tensor_index[3]; + Index output_tensor_index = + FlatToTensorIndex(output_index, output_dims); + + Index input_tensor_index; + input_tensor_index[0] = output_tensor_index[0]; // batch + bool ok = true; + for (int i = 1; i < NDIMS - 1; i++) { + input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1]; + ok &= + (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); + } + input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1]; // channels - if (input_tensor_index[1] >= 0 && input_tensor_index[1] < input_dims[1] && - input_tensor_index[2] >= 0 && input_tensor_index[2] < input_dims[2]) { - int input_index = TensorIndexToFlat(input_tensor_index, input_dims); + if (ok) { + const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); output[output_index] = input[input_index]; } else { output[output_index] = T(0); @@ -274,25 +286,28 @@ __global__ void PadInputCustomKernelNHWC(int nthreads, const T* input, } } -template +template __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input, - Dimension<4> input_dims, T* output, - Dimension<4> output_dims, - int padding_rows_left, - int padding_cols_left) { + Dimension input_dims, T* output, + Dimension output_dims, + Dimension padding_left) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int output_index = index; - Index<4> output_tensor_index = FlatToTensorIndex(output_index, output_dims); - - Index<4> input_tensor_index; - input_tensor_index[0] = output_tensor_index[0]; - input_tensor_index[1] = output_tensor_index[1]; - input_tensor_index[2] = output_tensor_index[2] - padding_rows_left; - input_tensor_index[3] = output_tensor_index[3] - padding_cols_left; + Index output_tensor_index = + FlatToTensorIndex(output_index, output_dims); + + Index input_tensor_index; + input_tensor_index[0] = output_tensor_index[0]; // batch + input_tensor_index[1] = output_tensor_index[1]; // channels + bool ok = true; + for (int i = 2; i < NDIMS; i++) { + input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2]; + ok &= + (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); + } - if (input_tensor_index[2] >= 0 && input_tensor_index[2] < input_dims[2] && - input_tensor_index[3] >= 0 && input_tensor_index[3] < input_dims[3]) { - int input_index = TensorIndexToFlat(input_tensor_index, input_dims); + if (ok) { + const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); output[output_index] = input[input_index]; } else { output[output_index] = T(0); @@ -302,15 +317,19 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input, // A GPU helper function that converts TensorFlow filter format to Cudnn filter // format. -template -struct TransformFilter { +template +struct TransformFilter { typedef GPUDevice Device; - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out) { + void operator()(const Device& d, + typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { Dimension<3> combined_dims; - combined_dims[0] = in.dimension(0) * in.dimension(1); - combined_dims[1] = in.dimension(2); - combined_dims[2] = in.dimension(3); + combined_dims[0] = in.dimension(0); // spatial dimensions + for (int i = 1; i < NDIMS - 2; i++) { + combined_dims[0] *= in.dimension(i); + } + combined_dims[1] = in.dimension(NDIMS - 2); // input filters + combined_dims[2] = in.dimension(NDIMS - 1); // output filters CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); SwapDimension0And2InTensor3< T><<>>( @@ -319,15 +338,18 @@ struct TransformFilter { }; // Converts Cudnn filter format back to TensorFlow filter format. -template -struct ReverseTransformFilter { +template +struct ReverseTransformFilter { typedef GPUDevice Device; - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out) { + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { Dimension<3> combined_dims; - combined_dims[0] = in.dimension(0); - combined_dims[1] = in.dimension(1); - combined_dims[2] = in.dimension(2) * in.dimension(3); + combined_dims[0] = in.dimension(0); // output filters + combined_dims[1] = in.dimension(1); // input filters + combined_dims[2] = in.dimension(2); // spatial dimensions + for (int i = 3; i < NDIMS; ++i) { + combined_dims[2] *= in.dimension(i); + } CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); SwapDimension0And2InTensor3< T><<>>( @@ -337,33 +359,37 @@ struct ReverseTransformFilter { // A GPU helper function that converts input tensor to a larger output tensor, // given proper padding values. The padded value is zero. -template -struct PadInput { +template +struct PadInput { typedef GPUDevice Device; - void operator()(const Device& d, typename TTypes::ConstTensor in, - int padding_rows_left, int padding_rows_right, - int padding_cols_left, int padding_cols_right, - typename TTypes::Tensor out, TensorFormat format) { + void operator()(const Device& d, + typename TTypes::ConstTensor in, + const std::array& padding_left, + const std::array& padding_right, + typename TTypes::Tensor out, + TensorFormat format) { CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - Dimension<4> input_dims; - for (int i = 0; i < 4; i++) { + Dimension input_dims; + for (int i = 0; i < NDIMS; ++i) { input_dims[i] = in.dimension(i); } - Dimension<4> output_dims; - for (int i = 0; i < 4; i++) { + Dimension output_dims; + for (int i = 0; i < NDIMS; ++i) { output_dims[i] = out.dimension(i); } + const Dimension padding_left_dim(padding_left); + if (format == FORMAT_NHWC) { - PadInputCustomKernelNHWC< - T><<>>( + PadInputCustomKernelNHWC<<< + config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in.data(), input_dims, out.data(), - output_dims, padding_rows_left, padding_cols_left); + output_dims, padding_left_dim); } else if (format == FORMAT_NCHW) { - PadInputCustomKernelNCHW< - T><<>>( + PadInputCustomKernelNCHW<<< + config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in.data(), input_dims, out.data(), - output_dims, padding_rows_left, padding_cols_left); + output_dims, padding_left_dim); } else { LOG(FATAL) << "Invalid data format: " << format; } @@ -405,30 +431,36 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, // A GPU helper functor that converts NHWC TensorFlow data format to // NCHW format that is accepted by Cudnn. -template -struct NHWCToNCHW { +template +struct NHWCToNCHW { typedef GPUDevice Device; - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out) { + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { Dimension<3> combined_dims; - combined_dims[0] = in.dimension(0); - combined_dims[1] = in.dimension(1) * in.dimension(2); - combined_dims[2] = in.dimension(3); + combined_dims[0] = in.dimension(0); // N (batch) + combined_dims[1] = in.dimension(1); // spatial dimensions (HW) + for (int i = 2; i < NDIMS - 1; ++i) { + combined_dims[1] *= in.dimension(i); + } + combined_dims[2] = in.dimension(NDIMS - 1); // C (channels) RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); } }; // A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow // Format. -template -struct NCHWToNHWC { +template +struct NCHWToNHWC { typedef GPUDevice Device; - void operator()(const Device& d, typename TTypes::ConstTensor in, - typename TTypes::Tensor out) { + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { Dimension<3> combined_dims; - combined_dims[0] = in.dimension(0); - combined_dims[1] = in.dimension(1); - combined_dims[2] = in.dimension(2) * in.dimension(3); + combined_dims[0] = in.dimension(0); // N (batch) + combined_dims[1] = in.dimension(1); // C (channel) + combined_dims[2] = in.dimension(2); // spatial dimensions (HW) + for (int i = 3; i < NDIMS; ++i) { + combined_dims[2] *= in.dimension(i); + } RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); } }; @@ -440,17 +472,21 @@ template struct functor::ShuffleAndReverse; template struct functor::ShuffleAndReverse; -template struct functor::TransformFilter; - -template struct functor::ReverseTransformFilter; - -template struct functor::PadInput; - template struct functor::TransformDepth; -template struct functor::NHWCToNCHW; - -template struct functor::NCHWToNHWC; +// For 2d ops. +template struct functor::TransformFilter; +template struct functor::ReverseTransformFilter; +template struct functor::NHWCToNCHW; +template struct functor::NCHWToNHWC; +template struct functor::PadInput; + +// For 3d ops. +template struct functor::TransformFilter; +template struct functor::ReverseTransformFilter; +template struct functor::NHWCToNCHW; +template struct functor::NCHWToNHWC; +template struct functor::PadInput; } // namespace tensorflow diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc new file mode 100644 index 00000000000000..35bbca53e65fb4 --- /dev/null +++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc @@ -0,0 +1,216 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/cudnn_pooling_gpu.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_3d.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" + +typedef Eigen::GpuDevice GPUDevice; + +namespace tensorflow { + +#if GOOGLE_CUDA + +template +void DnnPooling3dOp::Compute( + OpKernelContext* context, + perftools::gputools::dnn::PoolingMode pooling_mode, + const std::array& window, const std::array& stride, + const std::array& padding, const Tensor& tensor_in, + Tensor* output) { + const auto in_shape = tensor_in.shape(); + const auto out_shape = output->shape(); + + const int64 in_batch = in_shape.dim_size(0); + const int64 in_features = in_shape.dim_size(4); + + Tensor transformed_input; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + {in_shape.dim_size(0), in_shape.dim_size(4), + in_shape.dim_size(1), in_shape.dim_size(2), + in_shape.dim_size(3)}, + &transformed_input)); + functor::NHWCToNCHW()(context->eigen_device(), + tensor_in.tensor(), + transformed_input.tensor()); + Tensor transformed_output; + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + {out_shape.dim_size(0), out_shape.dim_size(4), + out_shape.dim_size(1), out_shape.dim_size(2), + out_shape.dim_size(3)}, + &transformed_output)); + + perftools::gputools::dnn::PoolingDescriptor pooling_desc(3); + pooling_desc.set_pooling_mode(pooling_mode); + perftools::gputools::dnn::BatchDescriptor input_desc(3); + input_desc.set_count(in_batch) + .set_feature_map_count(in_features) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc(3); + output_desc.set_count(in_batch) + .set_feature_map_count(in_features) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + for (size_t i = 0; i < window.size(); ++i) { + const auto dim_i = static_cast(i); + pooling_desc.set_window(dim_i, window.rbegin()[i]); + pooling_desc.set_stride(dim_i, stride.rbegin()[i]); + pooling_desc.set_padding(dim_i, padding.rbegin()[i]); + input_desc.set_spatial_dim(dim_i, in_shape.dim_size(3 - i)); + output_desc.set_spatial_dim(dim_i, out_shape.dim_size(3 - i)); + } + + auto input_data = AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + auto output_data = + AsDeviceMemory(transformed_output.template flat().data(), + transformed_output.template flat().size()); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + bool status = stream + ->ThenPoolForward(pooling_desc, input_desc, input_data, + output_desc, &output_data) + .ok(); + OP_REQUIRES(context, status, + errors::Internal("cudnn PoolForward launch failed")); + + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::NCHWToNHWC()( + context->eigen_device(), + toConstTensor(transformed_output).template tensor(), + output->tensor()); +} + +template +void DnnPooling3dGradOp::Compute( + OpKernelContext* context, + perftools::gputools::dnn::PoolingMode pooling_mode, + const std::array& window, const std::array& stride, + const std::array& padding, + const std::array& output_size, const Tensor& out_backprop, + const TensorShape& tensor_in_shape, const Tensor* tensor_in, + const Tensor* tensor_out, Tensor* input_backprop) { + CHECK((pooling_mode != perftools::gputools::dnn::PoolingMode::kMaximum) || + (tensor_in && tensor_out)) + << "For MaxPoolGrad, both tensor_in and tensor_out needs to be " + "specified"; + + const int64 in_batch = tensor_in_shape.dim_size(0); + const int64 in_features = tensor_in_shape.dim_size(4); + + Tensor transformed_input; + TensorShape transformed_input_shape = { + in_batch, in_features, tensor_in_shape.dim_size(1), + tensor_in_shape.dim_size(2), tensor_in_shape.dim_size(3)}; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_input_shape, + &transformed_input)); + Tensor transformed_output; + TensorShape transformed_output_shape = { + out_backprop.dim_size(0), out_backprop.dim_size(4), + out_backprop.dim_size(1), out_backprop.dim_size(2), + out_backprop.dim_size(3)}; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_output_shape, + &transformed_output)); + Tensor transformed_input_backprop; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_input_shape, + &transformed_input_backprop)); + Tensor transformed_output_backprop; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_output_shape, + &transformed_output_backprop)); + if (tensor_in != nullptr) { + functor::NHWCToNCHW()(context->eigen_device(), + tensor_in->tensor(), + transformed_input.tensor()); + } + if (tensor_out != nullptr) { + functor::NHWCToNCHW()(context->eigen_device(), + tensor_out->tensor(), + transformed_output.tensor()); + } + functor::NHWCToNCHW()( + context->eigen_device(), out_backprop.tensor(), + transformed_output_backprop.tensor()); + + perftools::gputools::dnn::PoolingDescriptor pooling_desc(3); + pooling_desc.set_pooling_mode(pooling_mode); + + perftools::gputools::dnn::BatchDescriptor orig_output_desc(3); + orig_output_desc.set_count(in_batch) + .set_feature_map_count(in_features) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + + perftools::gputools::dnn::BatchDescriptor orig_input_desc(3); + orig_input_desc.set_count(in_batch) + .set_feature_map_count(in_features) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + + for (size_t i = 0; i < window.size(); ++i) { + const auto dim_i = static_cast(i); + pooling_desc.set_window(dim_i, window[i]); + pooling_desc.set_stride(dim_i, stride[i]); + pooling_desc.set_padding(dim_i, padding[i]); + orig_input_desc.set_spatial_dim(dim_i, tensor_in_shape.dim_size(3 - i)); + orig_output_desc.set_spatial_dim(dim_i, output_size[i]); + } + + auto orig_output_data = + AsDeviceMemory(transformed_output.template flat().data(), + transformed_output.template flat().size()); + auto orig_input_data = + AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + auto output_backprop_data = + AsDeviceMemory(transformed_output_backprop.template flat().data(), + transformed_output_backprop.template flat().size()); + auto input_backprop_data = + AsDeviceMemory(transformed_input_backprop.template flat().data(), + transformed_input_backprop.template flat().size()); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + bool status = + stream + ->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data, + orig_output_desc, orig_output_data, + output_backprop_data, &input_backprop_data) + .ok(); + OP_REQUIRES(context, status, + errors::Internal("cudnn PoolBackward launch failed")); + + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::NCHWToNHWC()( + context->eigen_device(), + toConstTensor(transformed_input_backprop).template tensor(), + input_backprop->tensor()); +} + +template class DnnPooling3dOp; +template class DnnPooling3dGradOp; + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.h b/tensorflow/core/kernels/cudnn_pooling_gpu.h new file mode 100644 index 00000000000000..2e28d69601d964 --- /dev/null +++ b/tensorflow/core/kernels/cudnn_pooling_gpu.h @@ -0,0 +1,65 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions to run 3d pooling on GPU using CuDNN. + +#ifndef TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_ +#define TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/platform/stream_executor.h" +#endif + +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { + +#if GOOGLE_CUDA + +// Runs (avg/max)pooling on GPU. +template +class DnnPooling3dOp { + public: + static void Compute(OpKernelContext* context, + perftools::gputools::dnn::PoolingMode pooling_mode, + const std::array& size, + const std::array& stride, + const std::array& padding, + const Tensor& tensor_in, Tensor* output); +}; + +// Computes the gradient of (avg/max)pooling on GPU. +template +class DnnPooling3dGradOp { + public: + static void Compute(OpKernelContext* context, + perftools::gputools::dnn::PoolingMode pooling_mode, + const std::array& window, + const std::array& stride, + const std::array& padding, + const std::array& output_size, + const Tensor& out_backprop, + const TensorShape& tensor_in_shape, + const Tensor* tensor_in, const Tensor* tensor_out, + Tensor* input_backprop); +}; + +#endif + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_ diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc index a32443b841d102..c0e939c845eacb 100644 --- a/tensorflow/core/kernels/ops_util.cc +++ b/tensorflow/core/kernels/ops_util.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/core/kernels/ops_util.h" @@ -73,6 +74,33 @@ Status Get2dOutputSizeVerbose(const int in_height, const int in_width, return Status::OK(); } +Status Get3dOutputSize(const std::array& input, + const std::array& window, + const std::array& strides, + Padding padding_type, std::array* output_ptr, + std::array* padding_ptr) { + auto& output = *output_ptr; + auto& padding = *padding_ptr; + switch (padding_type) { + case Padding::VALID: + for (size_t i = 0; i < input.size(); ++i) { + output[i] = (input[i] - window[i] + strides[i]) / strides[i]; + padding[i] = 0; + } + break; + case Padding::SAME: + for (size_t i = 0; i < input.size(); ++i) { + output[i] = (input[i] + strides[i] - 1) / strides[i]; + const int64 delta = (output[i] - 1) * strides[i] + window[i] - input[i]; + // For odd values of total padding, add more padding at the 'right' + // side of the given dimension. + padding[i] = std::max(delta / 2, 0ll); + } + break; + } + return Status::OK(); +} + Eigen::PaddingType BrainPadding2EigenPadding(Padding padding) { switch (padding) { case Padding::VALID: diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h index f27a5bc4231c04..2cdad1c415886d 100644 --- a/tensorflow/core/kernels/ops_util.h +++ b/tensorflow/core/kernels/ops_util.h @@ -18,6 +18,8 @@ limitations under the License. // This file contains utilities for various operations. +#include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" @@ -89,6 +91,19 @@ Status Get2dOutputSizeVerbose(const int in_height, const int in_width, int* new_height, int* new_width, int* pad_top, int* pad_bottom, int* pad_left, int* pad_right); +// Given an input tensor, kernel, stride and padding type, populates the 3D size +// of the output tensor and padding to be applied to the input tensor at the +// lower end of every dimension. Use for 3D convolutions, where the input data +// is padded with zeros, as well as for 3D avg/max pooling, where the input data +// is padded with invalid values that are not considered for pooling. +// +// TODO(mjanusz): Unify this with Get2dOutputSize by using a common template. +Status Get3dOutputSize(const std::array& input, + const std::array& window, + const std::array& strides, + Padding padding_type, std::array* output, + std::array* padding); + // Calculates broadcast starting index and size. For SAME padding, addition // padding could be applied to right, left, top and bottom. Depending on the // current index, input size, kernel size, stride, padding size, the starting diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc new file mode 100644 index 00000000000000..e9a95b72403bc3 --- /dev/null +++ b/tensorflow/core/kernels/pooling_ops_3d.cc @@ -0,0 +1,515 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_THREADS + +#include + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/eigen_pooling.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/padding.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/cudnn_pooling_gpu.h" +#endif +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +enum PoolingType { MAX, AVG }; + +template +struct LaunchPoolingOp; + +template +struct LaunchPoolingOp { + static void launch(OpKernelContext* context, const Tensor& tensor_in, + const std::array& window, + const std::array& stride, + const std::array& padding, Padding padding_type, + Tensor* output) { + output->tensor().device(context->eigen_device()) = + Eigen::CuboidAvgPooling(tensor_in.tensor(), window[0], window[1], + window[2], stride[0], stride[1], stride[2], + BrainPadding2EigenPadding(padding_type)); + } +}; + +template +struct LaunchPoolingOp { + static void launch(OpKernelContext* context, const Tensor& tensor_in, + const std::array& window, + const std::array& stride, + const std::array& padding, Padding padding_type, + Tensor* output) { + output->tensor().device(context->eigen_device()) = + Eigen::CuboidMaxPooling(tensor_in.tensor(), window[0], window[1], + window[2], stride[0], stride[1], stride[2], + BrainPadding2EigenPadding(padding_type)); + } +}; + +template +class Pooling3DOp : public UnaryOp { + public: + explicit Pooling3DOp(OpKernelConstruction* context) : UnaryOp(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 5, + errors::InvalidArgument("Sliding window ksize field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window stride field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES(context, ksize_[4] == 1 && stride_[4] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the depth dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + + OP_REQUIRES(context, tensor_in.dims() == 5, + errors::InvalidArgument("tensor_in must be 5-dimensional")); + const int64 depth = tensor_in.dim_size(4); + const int64 in_batch = tensor_in.dim_size(0); + + std::array input_size{ + {tensor_in.dim_size(3), tensor_in.dim_size(2), tensor_in.dim_size(1)}}; + std::array window({{ksize_[3], ksize_[2], ksize_[1]}}); + std::array stride({{stride_[3], stride_[2], stride_[1]}}); + std::array padding, out; + + OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride, + padding_, &out, &padding)); + + TensorShape out_shape({in_batch, out[2], out[1], out[0], depth}); + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + LaunchPoolingOp::launch(context, tensor_in, window, stride, + padding, padding_, output); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; +REGISTER_KERNEL_BUILDER( + Name("AvgPool3D").Device(DEVICE_CPU).TypeConstraint("T"), + Pooling3DOp); +REGISTER_KERNEL_BUILDER( + Name("MaxPool3D").Device(DEVICE_CPU).TypeConstraint("T"), + Pooling3DOp); + +template +struct LaunchMaxPooling3dGradOp; + +template +struct LaunchMaxPooling3dGradOp { + static void launch(OpKernelContext* context, const Tensor& tensor_in, + const Tensor& tensor_out, const Tensor& out_backprop, + const std::array& window, + const std::array& stride, + const std::array& out, + const std::array& padding, Tensor* output) { + output->flat().setZero(); + for (int64 p = 0; p < out_backprop.dim_size(3); ++p) { + // Calculate broadcast size for planes/rows/cols. For SAME padding, + // current index could be in the padding area, and + // p * stride_planes + window_planes + // could be beyond the input tensor's boundary. In such cases, change + // the starting index and reduce the broadcast size. + // + // The same procedure is repeated for every spatial dimension in the + // nested loops below. + int pindex, psize; + std::array input_size{{tensor_in.dim_size(3), + tensor_in.dim_size(2), + tensor_in.dim_size(1)}}; + OP_REQUIRES_OK(context, + GetBroadcastSize(p, input_size[0], window[0], stride[0], + padding[0], &pindex, &psize)); + for (int64 r = 0; r < out_backprop.dim_size(2); ++r) { + int rindex, rsize; + OP_REQUIRES_OK(context, + GetBroadcastSize(r, input_size[1], window[1], stride[1], + padding[1], &rindex, &rsize)); + for (int64 c = 0; c < out_backprop.dim_size(1); ++c) { + int cindex, csize; + OP_REQUIRES_OK( + context, GetBroadcastSize(c, input_size[2], window[2], stride[2], + padding[2], &cindex, &csize)); + TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}}; + TensorSlice dst{{0, -1}, + {cindex, csize}, + {rindex, rsize}, + {pindex, psize}, + {0, -1}}; + Eigen::DSizes src_indices; + Eigen::DSizes src_sizes; + Eigen::DSizes dst_indices; + Eigen::DSizes dst_sizes; + src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices, + &src_sizes); + dst.FillIndicesAndSizes<5>(tensor_in.shape(), &dst_indices, + &dst_sizes); + +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array bcast = {1, csize, rsize, psize, 1}; +#else + Eigen::IndexList, int, int, int, + Eigen::type2index<1> > + bcast; + bcast.set(1, csize); + bcast.set(2, rsize); + bcast.set(3, psize); +#endif + + // Slice from tensor_in. + Eigen::Tensor tensor_in_slice(dst_sizes); + tensor_in_slice.device(context->eigen_cpu_device()) = + tensor_in.tensor().slice(dst_indices, dst_sizes); + + // Slice from tensor_out. + Eigen::Tensor tensor_out_slice(src_sizes); + tensor_out_slice.device(context->eigen_cpu_device()) = + tensor_out.tensor().slice(src_indices, src_sizes); + + // Backprop slice. + Eigen::Tensor out_backprop_slice(src_sizes); + out_backprop_slice.device(context->eigen_cpu_device()) = + out_backprop.tensor().slice(src_indices, src_sizes); + + // The true backprop slice: if an element is the max, choose + // the backprop slice; otherwise set to 0. + Eigen::Tensor select_slice(dst_sizes); + Eigen::Tensor mat0(dst_sizes); + mat0.setZero(); + select_slice = + ((tensor_in_slice - tensor_out_slice.broadcast(bcast)).abs() < + tensor_in_slice.constant(1e-5)) + .select(out_backprop_slice.broadcast(bcast), mat0); + + output->tensor() + .slice(dst_indices, dst_sizes) + .device(context->eigen_cpu_device()) += select_slice; + } + } + } + } +}; + +template +class MaxPooling3dGradOp : public OpKernel { + public: + explicit MaxPooling3dGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 5, + errors::InvalidArgument("Sliding window ksize field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window stride field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES(context, ksize_[4] == 1 && stride_[4] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the depth dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + const Tensor& tensor_out = context->input(1); + const Tensor& out_backprop = context->input(2); + OP_REQUIRES(context, tensor_in.dims() == 5, + errors::InvalidArgument("tensor_in must be 5-dimensional")); + OP_REQUIRES(context, tensor_out.dims() == 5, + errors::InvalidArgument("tensor_out must be 5-dimensional")); + OP_REQUIRES(context, out_backprop.dims() == 5, + errors::InvalidArgument("out_backprop must be 5-dimensional")); + + const TensorShape& output_shape = tensor_in.shape(); + Tensor* input_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &input_backprop)); + + std::array input_size = {{output_shape.dim_size(3), + output_shape.dim_size(2), + output_shape.dim_size(1)}}; + std::array window = {{ksize_[3], ksize_[2], ksize_[1]}}; + std::array stride = {{stride_[3], stride_[2], stride_[1]}}; + std::array out, padding; + + OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride, + padding_, &out, &padding)); + LaunchMaxPooling3dGradOp::launch(context, tensor_in, tensor_out, + out_backprop, window, stride, + out, padding, input_backprop); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER( + Name("MaxPool3DGrad").Device(DEVICE_CPU).TypeConstraint("T"), + MaxPooling3dGradOp); + +template +struct LaunchAvgPooling3dGradOp; + +template +struct LaunchAvgPooling3dGradOp { + static void launch(OpKernelContext* context, + const TensorShape& tensor_in_shape, + const Tensor& out_backprop, + const std::array& window, + const std::array& stride, + const std::array& output_shape, + const std::array& padding, Tensor* output) { + output->flat().setZero(); + std::array input_size = {{tensor_in_shape.dim_size(3), + tensor_in_shape.dim_size(2), + tensor_in_shape.dim_size(1)}}; + for (int64 p = 0; p < out_backprop.dim_size(3); ++p) { + // Calculate broadcast size for planes/rows/cols. For SAME padding, + // current index could be in the padding area, and + // p * stride_planes + window_planes + // could be beyond the input tensor's boundary. In such cases, change + // the starting index and reduce the broadcast size. + // + // The same procedure is repeated for every spatial dimension in the + // nested loops below. + int pindex, psize; + OP_REQUIRES_OK(context, + GetBroadcastSize(p, input_size[0], window[0], stride[0], + padding[0], &pindex, &psize)); + for (int64 r = 0; r < out_backprop.dim_size(2); ++r) { + int rindex, rsize; + OP_REQUIRES_OK(context, + GetBroadcastSize(r, input_size[1], window[1], stride[1], + padding[1], &rindex, &rsize)); + for (int64 c = 0; c < out_backprop.dim_size(1); ++c) { + int cindex, csize; + OP_REQUIRES_OK( + context, GetBroadcastSize(c, input_size[2], window[2], stride[2], + padding[2], &cindex, &csize)); + TensorSlice src{{0, -1}, {c, 1}, {r, 1}, {p, 1}, {0, -1}}; + TensorSlice dst{{0, -1}, + {cindex, csize}, + {rindex, rsize}, + {pindex, psize}, + {0, -1}}; + Eigen::DSizes src_indices; + Eigen::DSizes src_sizes; + Eigen::DSizes dst_indices; + Eigen::DSizes dst_sizes; + src.FillIndicesAndSizes<5>(out_backprop.shape(), &src_indices, + &src_sizes); + dst.FillIndicesAndSizes<5>(tensor_in_shape, &dst_indices, &dst_sizes); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array bcast = {1, csize, rsize, psize, 1}; +#else + Eigen::IndexList, int, int, int, + Eigen::type2index<1> > + bcast; + bcast.set(1, csize); + bcast.set(2, rsize); + bcast.set(3, psize); +#endif + Eigen::Tensor slices(src_sizes); + slices.device(context->eigen_cpu_device()) = + out_backprop.tensor().slice(src_indices, src_sizes); + // Divide by the size of the actual patch (psize * rsize * csize). + float divide_size = rsize * csize * psize * 1.0f; + slices *= slices.constant(1.0f / divide_size); + + output->tensor() + .slice(dst_indices, dst_sizes) + .device(context->eigen_cpu_device()) += slices.broadcast(bcast); + } + } + } + } +}; + +template +class AvgPooling3dGradOp : public OpKernel { + public: + explicit AvgPooling3dGradOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 5, + errors::InvalidArgument("Sliding window ksize field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window stride field must " + "specify 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + OP_REQUIRES(context, ksize_[4] == 1 && stride_[4] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the depth dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in_shape = context->input(0); + const Tensor& out_backprop = context->input(1); + OP_REQUIRES(context, tensor_in_shape.dims() == 1 && + tensor_in_shape.NumElements() == 5, + errors::InvalidArgument("tensor_in must be 1-dimensional and 5 " + "elements")); + OP_REQUIRES(context, out_backprop.dims() == 5, + errors::InvalidArgument("out_backprop must be 5-dimensional")); + + TensorShape output_shape; + auto shape_vec = tensor_in_shape.vec(); + for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) { + output_shape.AddDim(shape_vec(i)); + } + + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + std::array input_size = {{output_shape.dim_size(3), + output_shape.dim_size(2), + output_shape.dim_size(1)}}; + std::array window = {{ksize_[3], ksize_[2], ksize_[1]}}; + std::array stride = {{stride_[3], stride_[2], stride_[1]}}; + std::array padding, out; + + OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride, + padding_, &out, &padding)); + + LaunchAvgPooling3dGradOp::launch(context, output_shape, + out_backprop, window, stride, + out, padding, output); + } + + private: + std::vector ksize_; + std::vector stride_; + Padding padding_; +}; + +REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .HostMemory("orig_input_shape"), + AvgPooling3dGradOp); + +#if GOOGLE_CUDA + +template +struct LaunchPoolingOp { + static void launch(OpKernelContext* context, const Tensor& tensor_in, + const std::array& window, + const std::array& stride, + const std::array& padding, Padding padding_type, + Tensor* output) { + DnnPooling3dOp::Compute(context, + perftools::gputools::dnn::PoolingMode::kAverage, + window, stride, padding, tensor_in, output); + } +}; + +template +struct LaunchPoolingOp { + static void launch(OpKernelContext* context, const Tensor& tensor_in, + const std::array& window, + const std::array& stride, + const std::array& padding, Padding padding_type, + Tensor* output) { + DnnPooling3dOp::Compute(context, + perftools::gputools::dnn::PoolingMode::kMaximum, + window, stride, padding, tensor_in, output); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AvgPool3D").Device(DEVICE_GPU).TypeConstraint("T"), + Pooling3DOp); +REGISTER_KERNEL_BUILDER( + Name("MaxPool3D").Device(DEVICE_GPU).TypeConstraint("T"), + Pooling3DOp); + +template +struct LaunchMaxPooling3dGradOp { + static void launch(OpKernelContext* context, const Tensor& tensor_in, + const Tensor& tensor_out, const Tensor& out_backprop, + const std::array& window, + const std::array& stride, + const std::array& out, + const std::array& padding, + Tensor* input_backprop) { + const TensorShape output_shape = tensor_in.shape(); + DnnPooling3dGradOp::Compute( + context, perftools::gputools::dnn::PoolingMode::kMaximum, window, + stride, padding, out, out_backprop, output_shape, &tensor_in, + &tensor_out, input_backprop); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("MaxPool3DGrad").Device(DEVICE_GPU).TypeConstraint("T"), + MaxPooling3dGradOp); + +template +struct LaunchAvgPooling3dGradOp { + static void launch(OpKernelContext* context, + const TensorShape& tensor_in_shape, + const Tensor& out_backprop, + const std::array& window, + const std::array& stride, + const std::array& out, + const std::array& padding, Tensor* output) { + DnnPooling3dGradOp::Compute( + context, perftools::gputools::dnn::PoolingMode::kAverage, window, + stride, padding, out, out_backprop, tensor_in_shape, nullptr, nullptr, + output); + } +}; +REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("orig_input_shape"), + AvgPooling3dGradOp); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index eecaf25c2b2365..017b789473f0c0 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -153,9 +153,9 @@ void DnnPoolingOp::Compute( ShapeFromFormat(FORMAT_NCHW, tensor_in.shape(), data_format), &transformed_input)); - functor::NHWCToNCHW()(context->eigen_device(), - tensor_in.tensor(), - transformed_input.tensor()); + functor::NHWCToNCHW()(context->eigen_device(), + tensor_in.tensor(), + transformed_input.tensor()); } else { transformed_input = tensor_in; } @@ -213,7 +213,7 @@ void DnnPoolingOp::Compute( if (data_format == FORMAT_NHWC) { /// Transform the output data from NCHW back to NHWC auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::NCHWToNHWC()( + functor::NCHWToNHWC()( context->eigen_device(), toConstTensor(transformed_output).template tensor(), tensor_out->tensor()); @@ -292,19 +292,19 @@ void DnnPoolingGradOp::Compute( // For AvgPoolGrad, the original input tensor is not necessary. However, // cudnn still requires them to run, although they do not affect the // results. - functor::NHWCToNCHW()(context->eigen_device(), - tensor_in->tensor(), - transformed_input.tensor()); + functor::NHWCToNCHW()(context->eigen_device(), + tensor_in->tensor(), + transformed_input.tensor()); } if (tensor_out) { // For AvgPoolGrad, the original output tensor is not necessary. However, // cudnn still requires them to run, although they do not affect the // results. - functor::NHWCToNCHW()(context->eigen_device(), - tensor_out->tensor(), - transformed_output.tensor()); + functor::NHWCToNCHW()(context->eigen_device(), + tensor_out->tensor(), + transformed_output.tensor()); } - functor::NHWCToNCHW()( + functor::NHWCToNCHW()( context->eigen_device(), out_backprop.tensor(), transformed_output_backprop.tensor()); } @@ -361,7 +361,7 @@ void DnnPoolingGradOp::Compute( if (data_format == FORMAT_NHWC) { /// Transform the output data from NCHW back to NHWC. auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::NCHWToNHWC()( + functor::NCHWToNHWC()( context->eigen_device(), toConstTensor(transformed_input_backprop).template tensor(), input_backprop->tensor()); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 5cb9b5fa265059..e68be084a41065 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -415,6 +415,159 @@ output: 4-D with shape the `filter` input of the convolution. )doc"); +// -------------------------------------------------------------------------- +REGISTER_OP("Conv3D") + .Input("input: T") + .Input("filter: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes a 3-D convolution given 5-D `input` and `filter` tensors. + +In signal processing, cross-correlation is a measure of similarity of +two waveforms as a function of a time-lag applied to one of them. This +is also known as a sliding dot product or sliding inner-product. + +Our Conv3D implements a form of cross-correlation. + +input: Shape `[batch, in_depth, in_height, in_width, in_channels]`. +filter: Shape `[filter_depth, filter_height, filter_width, in_channels, out_channels]`. + `in_channels` must match between `input` and `filter`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. + +)doc"); + +REGISTER_OP("Conv3DBackpropInput") + .Input("input: T") + .Input("filter: T") + .Input("out_backprop: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes the gradients of 3D convolution with respect to the input. + +input: Shape `[batch, depth, rows, cols, in_channels]`. +filter: Shape `[depth, rows, cols, in_channels, out_channels]`. + `in_channels` must match between `input` and `filter`. +out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, out_channels]`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. + +)doc"); + +REGISTER_OP("Conv3DBackpropFilter") + .Input("input: T") + .Input("filter: T") + .Input("out_backprop: T") + .Output("output: T") + .Attr("T: numbertype") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes the gradients of 3D convolution with respect to the filter. + +input: Shape `[batch, depth, rows, cols, in_channels]`. +filter: Shape `[depth, rows, cols, in_channels, out_channels]`. + `in_channels` must match between `input` and `filter`. +out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, out_channels]`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. + +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("AvgPool3D") + .Input("input: T") + .Output("output: T") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr("T: numbertype") + .Doc(R"doc( +Performs 3D average pooling on the input. + +ksize: 1-D tensor of length 5. The size of the window for each dimension of + the input tensor. Must have `ksize[0] = ksize[1] = 1`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. +input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. +output: The average pooled output tensor. +)doc"); + +REGISTER_OP("AvgPool3DGrad") + .Input("orig_input_shape: int32") + .Input("grad: T") + .Output("output: T") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr("T: numbertype") + .Doc(R"doc( +Computes gradients of average pooling function. + +ksize: 1-D tensor of length 5. The size of the window for each dimension of + the input tensor. Must have `ksize[0] = ksize[1] = 1`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. +orig_input_shape: The original input dimensions. +grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +output: The backprop for input. +)doc"); + +// -------------------------------------------------------------------------- + +REGISTER_OP("MaxPool3D") + .Input("input: T") + .Output("output: T") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr("T: numbertype") + .Doc(R"doc( +Performs 3D max pooling on the input. + +ksize: 1-D tensor of length 5. The size of the window for each dimension of + the input tensor. Must have `ksize[0] = ksize[1] = 1`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. +input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over. +output: The max pooled output tensor. +)doc"); + +REGISTER_OP("MaxPool3DGrad") + .Input("orig_input: float") + .Input("orig_output: float") + .Input("grad: T") + .Output("output: T") + .Attr("ksize: list(int) >= 5 ") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr("T: numbertype") + .Doc(R"doc( +Computes gradients of max pooling function. + +ksize: 1-D tensor of length 5. The size of the window for each dimension of + the input tensor. Must have `ksize[0] = ksize[1] = 1`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("L2Loss") diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index ee5f3703ce95f8..4115afb2b1a317 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -36,18 +36,26 @@ bool FormatFromString(const string& format_str, TensorFormat* format); string ToString(TensorFormat format); // Return the position index from a format given a dimension specification with -// a char. +// a char. The chars can be N (batch), C (channels), H (y), W (x), or +// 0 .. (NDIMS-1). +template inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { if (format == FORMAT_NHWC) { switch (dimension) { case 'N': return 0; - case 'H': + case '0': return 1; - case 'W': + case '1': return 2; - case 'C': + case '2': return 3; + case 'H': + return NDIMS - 1; + case 'W': + return NDIMS; + case 'C': + return 1 + NDIMS; default: LOG(FATAL) << "Invalid dimension: " << dimension; } @@ -57,10 +65,16 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { return 0; case 'C': return 1; - case 'H': + case '0': return 2; - case 'W': + case '1': return 3; + case '2': + return 4; + case 'H': + return NDIMS; + case 'W': + return NDIMS + 1; default: LOG(FATAL) << "Invalid dimension: " << dimension; } @@ -69,11 +83,15 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { } } +inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { + return GetTensorDimIndex<2>(format, dimension); +} + // Return the given tensor dimension from a tensor. The tensor is interpretted // using the specified format, and a dimension specification using a char. inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format, char dimension) { - int index = GetTensorDimIndex(format, dimension); + int index = GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < tensor.dims()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -86,7 +104,7 @@ inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format, // specification using a char. inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format, char dimension) { - int index = GetTensorDimIndex(format, dimension); + int index = GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < tensor_shape.dims()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -99,7 +117,7 @@ inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format, template T GetTensorDim(const std::vector& attributes, TensorFormat format, char dimension) { - int index = GetTensorDimIndex(format, dimension); + int index = GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < attributes.size()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -113,10 +131,10 @@ string GetConvnetDataFormatAttrString(); inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H, int64 W, int64 C) { std::vector dim_sizes(4); - dim_sizes[GetTensorDimIndex(format, 'N')] = N; - dim_sizes[GetTensorDimIndex(format, 'H')] = H; - dim_sizes[GetTensorDimIndex(format, 'W')] = W; - dim_sizes[GetTensorDimIndex(format, 'C')] = C; + dim_sizes[GetTensorDimIndex<2>(format, 'N')] = N; + dim_sizes[GetTensorDimIndex<2>(format, 'H')] = H; + dim_sizes[GetTensorDimIndex<2>(format, 'W')] = W; + dim_sizes[GetTensorDimIndex<2>(format, 'C')] = C; return TensorShape(dim_sizes); } @@ -128,13 +146,13 @@ inline TensorShape ShapeFromFormat(TensorFormat dst_format, return src_shape; } std::vector dim_sizes(4); - dim_sizes[GetTensorDimIndex(dst_format, 'N')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'N')] = GetTensorDim(src_shape, src_format, 'N'); - dim_sizes[GetTensorDimIndex(dst_format, 'H')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'H')] = GetTensorDim(src_shape, src_format, 'H'); - dim_sizes[GetTensorDimIndex(dst_format, 'W')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'W')] = GetTensorDim(src_shape, src_format, 'W'); - dim_sizes[GetTensorDimIndex(dst_format, 'C')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'C')] = GetTensorDim(src_shape, src_format, 'C'); return TensorShape(dim_sizes); } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 05df66e4bb9356..66963145f786bb 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1119,6 +1119,7 @@ sharded_kernel_test_list = glob([ "kernel_tests/cwise_ops_test.py", "kernel_tests/embedding_ops_test.py", "kernel_tests/linalg_grad_test.py", + "kernel_tests/conv_ops_3d_test.py", ]) cpu_only_kernel_test_list = glob([ diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py new file mode 100644 index 00000000000000..a86d1b60eaaacb --- /dev/null +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -0,0 +1,420 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for 3d convolutional operations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import tensorflow as tf + + +class Conv3DTest(tf.test.TestCase): + + def _VerifyValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride, + padding, expected, use_gpu): + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x1 = [f * 1.0 for f in range(1, total_size_1 + 1)] + x2 = [f * 1.0 for f in range(1, total_size_2 + 1)] + with self.test_session(use_gpu=use_gpu) as sess: + t1 = tf.constant(x1, shape=tensor_in_sizes) + t2 = tf.constant(x2, shape=filter_in_sizes) + conv = tf.nn.conv3d(t1, + t2, [1, stride, stride, stride, 1], + padding=padding) + value = sess.run(conv) + print("expected = ", expected) + print("actual = ", value) + self.assertArrayNear(expected, value.flatten(), 1e-5) + + def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding, + expected): + self._VerifyValuesForDevice(tensor_in_sizes, + filter_in_sizes, + stride, + padding, + expected, + use_gpu=False) + self._VerifyValuesForDevice(tensor_in_sizes, + filter_in_sizes, + stride, + padding, + expected, + use_gpu=True) + + def testConv3D1x1x1Filter(self): + expected_output = [30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, + 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, + 312.0] + + # These are equivalent to the Conv2D1x1 case. + self._VerifyValues(tensor_in_sizes=[1, 2, 3, 1, 3], + filter_in_sizes=[1, 1, 1, 3, 3], + stride=1, + padding="VALID", + expected=expected_output) + self._VerifyValues(tensor_in_sizes=[1, 2, 1, 3, 3], + filter_in_sizes=[1, 1, 1, 3, 3], + stride=1, + padding="VALID", + expected=expected_output) + self._VerifyValues(tensor_in_sizes=[1, 1, 2, 3, 3], + filter_in_sizes=[1, 1, 1, 3, 3], + stride=1, + padding="VALID", + expected=expected_output) + + # Expected values computed using scipy's correlate function. + def testConv3D2x2x2Filter(self): + expected_output = [19554., 19962., 20370., 22110., 22590., 23070., 34890., + 35730., 36570., 37446., 38358., 39270., 50226., 51498., + 52770., 52782., 54126., 55470.] + # expected_shape = [1, 3, 1, 2, 5] + self._VerifyValues(tensor_in_sizes=[1, 4, 2, 3, 3], # b, z, y, x, fin + filter_in_sizes=[2, 2, 2, 3, 3], # z, y, x, fin, fout + stride=1, padding="VALID", + expected=expected_output) + + def testConv3D2x2x2FilterStride2(self): + expected_output = [19554., 19962., 20370., 50226., 51498., 52770.] + self._VerifyValues(tensor_in_sizes=[1, 4, 2, 3, 3], + filter_in_sizes=[2, 2, 2, 3, 3], + stride=2, + padding="VALID", + expected=expected_output) + + def testConv3DStride3(self): + expected_output = [ + 36564., 38022., 39480., 37824., 39354., 40884., 39084., 40686., 42288., + 46644., 48678., 50712., 47904., 50010., 52116., 49164., 51342., 53520., + 107124., 112614., 118104., 108384., 113946., 119508., 109644., 115278., + 120912., 117204., 123270., 129336., 118464., 124602., 130740., 119724., + 125934., 132144. + ] + self._VerifyValues(tensor_in_sizes=[1, 6, 7, 8, 2], + filter_in_sizes=[3, 2, 1, 2, 3], + stride=3, + padding="VALID", + expected=expected_output) + + def testConv3D2x2x2FilterStride2Same(self): + expected_output = [ + 19554., 19962., 20370., 10452., 10710., 10968., 50226., 51498., 52770., + 23844., 24534., 25224. + ] + self._VerifyValues(tensor_in_sizes=[1, 4, 2, 3, 3], + filter_in_sizes=[2, 2, 2, 3, 3], + stride=2, + padding="SAME", + expected=expected_output) + + def testKernelSmallerThanStride(self): + expected_output = [1., 3., 7., 9., 19., 21., 25., 27.] + self._VerifyValues(tensor_in_sizes=[1, 3, 3, 3, 1], + filter_in_sizes=[1, 1, 1, 1, 1], + stride=2, + padding="SAME", + expected=expected_output) + self._VerifyValues(tensor_in_sizes=[1, 3, 3, 3, 1], + filter_in_sizes=[1, 1, 1, 1, 1], + stride=2, + padding="VALID", + expected=expected_output) + + expected_output = [1484., 1592., 770., + 2240., 2348., 1106., + 1149., 1191., 539., + + 6776., 6884., 3122., + 7532., 7640., 3458., + 3207., 3249., 1421., + + 3005., 3035., 1225., + 3215., 3245., 1309., + 1013., 1022., 343.] + self._VerifyValues(tensor_in_sizes=[1, 7, 7, 7, 1], + filter_in_sizes=[2, 2, 2, 1, 1], + stride=3, + padding="SAME", + expected=expected_output) + + expected_output = [1484., 1592., + 2240., 2348., + + 6776., 6884., + 7532., 7640.] + self._VerifyValues(tensor_in_sizes=[1, 7, 7, 7, 1], + filter_in_sizes=[2, 2, 2, 1, 1], + stride=3, + padding="VALID", + expected=expected_output) + + def ConstructAndTestGradient(self, batch, input_planes, input_rows, + input_cols, filter_planes, filter_rows, + filter_cols, in_depth, out_depth, stride, + padding, test_input, use_gpu): + input_shape = [batch, input_planes, input_rows, input_cols, in_depth] + filter_shape = [filter_planes, filter_rows, filter_cols, in_depth, + out_depth] + if padding == "VALID": + output_planes = int(math.ceil((input_planes - filter_planes + 1.0) / + stride)) + output_rows = int(math.ceil((input_rows - filter_rows + 1.0) / stride)) + output_cols = int(math.ceil((input_cols - filter_cols + 1.0) / stride)) + else: + output_planes = int(math.ceil(float(input_planes) / stride)) + output_rows = int(math.ceil(float(input_rows) / stride)) + output_cols = int(math.ceil(float(input_cols) / stride)) + output_shape = [batch, output_planes, output_rows, output_cols, out_depth] + input_size = 1 + for x in input_shape: + input_size *= x + filter_size = 1 + for x in filter_shape: + filter_size *= x + input_data = [x * 1.0 / input_size for x in range(0, input_size)] + filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)] + if use_gpu: + data_type = tf.float32 + tolerance = 4e-3 + else: + data_type = tf.float64 + tolerance = 1e-8 + with self.test_session(use_gpu=use_gpu): + input_tensor = tf.constant(input_data, + shape=input_shape, + dtype=data_type, + name="input") + filter_tensor = tf.constant(filter_data, + shape=filter_shape, + dtype=data_type, + name="filter") + conv = tf.nn.conv3d(input_tensor, + filter_tensor, [1, stride, stride, stride, 1], + padding, + name="conv") + + if test_input: + err = tf.test.compute_gradient_error(input_tensor, input_shape, conv, + output_shape) + else: + err = tf.test.compute_gradient_error(filter_tensor, filter_shape, conv, + output_shape) + print("conv3d gradient error = ", err) + self.assertLess(err, tolerance) + + def testInputGradientValidPaddingStrideOne(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=3, + input_rows=5, + input_cols=4, + filter_planes=3, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride=1, + padding="VALID", + test_input=True, + use_gpu=use_gpu) + + def testFilterGradientValidPaddingStrideOne(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=4, + input_planes=4, + input_rows=6, + input_cols=5, + filter_planes=2, + filter_rows=2, + filter_cols=2, + in_depth=2, + out_depth=3, + stride=1, + padding="VALID", + test_input=False, + use_gpu=use_gpu) + + def testInputGradientValidPaddingStrideTwo(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=6, + input_rows=3, + input_cols=5, + filter_planes=3, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride=2, + padding="VALID", + test_input=True, + use_gpu=use_gpu) + + def testFilterGradientValidPaddingStrideTwo(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=7, + input_rows=6, + input_cols=5, + filter_planes=2, + filter_rows=2, + filter_cols=2, + in_depth=2, + out_depth=3, + stride=2, + padding="VALID", + test_input=False, + use_gpu=use_gpu) + + def testInputGradientValidPaddingStrideThree(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=3, + input_rows=7, + input_cols=6, + filter_planes=3, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride=3, + padding="VALID", + test_input=True, + use_gpu=use_gpu) + + def testFilterGradientValidPaddingStrideThree(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=4, + input_rows=4, + input_cols=7, + filter_planes=4, + filter_rows=4, + filter_cols=4, + in_depth=2, + out_depth=3, + stride=3, + padding="VALID", + test_input=False, + use_gpu=use_gpu) + + def testInputGradientSamePaddingStrideOne(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=3, + input_rows=2, + input_cols=2, + filter_planes=3, + filter_rows=2, + filter_cols=1, + in_depth=2, + out_depth=1, + stride=1, + padding="SAME", + test_input=True, + use_gpu=use_gpu) + + def testFilterGradientSamePaddingStrideOne(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=3, + input_rows=6, + input_cols=5, + filter_planes=2, + filter_rows=2, + filter_cols=2, + in_depth=2, + out_depth=3, + stride=1, + padding="SAME", + test_input=False, + use_gpu=use_gpu) + + def testInputGradientSamePaddingStrideTwo(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=6, + input_rows=3, + input_cols=4, + filter_planes=3, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride=2, + padding="SAME", + test_input=True, + use_gpu=use_gpu) + + def testFilterGradientSamePaddingStrideTwo(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=4, + input_planes=7, + input_rows=3, + input_cols=5, + filter_planes=2, + filter_rows=2, + filter_cols=2, + in_depth=2, + out_depth=3, + stride=2, + padding="SAME", + test_input=False, + use_gpu=use_gpu) + + def testInputGradientSamePaddingStrideThree(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=9, + input_rows=3, + input_cols=6, + filter_planes=3, + filter_rows=3, + filter_cols=3, + in_depth=2, + out_depth=3, + stride=3, + padding="SAME", + test_input=True, + use_gpu=use_gpu) + + def testFilterGradientSamePaddingStrideThree(self): + for use_gpu in [False, True]: + self.ConstructAndTestGradient(batch=2, + input_planes=9, + input_rows=4, + input_cols=7, + filter_planes=4, + filter_rows=4, + filter_cols=4, + in_depth=2, + out_depth=3, + stride=3, + padding="SAME", + test_input=False, + use_gpu=use_gpu) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 6a9b9978c16aea..b59bb5f3b8ae72 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -862,14 +862,14 @@ def testShapeFunctionEdgeCases(self): # Filter larger than input. with self.assertRaisesRegexp(ValueError, - "filter must not be larger than the input"): + "Filter must not be larger than the input"): tf.nn.conv2d(tf.placeholder(tf.float32, shape=[32, 20, 20, 3]), tf.placeholder(tf.float32, shape=[20, 21, 3, 2]), strides=[1, 1, 1, 1], padding="SAME") with self.assertRaisesRegexp(ValueError, - "filter must not be larger than the input"): + "Filter must not be larger than the input"): tf.nn.conv2d(tf.placeholder(tf.float32, shape=[32, 20, 20, 3]), tf.placeholder(tf.float32, diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py new file mode 100644 index 00000000000000..e7686871e1633b --- /dev/null +++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py @@ -0,0 +1,340 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for 3d pooling operations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +class PoolingTest(tf.test.TestCase): + + def _VerifyValues(self, pool_func, input_sizes, window, strides, padding, + expected, use_gpu): + """Verifies the output values of the pooling function. + + Args: + pool_func: Function to be called: co.MaxPool, co.AvgPool. + input_sizes: Input tensor dimensions. + window: Tuple of kernel dims: planes, rows, cols. + strides: Tuple of strides for dims: planes, rows, cols. + padding: Padding type. + expected: An array containing the expected operation outputs. + use_gpu: Whether we are running on GPU. + """ + total_size = 1 + for s in input_sizes: + total_size *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x = [f * 1.0 for f in range(1, total_size + 1)] + with self.test_session(use_gpu=use_gpu) as sess: + t = tf.constant(x, shape=input_sizes) + t = pool_func(t, + ksize=[1, window[0], window[1], window[2], 1], + strides=[1, strides[0], strides[1], strides[2], 1], + padding=padding) + vals = sess.run(t) + # Verifies values. + actual = vals.flatten() + self.assertAllClose(expected, actual) + + def _testAvgPool3dValidPadding(self, use_gpu): + expected_output = [20.5, 21.5, 22.5] + self._VerifyValues(tf.nn.avg_pool3d, + input_sizes=[1, 3, 3, 3, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="VALID", + expected=expected_output, + use_gpu=use_gpu) + + def _testAvgPool3dSamePadding(self, use_gpu): + expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5] + self._VerifyValues(tf.nn.avg_pool3d, + input_sizes=[1, 2, 2, 4, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="SAME", + expected=expected_output, + use_gpu=use_gpu) + + def _testMaxPool3dValidPadding(self, use_gpu): + expected_output = [40.0, 41.0, 42.0] + self._VerifyValues(tf.nn.max_pool3d, + input_sizes=[1, 3, 3, 3, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="VALID", + expected=expected_output, + use_gpu=use_gpu) + + def _testMaxPool3dSamePadding(self, use_gpu): + expected_output = [31., 32., 33., 34., 35., 36.] + self._VerifyValues(tf.nn.max_pool3d, + input_sizes=[1, 2, 2, 3, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="SAME", + expected=expected_output, + use_gpu=use_gpu) + + def testAvgPooling3d(self): + for use_gpu in [False, True]: + self._testAvgPool3dValidPadding(use_gpu) + self._testAvgPool3dSamePadding(use_gpu) + + def testMaxPooling3d(self): + for use_gpu in [False, True]: + self._testMaxPool3dValidPadding(use_gpu) + self._testMaxPool3dSamePadding(use_gpu) + + def testKernelSmallerThanStride(self): + for use_gpu in [True, False]: + self._VerifyValues(tf.nn.max_pool3d, input_sizes=[1, 3, 3, 3, 1], + window=[1, 1, 1], strides=[2, 2, 2], + padding="SAME", + expected=[1, 3, 7, 9, 19, 21, 25, 27], + use_gpu=use_gpu) + + self._VerifyValues(tf.nn.max_pool3d, input_sizes=[1, 7, 7, 7, 1], + window=[2, 2, 2], strides=[3, 3, 3], + padding="VALID", + expected=[58, 61, 79, 82, 205, 208, 226, 229], + use_gpu=use_gpu) + + self._VerifyValues(tf.nn.avg_pool3d, input_sizes=[1, 3, 3, 3, 1], + window=[1, 1, 1], strides=[2, 2, 2], + padding="SAME", + expected=[1, 3, 7, 9, 19, 21, 25, 27], + use_gpu=use_gpu) + + self._VerifyValues(tf.nn.avg_pool3d, input_sizes=[1, 7, 7, 7, 1], + window=[2, 2, 2], strides=[3, 3, 3], + padding="VALID", + expected=[29.5, 32.5, 50.5, 53.5, + 176.5, 179.5, 197.5, 200.5], + use_gpu=use_gpu) + + def _ConstructAndTestGradient(self, + pool_func, + input_sizes, + output_sizes, + window, + strides, + padding, + x_init_value=None, + use_gpu=False): + """Verifies the gradients of the avg pooling function. + + Args: + pool_func: Function to be called, co.MaxPool, co.AvgPool, + or the Lua version. + input_sizes: Input tensor dimensions. + output_sizes: Output tensor dimensions. + window: Tuple of kernel dims: planes, rows, cols. + strides: Tuple of strides for dims: planes, rows, cols. + padding: Padding type. + x_init_value: Values to be passed to the gradient checker. + use_gpu: Whether to run pooling on GPU. + """ + total_size = 1 + for s in input_sizes: + total_size *= s + # Initializes the input tensor with array containing incrementing + # numbers from 1. + x = [f * 1.0 for f in range(1, total_size + 1)] + with self.test_session(use_gpu=use_gpu): + input_tensor = tf.constant(x, shape=input_sizes, name="input") + err_margin = 1e-3 + if pool_func == tf.nn.avg_pool3d: + func_name = "avg_pool3d" + else: + if x_init_value is None: + x_init_value = np.asfarray( + np.arange(1, total_size + 1), + dtype=np.float32).reshape(input_sizes) + func_name = "max_pool3d" + + t = pool_func(input_tensor, + ksize=[1, window[0], window[1], window[2], 1], + strides=[1, strides[0], strides[1], strides[2], 1], + padding=padding, + name=func_name) + + err = tf.test.compute_gradient_error(input_tensor, + input_sizes, + t, + output_sizes, + x_init_value=x_init_value, + delta=1e-2) + print("%s gradient error = " % func_name, err) + self.assertLess(err, err_margin) + + def testMaxPoolGradValidPadding1_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[1, 3, 3, 3, 1], + output_sizes=[1, 3, 3, 3, 1], + window=(1, 1, 1), + strides=(1, 1, 1), + padding="VALID", + use_gpu=use_gpu) + + def testMaxPoolGradValidPadding2_1_6_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[2, 3, 3, 6, 3], + output_sizes=[2, 2, 2, 5, 3], + window=(2, 2, 2), + strides=(1, 1, 1), + padding="VALID", + use_gpu=use_gpu) + + def testMaxPoolGradValidPadding2_1_7_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[2, 3, 5, 7, 3], + output_sizes=[2, 2, 4, 6, 3], + window=(2, 2, 2), + strides=(1, 1, 1), + padding="VALID", + use_gpu=use_gpu) + + def testMaxPoolGradValidPadding2_2_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[2, 2, 2, 2, 3], + output_sizes=[2, 1, 1, 1, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="VALID", + use_gpu=use_gpu) + + def testMaxPoolGradSamePadding1_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[2, 3, 2, 4, 1], + output_sizes=[2, 3, 2, 4, 1], + window=(1, 1, 1), + strides=(1, 1, 1), + padding="SAME", + use_gpu=use_gpu) + + def testMaxPoolGradSamePadding2_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[2, 3, 2, 4, 1], + output_sizes=[2, 3, 2, 4, 1], + window=(2, 2, 2), + strides=(1, 1, 1), + padding="SAME", + use_gpu=use_gpu) + + def testMaxPoolGradSamePadding2_2_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[2, 5, 2, 4, 3], + output_sizes=[2, 3, 1, 2, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="SAME", + use_gpu=use_gpu) + + def testMaxPoolGradSamePadding3_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.max_pool3d, + input_sizes=[1, 3, 3, 7, 1], + output_sizes=[1, 3, 3, 7, 1], + window=(3, 3, 3), + strides=(1, 1, 1), + padding="SAME", + use_gpu=use_gpu) + + def testAvgPoolGradValidPadding1_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.avg_pool3d, + input_sizes=[2, 3, 3, 3, 3], + output_sizes=[2, 3, 3, 3, 3], + window=(1, 1, 1), + strides=(1, 1, 1), + padding="VALID", + use_gpu=use_gpu) + + def testAvgPoolGradValidPadding2_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.avg_pool3d, + input_sizes=[2, 3, 3, 3, 3], + output_sizes=[2, 2, 2, 2, 3], + window=(2, 2, 2), + strides=(1, 1, 1), + padding="VALID", + use_gpu=use_gpu) + + def testAvgPoolGradValidPadding2_2_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.avg_pool3d, + input_sizes=[2, 2, 2, 2, 3], + output_sizes=[2, 1, 1, 1, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="VALID", + use_gpu=use_gpu) + + def testAvgPoolGradSamePadding1_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.avg_pool3d, + input_sizes=[2, 3, 2, 4, 3], + output_sizes=[2, 3, 2, 4, 3], + window=(1, 1, 1), + strides=(1, 1, 1), + padding="SAME", + use_gpu=use_gpu) + + def testAvgPoolGradSamePadding2_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.avg_pool3d, + input_sizes=[1, 2, 2, 2, 1], + output_sizes=[1, 2, 2, 2, 1], + window=(2, 2, 2), + strides=(1, 1, 1), + padding="SAME", + use_gpu=use_gpu) + + def testAvgPoolGradSamePadding2_2_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.avg_pool3d, + input_sizes=[2, 5, 2, 4, 3], + output_sizes=[2, 3, 1, 2, 3], + window=(2, 2, 2), + strides=(2, 2, 2), + padding="SAME", + use_gpu=use_gpu) + + def testAvgPoolGradSamePadding3_1_3d(self): + for use_gpu in (False, True): + self._ConstructAndTestGradient(tf.nn.avg_pool3d, + input_sizes=[1, 3, 6, 7, 1], + output_sizes=[1, 3, 6, 7, 1], + window=(3, 3, 3), + strides=(1, 1, 1), + padding="SAME", + use_gpu=use_gpu) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 9165748579d4ab..7e0993bba3ec81 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -902,12 +902,12 @@ def testShapeFunctionEdgeCases(self): for pool_func in [tf.nn.max_pool, tf.nn.avg_pool, tf.nn.max_pool_with_argmax]: with self.assertRaisesRegexp(ValueError, - "filter must not be larger than the input"): + "Filter must not be larger than the input"): pool_func(tf.placeholder(tf.float32, shape=[32, 20, 20, 3]), ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME") with self.assertRaisesRegexp(ValueError, - "filter must not be larger than the input"): + "Filter must not be larger than the input"): pool_func(tf.placeholder(tf.float32, shape=[32, 20, 20, 3]), ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME") diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py index db5ed6c5513800..d746cb1ab3a824 100644 --- a/tensorflow/python/ops/common_shapes.py +++ b/tensorflow/python/ops/common_shapes.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """A library of common shape functions.""" from __future__ import absolute_import from __future__ import division @@ -41,8 +40,10 @@ def unchanged_shape_with_rank(rank): A shape function for ops that output a tensor of the same size as their input, with a particular rank. """ + def _ShapeFunction(op): return [op.inputs[0].get_shape().with_rank(rank)] + return _ShapeFunction @@ -56,8 +57,10 @@ def unchanged_shape_with_rank_at_least(rank): A shape function for ops that output a tensor of the same size as their input, with a particular rank. """ + def _ShapeFunction(op): return [op.inputs[0].get_shape().with_rank_at_least(rank)] + return _ShapeFunction @@ -71,8 +74,10 @@ def unchanged_shape_with_rank_at_most(rank): A shape function for ops that output a tensor of the same size as their input, with a particular rank. """ + def _ShapeFunction(op): return [op.inputs[0].get_shape().with_rank_at_most(rank)] + return _ShapeFunction @@ -103,12 +108,11 @@ def bias_add_shape(op): data_format = None if data_format == b"NCHW": # Merge the length of bias_shape into the third-to-last dimension. - output_shape = input_shape[0:-3].concatenate( - input_shape[-3].merge_with(bias_shape[0])).concatenate( - input_shape[-2:]) + output_shape = input_shape[0:-3].concatenate(input_shape[-3].merge_with( + bias_shape[0])).concatenate(input_shape[-2:]) else: - output_shape = input_shape[0:-1].concatenate( - input_shape[-1].merge_with(bias_shape[0])) + output_shape = input_shape[0:-1].concatenate(input_shape[-1].merge_with( + bias_shape[0])) else: output_shape = tensor_shape.unknown_shape() return [output_shape] @@ -130,47 +134,54 @@ def bias_add_grad_shape(op): return [output_shape] -def get2d_conv_output_size(input_height, input_width, filter_height, - filter_width, row_stride, col_stride, padding_type): - """Returns the number of rows and columns in a convolution/pooling output.""" - input_height = tensor_shape.as_dimension(input_height) - input_width = tensor_shape.as_dimension(input_width) - filter_height = tensor_shape.as_dimension(filter_height) - filter_width = tensor_shape.as_dimension(filter_width) - row_stride = int(row_stride) - col_stride = int(col_stride) - - if filter_height.value == 1 and filter_width.value == 1 and ( - row_stride == 1 and col_stride == 1): - return input_height, input_width +def get_conv_output_size(input_size, filter_size, strides, padding_type): + """Returns the spatial size of a n-d convolution/pooling output.""" + input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size]) + filter_size = tuple([tensor_shape.as_dimension(x).value for x in filter_size]) + strides = [int(x) for x in strides] + + if all(x == 1 for x in input_size) and all(x == 1 for x in filter_size): + return input_size + + if any(x is not None and y is not None and x > y for x, y in + zip(filter_size, input_size)): + raise ValueError("Filter must not be larger than the input: " + "Filter: %r Input: %r" % (filter_size, input_size)) + + if padding_type == b"VALID": + + def _valid(in_dim, k_dim, s_dim): + if in_dim is not None and k_dim is not None: + return (in_dim - k_dim + s_dim) // s_dim + else: + return None + + output_size = [ + _valid(in_dim, k_dim, s_dim) + for in_dim, k_dim, s_dim in zip(input_size, filter_size, strides) + ] + elif padding_type == b"SAME": + + def _same(in_dim, s_dim): + if in_dim is not None: + return (in_dim + s_dim - 1) // s_dim + else: + return None + + output_size = [_same(in_dim, s_dim) + for in_dim, s_dim in zip(input_size, strides)] else: - if filter_height > input_height or filter_width > input_width: - raise ValueError( - "filter must not be larger than the input: " - "Filter: [%sx%s] Input: [%sx%s]" - % (filter_height, filter_width, input_height, input_width)) - - # Compute number of rows in the output, based on the padding. - if input_height.value is None or filter_height.value is None: - out_rows = None - elif padding_type == b"VALID": - out_rows = ((input_height.value - filter_height.value + row_stride) // - row_stride) - elif padding_type == b"SAME": - out_rows = (input_height.value + row_stride - 1) // row_stride - else: - raise ValueError("Invalid value for padding: %r" % padding_type) + raise ValueError("Invalid padding: %r" % padding_type) + + return tuple(output_size) - # Compute number of columns in the output, based on the padding. - if input_width.value is None or filter_width.value is None: - out_cols = None - elif padding_type == b"VALID": - out_cols = ((input_width.value - filter_width.value + col_stride) // - col_stride) - elif padding_type == b"SAME": - out_cols = (input_width.value + col_stride - 1) // col_stride - return out_rows, out_cols +def get2d_conv_output_size(input_height, input_width, filter_height, + filter_width, row_stride, col_stride, padding_type): + """Returns the number of rows and columns in a convolution/pooling output.""" + return get_conv_output_size((input_height, input_width), + (filter_height, filter_width), + (row_stride, col_stride), padding_type) def conv2d_shape(op): @@ -230,8 +241,9 @@ def conv2d_shape(op): # information in the input to be ignored. This will require a change # in the kernel implementation. padding = op.get_attr("padding") - out_rows, out_cols = get2d_conv_output_size( - in_rows, in_cols, filter_rows, filter_cols, stride_r, stride_c, padding) + out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, + filter_cols, stride_r, stride_c, + padding) output_shape = [batch_size, out_rows, out_cols, depth_out] if data_format == b"NCHW": @@ -290,8 +302,9 @@ def depthwise_conv2d_native_shape(op): # in the kernel implementation. stride = stride_r padding = op.get_attr("padding") - out_rows, out_cols = get2d_conv_output_size( - in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding) + out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, + filter_cols, stride, stride, + padding) return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] @@ -352,8 +365,9 @@ def separable_conv2d_shape(op): # in the kernel implementation. stride = stride_r padding = op.get_attr("padding") - out_rows, out_cols = get2d_conv_output_size( - in_rows, in_cols, filter_rows, filter_cols, stride, stride, padding) + out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, filter_rows, + filter_cols, stride, stride, + padding) return [tensor_shape.TensorShape([batch_size, out_rows, out_cols, depth_out])] @@ -414,8 +428,9 @@ def avg_pool_shape(op): # in the kernel implementation. padding = op.get_attr("padding") - out_rows, out_cols = get2d_conv_output_size( - in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding) + out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r, + ksize_c, stride_r, stride_c, + padding) output_shape = [batch_size, out_rows, out_cols, depth] if data_format == b"NCHW": @@ -485,8 +500,9 @@ def max_pool_shape(op): # in the kernel implementation. if ksize_d == 1: padding = op.get_attr("padding") - out_rows, out_cols = get2d_conv_output_size( - in_rows, in_cols, ksize_r, ksize_c, stride_r, stride_c, padding) + out_rows, out_cols = get2d_conv_output_size(in_rows, in_cols, ksize_r, + ksize_c, stride_r, stride_c, + padding) output_shape = [batch_size, out_rows, out_cols, depth] else: if depth % ksize_d > 0: diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index c911e33fe35625..cf26be31e91efe 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -108,6 +108,7 @@ @@separable_conv2d @@atrous_conv2d @@conv2d_transpose +@@conv3d ## Pooling @@ -127,6 +128,8 @@ @@avg_pool @@max_pool @@max_pool_with_argmax +@@avg_pool3d +@@max_pool3d ## Normalization diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 9dac828896587d..188d936c0ec3be 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Gradients for operators defined in nn_ops.py.""" from __future__ import absolute_import @@ -40,14 +39,48 @@ def _Conv2DBackpropGrad(op, grad): the gradients w.r.t. the input and the filter """ return [None, - nn_ops.conv2d_backprop_filter( - grad, array_ops.shape(op.inputs[1]), op.inputs[2], - op.get_attr("strides"), op.get_attr("padding"), - op.get_attr("use_cudnn_on_gpu"), op.get_attr("data_format")), - nn_ops.conv2d( - grad, op.inputs[1], op.get_attr("strides"), - op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"), - op.get_attr("data_format"))] + nn_ops.conv2d_backprop_filter(grad, array_ops.shape(op.inputs[1]), + op.inputs[2], op.get_attr("strides"), + op.get_attr("padding"), + op.get_attr("use_cudnn_on_gpu"), + op.get_attr("data_format")), + nn_ops.conv2d(grad, op.inputs[1], op.get_attr("strides"), + op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"), + op.get_attr("data_format"))] + + +@ops.RegisterGradient("Conv3D") +def _Conv3DGrad(op, grad): + return [nn_ops.conv3d_backprop_input(op.inputs[0], + op.inputs[1], + grad, + strides=op.get_attr("strides"), + padding=op.get_attr("padding")), + nn_ops.conv3d_backprop_filter(op.inputs[0], + op.inputs[1], + grad, + strides=op.get_attr("strides"), + padding=op.get_attr("padding"))] + + +@ops.RegisterGradient("AvgPool3D") +def _AvgPool3DGrad(op, grad): + return nn_ops.avg_pool3d_grad( + array_ops.shape(op.inputs[0]), + grad, + ksize=op.get_attr("ksize"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding")) + + +@ops.RegisterGradient("MaxPool3D") +def _MaxPool3DGrad(op, grad): + return nn_ops.max_pool3d_grad(op.inputs[0], + op.outputs[0], + grad, + ksize=op.get_attr("ksize"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding")) @ops.RegisterGradient("Softmax") @@ -74,10 +107,8 @@ def _SoftmaxGrad(op, grad_softmax): # graph-construction time? Alternatively: do different things # depending on the dimensionality of the input tensors. softmax = op.outputs[0] - grad_x = ((grad_softmax - - array_ops.reshape(math_ops.reduce_sum(grad_softmax * softmax, [1]), - [-1, 1])) - * softmax) + grad_x = ((grad_softmax - array_ops.reshape( + math_ops.reduce_sum(grad_softmax * softmax, [1]), [-1, 1])) * softmax) return grad_x @@ -128,7 +159,8 @@ def _BiasAddGradV1(unused_bias_op, received_grad): the second one for the "bias" input of the BiasOp. """ reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1) - return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor)) + return (received_grad, math_ops.reduce_sum(received_grad, + reduction_dim_tensor)) @ops.RegisterGradient("Relu") @@ -159,8 +191,8 @@ def _SoftsignGrad(op, grad): @ops.RegisterGradient("ReluGrad") def _ReluGradGrad(op, grad): x = op.inputs[1] - return (gen_nn_ops._relu_grad(grad, x), - array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)) + return (gen_nn_ops._relu_grad(grad, x), array_ops.zeros( + shape=array_ops.shape(x), dtype=x.dtype)) def _BroadcastMul(vec, mat): @@ -196,12 +228,10 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _): @ops.RegisterGradient("Conv2D") def _Conv2DGrad(op, grad): - return [nn_ops.conv2d_backprop_input(array_ops.shape(op.inputs[0]), - op.inputs[1], grad, - op.get_attr("strides"), - op.get_attr("padding"), - op.get_attr("use_cudnn_on_gpu"), - op.get_attr("data_format")), + return [nn_ops.conv2d_backprop_input( + array_ops.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr("strides"), + op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"), + op.get_attr("data_format")), nn_ops.conv2d_backprop_filter(op.inputs[0], array_ops.shape(op.inputs[1]), grad, op.get_attr("strides"), @@ -228,28 +258,30 @@ def _LRNGrad(op, grad): bias = op.get_attr("bias") alpha = op.get_attr("alpha") beta = op.get_attr("beta") - return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], - depth_radius, bias, alpha, beta)] + return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius, + bias, alpha, beta)] @ops.RegisterGradient("AvgPool") def _AvgPoolGrad(op, grad): - return gen_nn_ops._avg_pool_grad(array_ops.shape(op.inputs[0]), grad, - op.get_attr("ksize"), - op.get_attr("strides"), - op.get_attr("padding"), - data_format=op.get_attr("data_format") - ) + return gen_nn_ops._avg_pool_grad( + array_ops.shape(op.inputs[0]), + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + op.get_attr("padding"), + data_format=op.get_attr("data_format")) @ops.RegisterGradient("MaxPool") def _MaxPoolGrad(op, grad): - return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad, + return gen_nn_ops._max_pool_grad(op.inputs[0], + op.outputs[0], + grad, op.get_attr("ksize"), op.get_attr("strides"), padding=op.get_attr("padding"), - data_format=op.get_attr("data_format") - ) + data_format=op.get_attr("data_format")) @ops.RegisterGradient("BatchNormWithGlobalNormalization") @@ -328,5 +360,4 @@ def _TopKGrad(op, grad, _): array_ops.reshape(grad, [-1]), validate_indices=False), in_shape), array_ops.zeros( - [], - dtype=dtypes.int32)] + [], dtype=dtypes.int32)] diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 661dc790d298c2..68fc9a364fb8f5 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Wrappers for primitive Neural Net (NN) Operations.""" # pylint: disable=invalid-name @@ -37,7 +36,6 @@ from tensorflow.python.ops.gen_nn_ops import * # pylint: enable=wildcard-import - # Aliases for some automatically-generated names. local_response_normalization = gen_nn_ops.lrn @@ -153,12 +151,13 @@ def atrous_conv2d(value, filters, rate, padding, name=None): "value's input channels does not match filters' input channels, " "{} != {}".format(value_shape[3], filter_shape[2])) if rate < 1: - raise ValueError( - "rate {} cannot be less than one".format(rate)) + raise ValueError("rate {} cannot be less than one".format(rate)) if rate == 1: - value = gen_nn_ops.conv2d(input=value, filter=filters, - strides=[1, 1, 1, 1], padding=padding) + value = gen_nn_ops.conv2d(input=value, + filter=filters, + strides=[1, 1, 1, 1], + padding=padding) return value # We have two padding contributions. The first is used for converting "SAME" @@ -201,20 +200,30 @@ def atrous_conv2d(value, filters, rate, padding, name=None): # The paddings argument to space_to_batch includes both padding components space_to_batch_pad = [[pad_top, pad_bottom + pad_bottom_extra], [pad_left, pad_right + pad_right_extra]] - value = array_ops.space_to_batch( - input=value, paddings=space_to_batch_pad, block_size=rate) + value = array_ops.space_to_batch(input=value, + paddings=space_to_batch_pad, + block_size=rate) - value = gen_nn_ops.conv2d(input=value, filter=filters, strides=[1, 1, 1, 1], - padding="VALID", name=name) + value = gen_nn_ops.conv2d(input=value, + filter=filters, + strides=[1, 1, 1, 1], + padding="VALID", + name=name) # The crops argument to batch_to_space is just the extra padding component batch_to_space_crop = [[0, pad_bottom_extra], [0, pad_right_extra]] - value = array_ops.batch_to_space( - input=value, crops=batch_to_space_crop, block_size=rate) + value = array_ops.batch_to_space(input=value, + crops=batch_to_space_crop, + block_size=rate) return value -def conv2d_transpose(value, filter, output_shape, strides, padding="SAME", + +def conv2d_transpose(value, + filter, + output_shape, + strides, + padding="SAME", name=None): """The transpose of `conv2d`. @@ -248,9 +257,9 @@ def conv2d_transpose(value, filter, output_shape, strides, padding="SAME", value = ops.convert_to_tensor(value, name="value") filter = ops.convert_to_tensor(filter, name="filter") if not value.get_shape()[3].is_compatible_with(filter.get_shape()[3]): - raise ValueError( - "input channels does not match filter's input channels, " - "{} != {}".format(value.get_shape()[3], filter.get_shape()[3])) + raise ValueError("input channels does not match filter's input channels, " + "{} != {}".format(value.get_shape()[3], filter.get_shape( + )[3])) output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)): @@ -302,8 +311,8 @@ def bias_add(value, bias, data_format=None, name=None): bias = ops.convert_to_tensor(bias, dtype=value.dtype, name="bias") return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name) -ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape) +ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape) ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape) @@ -338,7 +347,6 @@ def bias_add_v1(value, bias, name=None): ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape) - ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape) @@ -490,7 +498,9 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """ with ops.op_scope([value], name, "AvgPool") as name: value = ops.convert_to_tensor(value, name="input") - return gen_nn_ops._avg_pool(value, ksize=ksize, strides=strides, + return gen_nn_ops._avg_pool(value, + ksize=ksize, + strides=strides, padding=padding, data_format=data_format, name=name) @@ -515,7 +525,9 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """ with ops.op_scope([value], name, "MaxPool") as name: value = ops.convert_to_tensor(value, name="input") - return gen_nn_ops._max_pool(value, ksize=ksize, strides=strides, + return gen_nn_ops._max_pool(value, + ksize=ksize, + strides=strides, padding=padding, data_format=data_format, name=name) @@ -547,7 +559,6 @@ def _BinaryElementwiseShape(op): ops.RegisterShape("L2Loss")(common_shapes.scalar_shape) - ops.RegisterShape("LRN")(common_shapes.unchanged_shape_with_rank(4)) @@ -560,12 +571,9 @@ def _LRNGradShape(op): return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)] -ops.RegisterShape("Softmax")( - common_shapes.unchanged_shape_with_rank(2)) - +ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank(2)) -ops.RegisterShape("LogSoftmax")( - common_shapes.unchanged_shape_with_rank(2)) +ops.RegisterShape("LogSoftmax")(common_shapes.unchanged_shape_with_rank(2)) @ops.RegisterShape("InTopK") @@ -744,6 +752,93 @@ def _calc_conv_weight_params(graph, node): filter_in_depth * filter_out_depth)) +@ops.RegisterShape("Conv3D") +def _Conv3DShape(op): + """Shape function for Conv3D.""" + input_shape = op.inputs[0].get_shape().with_rank(5) + filter_shape = op.inputs[1].get_shape().with_rank(5) + + batch_size = input_shape[0] + out_channels = filter_shape[4] + # Check that the input number of channels is compatible between + # input data and filter size. + input_shape[4].assert_is_compatible_with(filter_shape[3]) + + stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") + assert stride_b == 1 + assert stride_d == 1 + + padding_type = op.get_attr("padding") + out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( + input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c), + padding_type) + + return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols, + out_channels])] + + +@ops.RegisterShape("MaxPool3D") +@ops.RegisterShape("AvgPool3D") +def _Pool3DShape(op): + """Shape function for Max/AvgPool3D.""" + input_shape = op.inputs[0].get_shape().with_rank(5) + ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize") + assert ksize_b == 1 + assert ksize_d == 1 + + stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides") + assert stride_b == 1 + assert stride_d == 1 + + batch_size = input_shape[0] + channels = input_shape[4] + + padding = op.get_attr("padding") + out_planes, out_rows, out_cols = common_shapes.get_conv_output_size( + input_shape[1:4], (ksize_p, ksize_r, ksize_c), + (stride_p, stride_r, stride_c), padding) + return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols, + channels])] + + +def _ShapeOrUnknown(input_shape, ndims=5): + if input_shape == None: # pylint:disable=g-equals-none + return [tensor_shape.unknown_shape(ndims=ndims)] + else: + return [input_shape] + + +@ops.RegisterShape("Conv3DBackpropFilter") +def _Conv3DBackpropFilterShape(op): + """Shape function for the Conv3DBackpropFilter op.""" + filter_shape = op.inputs[1].get_shape() + return _ShapeOrUnknown(filter_shape) + + +@ops.RegisterShape("Conv3DBackpropInput") +def _Conv3DBackpropInputShape(op): + """Shape function for the Conv3DBackpropInput op.""" + input_shape = op.inputs[0].get_shape() + return _ShapeOrUnknown(input_shape) + + +@ops.RegisterShape("AvgPool3DGrad") +def _AvgPool3DGradShape(op): + """Shape function for the AvgPool3DGrad op.""" + orig_input_shape = tensor_util.constant_value(op.inputs[0]) + if orig_input_shape != None: # pylint:disable=g-equals-none + return [tensor_shape.TensorShape(orig_input_shape.tolist())] + else: + return [tensor_shape.unknown_shape(ndims=5)] + + +@ops.RegisterShape("MaxPool3DGrad") +def _MaxPool3DGradShape(op): + """Shape function for the MaxPoolGrad op.""" + orig_input_shape = op.inputs[0].get_shape().with_rank(5) + return [orig_input_shape] + + @ops.RegisterStatistics("BiasAdd", "flops") def _calc_bias_add_flops(graph, node): """Calculates the computing needed for BiasAdd.""" @@ -846,15 +941,17 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): if isinstance(keep_prob, float) and not 0 < keep_prob <= 1: raise ValueError("keep_prob must be a scalar tensor or a float in the " "range (0, 1], got %g" % keep_prob) - keep_prob = ops.convert_to_tensor( - keep_prob, dtype=x.dtype, name="keep_prob") + keep_prob = ops.convert_to_tensor(keep_prob, + dtype=x.dtype, + name="keep_prob") keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) # uniform [keep_prob, 1.0 + keep_prob) random_tensor = keep_prob - random_tensor += random_ops.random_uniform( - noise_shape, seed=seed, dtype=x.dtype) + random_tensor += random_ops.random_uniform(noise_shape, + seed=seed, + dtype=x.dtype) # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) binary_tensor = math_ops.floor(random_tensor) ret = x * math_ops.inv(keep_prob) * binary_tensor @@ -890,5 +987,4 @@ def top_k(input, k=1, sorted=True, name=None): """ return gen_nn_ops._top_kv2(input, k=k, sorted=sorted, name=name) - # pylint: enable=invalid-name