Skip to content

Commit

Permalink
Add support for 3d convolutions and pooling. CPU kernels use Eigen, G…
Browse files Browse the repository at this point in the history
…PU kernels use CuDNN.

Change: 121484787
  • Loading branch information
A. Unique TensorFlower authored and tensorflower-gardener committed May 4, 2016
1 parent e5df6ad commit 6a187cc
Show file tree
Hide file tree
Showing 25 changed files with 3,591 additions and 460 deletions.
21 changes: 20 additions & 1 deletion tensorflow/core/kernels/BUILD
Expand Up @@ -71,6 +71,16 @@ cc_library(
],
)

cc_library(
name = "conv_3d",
hdrs = ["conv_3d.h"],
deps = [
":eigen_helpers",
"//tensorflow/core:framework",
"//third_party/eigen3",
],
)

cc_library(
name = "fill_functor",
hdrs = ["fill_functor.h"],
Expand Down Expand Up @@ -1106,11 +1116,15 @@ tf_cuda_cc_test(
# conv_ops_gpu.h has be separated into its own library.
tf_kernel_library(
name = "conv_ops",
srcs = ["conv_grad_ops.cc"],
srcs = [
"conv_grad_ops.cc",
"conv_grad_ops_3d.cc",
],
prefix = "conv_ops",
deps = [
":bounds_check",
":conv_2d",
":conv_3d",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
Expand Down Expand Up @@ -1238,11 +1252,14 @@ tf_kernel_library(
name = "pooling_ops",
srcs = [
"avgpooling_op.cc",
"cudnn_pooling_gpu.cc",
"maxpooling_op.cc",
"pooling_ops_3d.cc",
"pooling_ops_common.cc",
],
hdrs = [
"avgpooling_op.h",
"cudnn_pooling_gpu.h",
"maxpooling_op.h",
"pooling_ops_common.h",
],
Expand All @@ -1257,6 +1274,8 @@ tf_kernel_library(
],
deps = [
":conv_2d",
":conv_3d",
":conv_ops",
":eigen_helpers",
":ops_util",
"//tensorflow/core:core_cpu",
Expand Down
87 changes: 52 additions & 35 deletions tensorflow/core/kernels/conv_2d.h
Expand Up @@ -116,23 +116,31 @@ struct MatMulConvFunctor {
}
};

template <typename Device, typename T, typename IndexType>
// Shuffles a filter tensor from:
// [<spatial_dims>, in, out]
// to:
// [out, in, <spatial_dims>]
template <typename Device, typename T, typename IndexType, int NDIMS>
struct TransformFilter {
void operator()(const Device& d,
typename TTypes<T, 4, IndexType>::ConstTensor in,
typename TTypes<T, 4, IndexType>::Tensor out) {
// We want a 3, 2, 0, 1 shuffle. We can merge dimensions 0 and 1 together
// to help speedup the shuffle operation.
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
typename TTypes<T, NDIMS, IndexType>::Tensor out) {
// We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
// to speed up the shuffle operation.
Eigen::DSizes<IndexType, 3> merged_dims;
merged_dims[0] = in.dimension(0) * in.dimension(1);
merged_dims[1] = in.dimension(2);
merged_dims[2] = in.dimension(3);

Eigen::DSizes<IndexType, 4> expanded_dims;
expanded_dims[0] = in.dimension(3);
expanded_dims[1] = in.dimension(2);
expanded_dims[2] = in.dimension(0);
expanded_dims[3] = in.dimension(1);
merged_dims[0] = in.dimension(0); // spatial dimensions
for (int i = 1; i < NDIMS - 2; ++i) {
merged_dims[0] *= in.dimension(i);
}
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
merged_dims[2] = in.dimension(NDIMS - 1); // output filters

Eigen::DSizes<IndexType, NDIMS> expanded_dims;
expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
for (int i = 0; i < NDIMS; ++i) { // spatial dimensions
expanded_dims[i + 2] = in.dimension(i);
}

out.device(d) = in.reshape(merged_dims)
.shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
Expand Down Expand Up @@ -194,41 +202,50 @@ struct TransformDepth {
}
};

template <typename Device, typename T, typename IndexType>
template <typename Device, typename T, typename IndexType, int NDIMS>
struct PadInput {
void operator()(const Device& d,
typename TTypes<T, 4, IndexType>::ConstTensor in,
int padding_rows_left, int padding_rows_right,
int padding_cols_left, int padding_cols_right,
typename TTypes<T, 4, IndexType>::Tensor out,
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
const std::array<int, NDIMS - 2>& padding_left,
const std::array<int, NDIMS - 2>& padding_right,
typename TTypes<T, NDIMS, IndexType>::Tensor out,
TensorFormat format) {
Eigen::array<std::pair<IndexType, IndexType>, 4> padding;
padding[GetTensorDimIndex(format, 'N')] = std::make_pair(0, 0);
padding[GetTensorDimIndex(format, 'H')] =
std::make_pair(padding_rows_left, padding_rows_right);
padding[GetTensorDimIndex(format, 'W')] =
std::make_pair(padding_cols_left, padding_cols_right);
padding[GetTensorDimIndex(format, 'C')] = std::make_pair(0, 0);
Eigen::array<std::pair<IndexType, IndexType>, NDIMS> padding;
padding[GetTensorDimIndex<NDIMS - 2>(format, 'N')] = std::make_pair(0, 0);
for (int i = 0; i < NDIMS - 2; ++i) {
padding[GetTensorDimIndex<NDIMS - 2>(format, '0' + i)] =
std::make_pair(padding_left[i], padding_right[i]);
}
padding[GetTensorDimIndex<NDIMS - 2>(format, 'C')] = std::make_pair(0, 0);
out.device(d) = in.pad(padding);
}
};

template <typename Device, typename T>
// Converts a tensor from:
// [batch, <spatial>, filters]
// to:
// [batch, filters, <spatial>]
template <typename Device, typename T, int NDIMS>
struct NHWCToNCHW {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out);
void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
typename TTypes<T, NDIMS>::Tensor out);
};

template <typename Device, typename T>
// Converts a tensor from:
// [batch, filters, <spatial>]
// to:
// [batch, <spatial>, filters]
template <typename Device, typename T, int NDIMS>
struct NCHWToNHWC {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out);
void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
typename TTypes<T, NDIMS>::Tensor out);
};

template <typename Device, typename T>
// Reverses the effect of TransformFilter above.
template <typename Device, typename T, int NDIMS>
struct ReverseTransformFilter {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out);
void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
typename TTypes<T, NDIMS>::Tensor out);
};

} // namespace functor
Expand Down
48 changes: 48 additions & 0 deletions tensorflow/core/kernels/conv_3d.h
@@ -0,0 +1,48 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Functors for 3d convolution.

#ifndef TENSORFLOW_KERNELS_CONV_3D_H_
#define TENSORFLOW_KERNELS_CONV_3D_H_

#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"

namespace tensorflow {
namespace functor {

// Applies a 3D convolution to a batch of multi-channel volumes.
template <typename Device, typename T>
struct CuboidConvolution;

typedef Eigen::ThreadPoolDevice CPUDevice;

template <typename T>
struct CuboidConvolution<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T, 5>::Tensor output,
typename TTypes<T, 5>::ConstTensor input,
typename TTypes<T, 5>::ConstTensor filter, int stride_planes,
int stride_rows, int stride_cols,
const Eigen::PaddingType& padding) {
output.device(d) = Eigen::CuboidConvolution(
input, filter, stride_planes, stride_rows, stride_cols, padding);
}
};

} // namespace functor
} // namespace tensorflow

#endif // TENSORFLOW_KERNELS_CONV_3D_H_
96 changes: 48 additions & 48 deletions tensorflow/core/kernels/conv_grad_ops.cc
Expand Up @@ -946,7 +946,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
filter_rows, filter_cols}),
&transformed_filter));

functor::TransformFilter<Device, T, int>()(
functor::TransformFilter<Device, T, int, 4>()(
context->eigen_device<Device>(), To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));

Expand All @@ -959,9 +959,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
output_cols, out_depth),
&transformed_out_backprop));

functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
out_backprop.tensor<T, 4>(),
transformed_out_backprop.tensor<T, 4>());
functor::NHWCToNCHW<Device, T, 4>()(
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
transformed_out_backprop.tensor<T, 4>());
} else {
transformed_out_backprop = out_backprop;
}
Expand Down Expand Up @@ -1022,19 +1022,19 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
&in_backprop_remove_padding));

// Remove the padding for odd rows or cols.
functor::PadInput<GPUDevice, T, int>()(
functor::PadInput<GPUDevice, T, int, 4>()(
context->template eigen_device<GPUDevice>(),
To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
.tensor<T, 4>()),
0, -rows_odd, 0, -cols_odd,
{{0, 0}}, {{-rows_odd, -cols_odd}},
To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW);

pre_transformed_in_backprop = in_backprop_remove_padding;
}

if (data_format_ == FORMAT_NHWC) {
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
functor::NCHWToNHWC<Device, T>()(
functor::NCHWToNHWC<Device, T, 4>()(
context->eigen_device<Device>(),
toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
in_backprop->tensor<T, 4>());
Expand Down Expand Up @@ -1167,9 +1167,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
input_cols + cols_odd, in_depth),
&compatible_input));

functor::PadInput<GPUDevice, T, int>()(
functor::PadInput<GPUDevice, T, int, 4>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 4>()), 0, rows_odd, 0, cols_odd,
To32Bit(input.tensor<T, 4>()), {{0, 0}}, {{rows_odd, cols_odd}},
To32Bit(compatible_input.tensor<T, 4>()), data_format_);
} else {
compatible_input = input;
Expand Down Expand Up @@ -1227,9 +1227,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
ShapeFromFormat(FORMAT_NCHW, batch, output_rows,
output_cols, out_depth),
&transformed_out_backprop));
functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
out_backprop.tensor<T, 4>(),
transformed_out_backprop.tensor<T, 4>());
functor::NHWCToNCHW<Device, T, 4>()(
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
transformed_out_backprop.tensor<T, 4>());
} else {
transformed_out_backprop = out_backprop;
}
Expand All @@ -1246,7 +1246,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
GetTensorDim(compatible_input, data_format_, 'W'),
GetTensorDim(compatible_input, data_format_, 'C')),
&transformed_input));
functor::NHWCToNCHW<Device, T>()(
functor::NHWCToNCHW<Device, T, 4>()(
context->eigen_device<Device>(),
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
transformed_input.tensor<T, 4>());
Expand Down Expand Up @@ -1284,7 +1284,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
}

auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
functor::ReverseTransformFilter<Device, T>()(
functor::ReverseTransformFilter<Device, T, 4>()(
context->eigen_device<Device>(),
toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(),
filter_backprop->tensor<T, 4>());
Expand All @@ -1301,40 +1301,40 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {

// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
const Eigen::DSizes<int, 4>& order, \
const Eigen::array<bool, 4>& reverse_dims, \
typename TTypes<T, 4, int>::Tensor output); \
extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \
template <> \
void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
const Eigen::DSizes<int, 4>& strides, \
const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \
const Eigen::DSizes<int, 4>& order, \
typename TTypes<T, 4, int>::Tensor output); \
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int>; \
template <> \
void TransformDepth<GPUDevice, T, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const Eigen::DSizes<int, 4>& shuffle, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformDepth<GPUDevice, T, int>; \
template <> \
void PadInput<GPUDevice, T, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
int padding_rows_left, int padding_rows_right, int padding_cols_left, \
int padding_cols_right, typename TTypes<T, 4, int>::Tensor out, \
TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int>;
#define DECLARE_GPU_SPEC(T) \
template <> \
void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
const Eigen::DSizes<int, 4>& order, \
const Eigen::array<bool, 4>& reverse_dims, \
typename TTypes<T, 4, int>::Tensor output); \
extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \
template <> \
void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \
const Eigen::DSizes<int, 4>& strides, \
const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \
const Eigen::DSizes<int, 4>& order, \
typename TTypes<T, 4, int>::Tensor output); \
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
void TransformDepth<GPUDevice, T, int>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const Eigen::DSizes<int, 4>& shuffle, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformDepth<GPUDevice, T, int>; \
template <> \
void PadInput<GPUDevice, T, int, 4>::operator()( \
const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
const std::array<int, 2>& padding_left, \
const std::array<int, 2>& padding_right, \
typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
extern template struct PadInput<GPUDevice, T, int, 4>;

DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
Expand Down

0 comments on commit 6a187cc

Please sign in to comment.