From 283782a2dc3301f4d9e8dee65654eb97c698d635 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 2 Jun 2016 19:38:26 -0800 Subject: [PATCH] Morphological filtering operations: Dilation and erosion. Change: 123935296 --- tensorflow/core/kernels/BUILD | 14 + tensorflow/core/kernels/dilation_ops.cc | 491 ++++++++++++++++ tensorflow/core/kernels/dilation_ops.h | 66 +++ .../core/kernels/dilation_ops_gpu.cu.cc | 304 ++++++++++ tensorflow/core/ops/nn_ops.cc | 93 +++ .../kernel_tests/morphological_ops_test.py | 541 ++++++++++++++++++ tensorflow/python/ops/nn.py | 40 ++ tensorflow/python/ops/nn_grad.py | 12 + tensorflow/python/ops/nn_ops.py | 138 +++++ 9 files changed, 1699 insertions(+) create mode 100644 tensorflow/core/kernels/dilation_ops.cc create mode 100644 tensorflow/core/kernels/dilation_ops.h create mode 100644 tensorflow/core/kernels/dilation_ops_gpu.cu.cc create mode 100644 tensorflow/python/kernel_tests/morphological_ops_test.py diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 2e4a84d5a17d97..70d195d052df5c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1227,6 +1227,7 @@ tf_kernel_libraries( ":conv_ops", ":depthwise_conv_grad_op", ":depthwise_conv_op", + ":dilation_ops", ":ops_util", ":pooling_ops", "//tensorflow/core:framework", @@ -1347,6 +1348,19 @@ cc_library( ], ) +tf_kernel_library( + name = "dilation_ops", + prefix = "dilation_ops", + deps = [ + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:nn_ops_op_lib", + "//third_party/eigen3", + ], +) + tf_kernel_library( name = "batchtospace_op", prefix = "batchtospace_op", diff --git a/tensorflow/core/kernels/dilation_ops.cc b/tensorflow/core/kernels/dilation_ops.cc new file mode 100644 index 00000000000000..673c9696f20b23 --- /dev/null +++ b/tensorflow/core/kernels/dilation_ops.cc @@ -0,0 +1,491 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#define EIGEN_USE_THREADS + +#include +#include + +#include "tensorflow/core/kernels/dilation_ops.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +void ParseAttributes(OpKernelConstruction* context, std::vector* strides, + std::vector* rates, Padding* padding) { + OP_REQUIRES_OK(context, context->GetAttr("strides", strides)); + OP_REQUIRES(context, strides->size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES(context, (*strides)[0] == 1 && (*strides)[3] == 1, + errors::Unimplemented( + "Stride is only supported across spatial dimensions.")); + + OP_REQUIRES_OK(context, context->GetAttr("rates", rates)); + OP_REQUIRES(context, rates->size() == 4, + errors::InvalidArgument("Input stride (atrous rate) field " + "must specify 4 dimensions")); + OP_REQUIRES(context, (*rates)[0] == 1 && (*rates)[3] == 1, + errors::Unimplemented( + "Rate is only supported across spatial dimensions.")); + + OP_REQUIRES_OK(context, context->GetAttr("padding", padding)); +} + +void ParseSizes(OpKernelContext* context, const std::vector& strides, + const std::vector& rates, const Padding& padding, + int* stride_rows, int* stride_cols, int* rate_rows, + int* rate_cols, int* pad_top, int* pad_left, int* out_rows, + int* out_cols) { + // Input tensor is of the following dimensions: + // [ batch, input_rows, input_cols, depth ] + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().DebugString())); + const int input_rows = input.dim_size(1); + const int input_cols = input.dim_size(2); + const int depth = input.dim_size(3); + + // For now we take the stride and rate from the second and third dimensions + // only (we do not support striding on the batch or depth dimension). + *stride_rows = strides[1]; + *stride_cols = strides[2]; + *rate_rows = rates[1]; + *rate_cols = rates[2]; + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, depth ] + const Tensor& filter = context->input(1); + OP_REQUIRES(context, filter.dims() == 3, + errors::InvalidArgument("filter must be 3-dimensional: ", + filter.shape().DebugString())); + const int filter_rows = filter.dim_size(0); + const int filter_cols = filter.dim_size(1); + OP_REQUIRES( + context, depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + depth, " vs ", filter.dim_size(2))); + + // Effective filter size, after introducing rate - 1 zeros between each + // non-zero filter element. + const int filter_rows_eff = + filter_rows + (filter_rows - 1) * (*rate_rows - 1); + const int filter_cols_eff = + filter_cols + (filter_cols - 1) * (*rate_cols - 1); + + int pad_bottom = 0, pad_right = 0; + OP_REQUIRES_OK(context, + Get2dOutputSizeVerbose( + input_rows, input_cols, filter_rows_eff, filter_cols_eff, + *stride_rows, *stride_cols, padding, out_rows, out_cols, + pad_top, &pad_bottom, pad_left, &pad_right)); +} + +template +class DilationOp : public OpKernel { + public: + explicit DilationOp(OpKernelConstruction* context) : OpKernel(context) { + ParseAttributes(context, &strides_, &rates_, &padding_); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter = context->input(1); + + // Determine relevant sizes from input and filters. + int stride_rows = 0, stride_cols = 0; + int rate_rows = 0, rate_cols = 0; + int pad_top = 0, pad_left = 0; + int out_rows = 0, out_cols = 0; + ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols, + &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, + &out_cols); + + // Output tensor is of the following dimensions: + // [ batch, out_rows, out_cols, depth ] + const int batch = input.dim_size(0); + const int depth = input.dim_size(3); + const std::vector out_sizes = {batch, out_rows, out_cols, depth}; + TensorShape out_shape(out_sizes); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + functor::Dilation()( + context->eigen_device(), input.tensor(), + filter.tensor(), stride_rows, stride_cols, rate_rows, rate_cols, + pad_top, pad_left, output->tensor()); + } + + std::vector strides_; + std::vector rates_; + Padding padding_; +}; + +// Partial specialization of Dilation functor for a CPUDevice. +namespace functor { +template +struct Dilation { + void operator()(const CPUDevice& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, int pad_top, + int pad_left, typename TTypes::Tensor output) { + const int batch = input.dimension(0); + const int input_rows = input.dimension(1); + const int input_cols = input.dimension(2); + const int depth = input.dimension(3); + + const int filter_rows = filter.dimension(0); + const int filter_cols = filter.dimension(1); + + const int output_rows = output.dimension(1); + const int output_cols = output.dimension(2); + + // This is a reference implementation, likely to be slow. + // TODO(gpapan): Write multi-threaded implementation. + for (int b = 0; b < batch; ++b) { + for (int h_out = 0; h_out < output_rows; ++h_out) { + int h_beg = h_out * stride_rows - pad_top; + for (int w_out = 0; w_out < output_cols; ++w_out) { + int w_beg = w_out * stride_cols - pad_left; + for (int d = 0; d < depth; ++d) { + T cur_val = Eigen::NumTraits::lowest(); + for (int h = 0; h < filter_rows; ++h) { + const int h_in = h_beg + h * rate_rows; + if (h_in >= 0 && h_in < input_rows) { + for (int w = 0; w < filter_cols; ++w) { + const int w_in = w_beg + w * rate_cols; + if (w_in >= 0 && w_in < input_cols) { + const T val = input(b, h_in, w_in, d) + filter(h, w, d); + if (val > cur_val) { + cur_val = val; + } + } + } + } + } + output(b, h_out, w_out, d) = cur_val; + } + } + } + } + } +}; +} // namespace functor + +template +class DilationBackpropInputOp : public OpKernel { + public: + explicit DilationBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + ParseAttributes(context, &strides_, &rates_, &padding_); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter = context->input(1); + const Tensor& out_backprop = context->input(2); + + // Determine relevant sizes from input and filters. + int stride_rows = 0, stride_cols = 0; + int rate_rows = 0, rate_cols = 0; + int pad_top = 0, pad_left = 0; + int out_rows = 0, out_cols = 0; + ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols, + &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, + &out_cols); + + // Verify that the incoming gradient tensor has the expected size + // [ batch, out_rows, out_cols, depth ] + const int batch = input.dim_size(0); + const int depth = input.dim_size(3); + OP_REQUIRES(context, batch == out_backprop.dim_size(0) && + out_rows == out_backprop.dim_size(1) && + out_cols == out_backprop.dim_size(2) && + depth == out_backprop.dim_size(3), + errors::InvalidArgument("out_backprop has incompatible size.")); + + // The computed in_backprop has the same dimensions as the input: + // [ batch, input_rows, input_cols, depth ] + Tensor* in_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &in_backprop)); + + // If there is nothing to compute, return. + if (input.shape().num_elements() == 0) { + return; + } + + functor::DilationBackpropInput()( + context->eigen_device(), input.tensor(), + filter.tensor(), out_backprop.tensor(), stride_rows, + stride_cols, rate_rows, rate_cols, pad_top, pad_left, + in_backprop->tensor()); + } + + std::vector strides_; + std::vector rates_; + Padding padding_; +}; + +// Partial specialization of DilationBackpropInput functor for a CPUDevice. +namespace functor { +template +struct DilationBackpropInput { + void operator()(const CPUDevice& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor in_backprop) { + const int batch = input.dimension(0); + const int input_rows = input.dimension(1); + const int input_cols = input.dimension(2); + const int depth = input.dimension(3); + + const int filter_rows = filter.dimension(0); + const int filter_cols = filter.dimension(1); + + const int output_rows = out_backprop.dimension(1); + const int output_cols = out_backprop.dimension(2); + + // Initialize gradient with all zeros. + in_backprop.setZero(); + + // This is a reference implementation, likely to be slow. + // TODO(gpapan): Write multi-threaded implementation. + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + for (int b = 0; b < batch; ++b) { + for (int h_out = 0; h_out < output_rows; ++h_out) { + int h_beg = h_out * stride_rows - pad_top; + for (int w_out = 0; w_out < output_cols; ++w_out) { + int w_beg = w_out * stride_cols - pad_left; + for (int d = 0; d < depth; ++d) { + T cur_val = Eigen::NumTraits::lowest(); + int h_in_max = (h_beg < 0) ? 0 : h_beg; + int w_in_max = (w_beg < 0) ? 0 : w_beg; + for (int h = 0; h < filter_rows; ++h) { + const int h_in = h_beg + h * rate_rows; + if (h_in >= 0 && h_in < input_rows) { + for (int w = 0; w < filter_cols; ++w) { + const int w_in = w_beg + w * rate_cols; + if (w_in >= 0 && w_in < input_cols) { + const T val = input(b, h_in, w_in, d) + filter(h, w, d); + if (val > cur_val) { + cur_val = val; + h_in_max = h_in; + w_in_max = w_in; + } + } + } + } + } + in_backprop(b, h_in_max, w_in_max, d) += + out_backprop(b, h_out, w_out, d); + } + } + } + } + } +}; +} // namespace functor + +template +class DilationBackpropFilterOp : public OpKernel { + public: + explicit DilationBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context) { + ParseAttributes(context, &strides_, &rates_, &padding_); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& filter = context->input(1); + const Tensor& out_backprop = context->input(2); + + // Determine relevant sizes from input and filters. + int stride_rows = 0, stride_cols = 0; + int rate_rows = 0, rate_cols = 0; + int pad_top = 0, pad_left = 0; + int out_rows = 0, out_cols = 0; + ParseSizes(context, strides_, rates_, padding_, &stride_rows, &stride_cols, + &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, + &out_cols); + + // Verify that the incoming gradient tensor has the expected size + // [ batch, out_rows, out_cols, depth ] + const int batch = input.dim_size(0); + const int depth = input.dim_size(3); + OP_REQUIRES(context, batch == out_backprop.dim_size(0) && + out_rows == out_backprop.dim_size(1) && + out_cols == out_backprop.dim_size(2) && + depth == out_backprop.dim_size(3), + errors::InvalidArgument("out_backprop has incompatible size.")); + + // The computed filter_backprop has the same dimensions as the filter: + // [ batch, input_rows, input_cols, depth ] + Tensor* filter_backprop = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, filter.shape(), &filter_backprop)); + + // If there is nothing to compute, return. + if (filter.shape().num_elements() == 0) { + return; + } + + functor::DilationBackpropFilter()( + context->eigen_device(), input.tensor(), + filter.tensor(), out_backprop.tensor(), stride_rows, + stride_cols, rate_rows, rate_cols, pad_top, pad_left, + filter_backprop->tensor()); + } + + std::vector strides_; + std::vector rates_; + Padding padding_; +}; + +// Partial specialization of DilationBackpropFilter functor for a CPUDevice. +namespace functor { +template +struct DilationBackpropFilter { + void operator()(const CPUDevice& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor filter_backprop) { + const int batch = input.dimension(0); + const int input_rows = input.dimension(1); + const int input_cols = input.dimension(2); + const int depth = input.dimension(3); + + const int filter_rows = filter.dimension(0); + const int filter_cols = filter.dimension(1); + + const int output_rows = out_backprop.dimension(1); + const int output_cols = out_backprop.dimension(2); + + // Initialize gradient with all zeros. + filter_backprop.setZero(); + + // This is a reference implementation, likely to be slow. + // TODO(gpapan): Write multi-threaded implementation. + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + for (int b = 0; b < batch; ++b) { + for (int h_out = 0; h_out < output_rows; ++h_out) { + int h_beg = h_out * stride_rows - pad_top; + for (int w_out = 0; w_out < output_cols; ++w_out) { + int w_beg = w_out * stride_cols - pad_left; + for (int d = 0; d < depth; ++d) { + T cur_val = Eigen::NumTraits::lowest(); + int h_max = 0; + int w_max = 0; + for (int h = 0; h < filter_rows; ++h) { + const int h_in = h_beg + h * rate_rows; + if (h_in >= 0 && h_in < input_rows) { + for (int w = 0; w < filter_cols; ++w) { + const int w_in = w_beg + w * rate_cols; + if (w_in >= 0 && w_in < input_cols) { + const T val = input(b, h_in, w_in, d) + filter(h, w, d); + if (val > cur_val) { + cur_val = val; + h_max = h; + w_max = w; + } + } + } + } + } + filter_backprop(h_max, w_max, d) += + out_backprop(b, h_out, w_out, d); + } + } + } + } + } +}; +} // namespace functor + +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Dilation2D").Device(DEVICE_CPU).TypeConstraint("T"), \ + DilationOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DilationBackpropInputOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DilationBackpropFilterOp); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER); + +#undef REGISTER + +#if GOOGLE_CUDA + +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Dilation2D").Device(DEVICE_GPU).TypeConstraint("T"), \ + DilationOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropInput") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + DilationBackpropInputOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("Dilation2DBackpropFilter") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + DilationBackpropFilterOp); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER); + +#undef REGISTER + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dilation_ops.h b/tensorflow/core/kernels/dilation_ops.h new file mode 100644 index 00000000000000..63d386aa4643f3 --- /dev/null +++ b/tensorflow/core/kernels/dilation_ops.h @@ -0,0 +1,66 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +template +struct Dilation { + // We assume that the tensor sizes are correct. + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, int pad_top, + int pad_left, typename TTypes::Tensor output); +}; + +template +struct DilationBackpropInput { + // We assume that the tensor sizes are correct. + // To avoid storing the argmax values during forward computation, we recompute + // the argmax during backward computation, which is the reason why we provide + // filter as argument to the backward computation routine. + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor in_backprop); +}; + +template +struct DilationBackpropFilter { + // We assume that the tensor sizes are correct. + // To avoid storing the argmax values during forward computation, we recompute + // the argmax during backward computation, which is the reason why we provide + // filter as argument to the backward computation routine. + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor filter_backprop); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_ diff --git a/tensorflow/core/kernels/dilation_ops_gpu.cu.cc b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc new file mode 100644 index 00000000000000..ac0775fbefe601 --- /dev/null +++ b/tensorflow/core/kernels/dilation_ops_gpu.cu.cc @@ -0,0 +1,304 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include +#include + +#include "tensorflow/core/kernels/dilation_ops.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace { + +template +__global__ void DilationKernel(const int32 nthreads, const T* input_ptr, + const T* filter_ptr, int batch, int input_rows, + int input_cols, int depth, int filter_rows, + int filter_cols, int output_rows, + int output_cols, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, + int pad_top, int pad_left, T* output_ptr) { + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { + // out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b)) + const int d = out_idx % depth; + const int out_idx2 = out_idx / depth; + const int w_out = out_idx2 % output_cols; + const int out_idx3 = out_idx2 / output_cols; + const int h_out = out_idx3 % output_rows; + const int b = out_idx3 / output_rows; + int h_beg = h_out * stride_rows - pad_top; + int w_beg = w_out * stride_cols - pad_left; + T cur_val = Eigen::NumTraits::lowest(); + for (int h = 0; h < filter_rows; ++h) { + const int h_in = h_beg + h * rate_rows; + if (h_in >= 0 && h_in < input_rows) { + for (int w = 0; w < filter_cols; ++w) { + const int w_in = w_beg + w * rate_cols; + if (w_in >= 0 && w_in < input_cols) { + const T val = + input_ptr[d + + depth * + (w_in + input_cols * (h_in + input_rows * b))] + + filter_ptr[d + depth * (w + filter_cols * h)]; + if (val > cur_val) { + cur_val = val; + } + } + } + } + } + output_ptr[out_idx] = cur_val; + } +} + +template +__global__ void DilationBackpropInputKernel( + const int32 nthreads, const T* input_ptr, const T* filter_ptr, + const T* out_backprop_ptr, int batch, int input_rows, int input_cols, + int depth, int filter_rows, int filter_cols, int output_rows, + int output_cols, int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, T* in_backprop_ptr) { + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { + // out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b)) + const int d = out_idx % depth; + const int out_idx2 = out_idx / depth; + const int w_out = out_idx2 % output_cols; + const int out_idx3 = out_idx2 / output_cols; + const int h_out = out_idx3 % output_rows; + const int b = out_idx3 / output_rows; + int h_beg = h_out * stride_rows - pad_top; + int w_beg = w_out * stride_cols - pad_left; + T cur_val = Eigen::NumTraits::lowest(); + int h_in_max = (h_beg < 0) ? 0 : h_beg; + int w_in_max = (w_beg < 0) ? 0 : w_beg; + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + for (int h = 0; h < filter_rows; ++h) { + const int h_in = h_beg + h * rate_rows; + if (h_in >= 0 && h_in < input_rows) { + for (int w = 0; w < filter_cols; ++w) { + const int w_in = w_beg + w * rate_cols; + if (w_in >= 0 && w_in < input_cols) { + const T val = + input_ptr[d + + depth * + (w_in + input_cols * (h_in + input_rows * b))] + + filter_ptr[d + depth * (w + filter_cols * h)]; + if (val > cur_val) { + cur_val = val; + h_in_max = h_in; + w_in_max = w_in; + } + } + } + } + } + CudaAtomicAdd( + in_backprop_ptr + d + + depth * (w_in_max + input_cols * (h_in_max + input_rows * b)), + out_backprop_ptr[out_idx]); + } +} + +template +__global__ void DilationBackpropFilterKernel( + const int32 nthreads, const T* input_ptr, const T* filter_ptr, + const T* out_backprop_ptr, int batch, int input_rows, int input_cols, + int depth, int filter_rows, int filter_cols, int output_rows, + int output_cols, int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, T* filter_backprop_ptr) { + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) { + // out_idx = d + depth * (w_out + output_cols * (h_out + output_rows * b)) + const int d = out_idx % depth; + const int out_idx2 = out_idx / depth; + const int w_out = out_idx2 % output_cols; + const int out_idx3 = out_idx2 / output_cols; + const int h_out = out_idx3 % output_rows; + const int b = out_idx3 / output_rows; + int h_beg = h_out * stride_rows - pad_top; + int w_beg = w_out * stride_cols - pad_left; + T cur_val = Eigen::NumTraits::lowest(); + int h_max = 0; + int w_max = 0; + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + for (int h = 0; h < filter_rows; ++h) { + const int h_in = h_beg + h * rate_rows; + if (h_in >= 0 && h_in < input_rows) { + for (int w = 0; w < filter_cols; ++w) { + const int w_in = w_beg + w * rate_cols; + if (w_in >= 0 && w_in < input_cols) { + const T val = + input_ptr[d + + depth * + (w_in + input_cols * (h_in + input_rows * b))] + + filter_ptr[d + depth * (w + filter_cols * h)]; + if (val > cur_val) { + cur_val = val; + h_max = h; + w_max = w; + } + } + } + } + } + CudaAtomicAdd( + filter_backprop_ptr + d + depth * (w_max + filter_cols * h_max), + out_backprop_ptr[out_idx]); + } +} + +} // namespace + +namespace functor { + +template +struct Dilation { + void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, int pad_top, + int pad_left, typename TTypes::Tensor output) { + const int batch = input.dimension(0); + const int input_rows = input.dimension(1); + const int input_cols = input.dimension(2); + const int depth = input.dimension(3); + + const int filter_rows = filter.dimension(0); + const int filter_cols = filter.dimension(1); + + const int output_rows = output.dimension(1); + const int output_cols = output.dimension(2); + + const int total_count = batch * output_rows * output_cols * depth; + CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d); + + DilationKernel<<>>( + config.virtual_thread_count, input.data(), filter.data(), batch, + input_rows, input_cols, depth, filter_rows, filter_cols, output_rows, + output_cols, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, + pad_left, output.data()); + } +}; + +template +struct DilationBackpropInput { + void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor in_backprop) { + const int batch = input.dimension(0); + const int input_rows = input.dimension(1); + const int input_cols = input.dimension(2); + const int depth = input.dimension(3); + + const int filter_rows = filter.dimension(0); + const int filter_cols = filter.dimension(1); + + const int output_rows = out_backprop.dimension(1); + const int output_cols = out_backprop.dimension(2); + + int total_count; + CudaLaunchConfig config; + + // Initialize in_backprop with all zeros. + total_count = batch * input_rows * input_cols * depth; + config = GetCudaLaunchConfig(total_count, d); + SetZero<<>>( + total_count, in_backprop.data()); + + // Accumulate. + total_count = batch * output_rows * output_cols * depth; + config = GetCudaLaunchConfig(total_count, d); + DilationBackpropInputKernel<<>>( + config.virtual_thread_count, input.data(), filter.data(), + out_backprop.data(), batch, input_rows, input_cols, depth, filter_rows, + filter_cols, output_rows, output_cols, stride_rows, stride_cols, + rate_rows, rate_cols, pad_top, pad_left, in_backprop.data()); + } +}; + +template +struct DilationBackpropFilter { + void operator()(const GPUDevice& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor filter_backprop) { + const int batch = input.dimension(0); + const int input_rows = input.dimension(1); + const int input_cols = input.dimension(2); + const int depth = input.dimension(3); + + const int filter_rows = filter.dimension(0); + const int filter_cols = filter.dimension(1); + + const int output_rows = out_backprop.dimension(1); + const int output_cols = out_backprop.dimension(2); + + int total_count; + CudaLaunchConfig config; + + // Initialize filter_backprop with all zeros. + total_count = filter_rows * filter_cols * depth; + config = GetCudaLaunchConfig(total_count, d); + SetZero<<>>( + total_count, filter_backprop.data()); + + // Accumulate. + total_count = batch * output_rows * output_cols * depth; + config = GetCudaLaunchConfig(total_count, d); + DilationBackpropFilterKernel<<>>( + config.virtual_thread_count, input.data(), filter.data(), + out_backprop.data(), batch, input_rows, input_cols, depth, filter_rows, + filter_cols, output_rows, output_cols, stride_rows, stride_cols, + rate_rows, rate_cols, pad_top, pad_left, filter_backprop.data()); + } +}; + +} // namespace functor + +#define DEFINE_GPU_SPECS(T) \ + template struct functor::Dilation; \ + template struct functor::DilationBackpropInput; \ + template struct functor::DilationBackpropFilter; + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); + +#undef DEFINE_GPU_SPECS + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 2f509a332351ad..fee145be538865 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -740,6 +740,99 @@ output: Gradients w.r.t. the input of `max_pool`. // -------------------------------------------------------------------------- +REGISTER_OP("Dilation2D") + .Input("input: T") + .Input("filter: T") + .Output("output: T") + .Attr("T: realnumbertype") + .Attr("strides: list(int) >= 4") + .Attr("rates: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes the grayscale dilation of 4-D `input` and 3-D `filter` tensors. + +The `input` tensor has shape `[batch, in_height, in_width, depth]` and the +`filter` tensor has shape `[filter_height, filter_width, depth]`, i.e., each +input channel is processed independently of the others with its own structuring +function. The `output` tensor has shape +`[batch, out_height, out_width, depth]`. The spatial dimensions of the output +tensor depend on the `padding` algorithm. We currently only support the default +"NHWC" `data_format`. + +In detail, the grayscale morphological 2-D dilation is the max-sum correlation +(for consistency with `conv2d`, we use unmirrored filters): + + output[b, y, x, c] = + max_{dy, dx} input[b, + strides[1] * y + rates[1] * dy, + strides[2] * x + rates[2] * dx, + c] + + filter[dy, dx, c] + +Max-pooling is a special case when the filter has size equal to the pooling +kernel size and contains all zeros. + +Duality: The dilation of `input` by the `filter` is equal to the negation of +the erosion of `-input` by the reflected `filter`. + +input: 4-D with shape `[batch, in_height, in_width, depth]`. +filter: 3-D with shape `[filter_height, filter_width, depth]`. +strides: The stride of the sliding window for each dimension of the input + tensor. Must be: `[1, stride_height, stride_width, 1]`. +rates: The input stride for atrous morphological dilation. Must be: + `[1, rate_height, rate_width, 1]`. +padding: The type of padding algorithm to use. +output: 4-D with shape `[batch, out_height, out_width, depth]`. +)doc"); + +REGISTER_OP("Dilation2DBackpropInput") + .Input("input: T") + .Input("filter: T") + .Input("out_backprop: T") + .Output("in_backprop: T") + .Attr("T: realnumbertype") + .Attr("strides: list(int) >= 4") + .Attr("rates: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes the gradient of morphological 2-D dilation with respect to the input. + +input: 4-D with shape `[batch, in_height, in_width, depth]`. +filter: 3-D with shape `[filter_height, filter_width, depth]`. +out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +in_backprop: 4-D with shape `[batch, in_height, in_width, depth]`. +strides: 1-D of length 4. The stride of the sliding window for each dimension of + the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +rates: 1-D of length 4. The input stride for atrous morphological dilation. + Must be: `[1, rate_height, rate_width, 1]`. +padding: The type of padding algorithm to use. +)doc"); + +REGISTER_OP("Dilation2DBackpropFilter") + .Input("input: T") + .Input("filter: T") + .Input("out_backprop: T") + .Output("filter_backprop: T") + .Attr("T: realnumbertype") + .Attr("strides: list(int) >= 4") + .Attr("rates: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Computes the gradient of morphological 2-D dilation with respect to the filter. + +input: 4-D with shape `[batch, in_height, in_width, depth]`. +filter: 3-D with shape `[filter_height, filter_width, depth]`. +out_backprop: 4-D with shape `[batch, out_height, out_width, depth]`. +filter_backprop: 3-D with shape `[filter_height, filter_width, depth]`. +strides: 1-D of length 4. The stride of the sliding window for each dimension of + the input tensor. Must be: `[1, stride_height, stride_width, 1]`. +rates: 1-D of length 4. The input stride for atrous morphological dilation. + Must be: `[1, rate_height, rate_width, 1]`. +padding: The type of padding algorithm to use. +)doc"); + +// -------------------------------------------------------------------------- + REGISTER_OP("Relu") .Input("features: T") .Output("activations: T") diff --git a/tensorflow/python/kernel_tests/morphological_ops_test.py b/tensorflow/python/kernel_tests/morphological_ops_test.py new file mode 100644 index 00000000000000..98562429144759 --- /dev/null +++ b/tensorflow/python/kernel_tests/morphological_ops_test.py @@ -0,0 +1,541 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for morphological filtering operations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + + +class DilationTest(tf.test.TestCase): + + def _VerifyValues(self, image, kernel, strides, rates, padding, out, use_gpu): + """Verifies the output values of the dilation function. + + Args: + image: Input tensor with shape: [batch, in_height, in_width, channels]. + kernel: Filter tensor with shape: [filter_height, filter_width, channels]. + strides: Output strides, specified as [stride_height, stride_width]. + rates: Atrous rates, specified as [rate_height, rate_width]. + padding: Padding type. + out: Expected output. + use_gpu: Whether we are running on GPU. + """ + strides = [1] + strides + [1] + rates = [1] + rates + [1] + + with self.test_session(use_gpu=use_gpu): + out_tensor = tf.nn.dilation2d( + tf.constant(image), + tf.constant(kernel), + strides=strides, + rates=rates, + padding=padding, + name="dilation2d") + self.assertAllClose(out, out_tensor.eval()) + + def _testDilationValidPadding(self, use_gpu): + # [1, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.0]]] + # [1, 1, 1, 1] + out = [[[[.5]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="VALID", + out=out, + use_gpu=use_gpu) + + def _testDilationSamePadding(self, use_gpu): + # [1, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.0]]] + # [1, 2, 2, 1] + out = [[[[.5], [.6]], [[.7], [.8]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testDilationSamePaddingDepth(self, use_gpu): + # [1, 2, 2, 3] + image = [[[[.1, .2, .0], [.2, .3, .1]], [[.3, .4, .2], [.4, .5, .3]]]] + # [2, 2, 3] + kernel = [[[.4, .5, .3], [.3, .4, .2]], [[.1, .2, .0], [.0, .1, -.1]]] + # [1, 2, 2, 3] + out = [[[[.5, .7, .3], [.6, .8, .4]], [[.7, .9, .5], [.8, 1., .6]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testDilationSamePaddingBatch(self, use_gpu): + # [2, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]], [[[.2], [.3]], [[.4], [.5]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.0]]] + # [2, 2, 2, 1] + out = [[[[.5], [.6]], [[.7], [.8]]], [[[.6], [.7]], [[.8], [.9]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testDilationValidPaddingNonSquareWindow(self, use_gpu): + # [1, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]]] + # [1, 2, 1] + kernel = [[[.4], [.3]]] + # [1, 2, 1, 1] + out = [[[[.5]], [[.7]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="VALID", + out=out, + use_gpu=use_gpu) + + def _testDilationSamePaddingRate(self, use_gpu): + # [1, 3, 3, 1] + image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.2]]] + # Because rate = 2, the effective kernel is [3, 3, 1]: + # kernel_eff = [[[.4], [.0], [.3]], + # [[.0], [.0], [.0]], + # [[.1], [.0], [.2]]] + # [1, 3, 3, 1] + out = [[[[.7], [.8], [.6]], [[1.0], [1.1], [.9]], [[.8], [.9], [.9]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[2, 2], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testDilationValidPaddingUnevenStride(self, use_gpu): + # [1, 3, 3, 1] + image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]], + [[.9], [1.0], [1.1], [1.2]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.2]]] + # [1, 2, 2, 1] + out = [[[[.8], [1.0]], [[1.2], [1.4]]]] + self._VerifyValues(image, + kernel, + strides=[1, 2], + rates=[1, 1], + padding="VALID", + out=out, + use_gpu=use_gpu) + + def testDilation(self): + for use_gpu in True, False: + self._testDilationValidPadding(use_gpu) + self._testDilationSamePadding(use_gpu) + self._testDilationSamePaddingDepth(use_gpu) + self._testDilationSamePaddingBatch(use_gpu) + self._testDilationValidPaddingNonSquareWindow(use_gpu) + self._testDilationSamePaddingRate(use_gpu) + self._testDilationValidPaddingUnevenStride(use_gpu) + + def _ConstructAndTestGradient(self, image_shape, kernel_shape, strides, rates, + padding, use_gpu): + """Verifies the gradients of the dilation function. + + Args: + image_shape: Input shape, [batch, in_height, in_width, channels]. + kernel_shape: Filter shape, [filter_height, filter_width, channels]. + strides: Output strides, specified as [stride_height, stride_width]. + rates: Atrous rates, specified as [rate_height, rate_width]. + padding: Padding type. + use_gpu: Whether we are running on GPU. + """ + assert image_shape[3] == kernel_shape[2] + + np.random.seed(1) # Make it reproducible. + image = np.random.random_sample(image_shape).astype(np.float32) + kernel = np.random.random_sample(kernel_shape).astype(np.float32) + image_init = np.random.random_sample(image_shape).astype(np.float32) + kernel_init = np.random.random_sample(kernel_shape).astype(np.float32) + + strides = [1] + strides + [1] + rates = [1] + rates + [1] + + with self.test_session(use_gpu=use_gpu): + image_tensor = tf.constant(image, shape=image_shape, name="input") + kernel_tensor = tf.constant(kernel, shape=kernel_shape, name="filter") + out_tensor = tf.nn.dilation2d(image_tensor, + kernel_tensor, + strides=strides, + rates=rates, + padding=padding, + name="dilation2d") + out_shape = out_tensor.eval().shape + + # Small delta is necessary for argmax to remain the same. + err = tf.test.compute_gradient_error([image_tensor, kernel_tensor], + [image_shape, kernel_shape], + out_tensor, + out_shape, [image_init, kernel_init], + delta=1e-3) + + print("Dilation gradient error = %f" % err) + self.assertLess(err, 1e-4) + + def _testDilationGradValidPadding_1x1x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[1, 1, 1], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + use_gpu=use_gpu) + + def _testDilationGradSamePadding_1x1x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[1, 1, 1], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testDilationGradSamePadding_1x1x2(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 2], + kernel_shape=[1, 1, 2], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testDilationGradValidPadding_2x2x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[2, 2, 1], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + use_gpu=use_gpu) + + def _testDilationGradSamePadding_2x2x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[2, 2, 1], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testDilationGradSamePaddingBatch_2x2x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[4, 3, 3, 1], + kernel_shape=[2, 2, 1], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testDilationGradSamePadding_2x2x4(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 4], + kernel_shape=[2, 2, 4], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def testDilationGrad(self): + for use_gpu in True, False: + self._testDilationGradValidPadding_1x1x1(use_gpu) + self._testDilationGradSamePadding_1x1x1(use_gpu) + self._testDilationGradSamePadding_1x1x2(use_gpu) + self._testDilationGradValidPadding_2x2x1(use_gpu) + self._testDilationGradSamePadding_2x2x1(use_gpu) + self._testDilationGradSamePaddingBatch_2x2x1(use_gpu) + self._testDilationGradSamePadding_2x2x4(use_gpu) + + +class ErosionTest(tf.test.TestCase): + + def _VerifyValues(self, image, kernel, strides, rates, padding, out, use_gpu): + """Verifies the output values of the erosion function. + + Args: + image: Input tensor with shape: [batch, in_height, in_width, channels]. + kernel: Filter tensor with shape: [filter_height, filter_width, channels]. + strides: Output strides, specified as [stride_height, stride_width]. + rates: Atrous rates, specified as [rate_height, rate_width]. + padding: Padding type. + out: Expected output. + use_gpu: Whether we are running on GPU. + """ + strides = [1] + strides + [1] + rates = [1] + rates + [1] + + with self.test_session(use_gpu=use_gpu): + out_tensor = tf.nn.erosion2d( + tf.constant(image), + tf.constant(kernel), + strides=strides, + rates=rates, + padding=padding, + name="erosion2d") + self.assertAllClose(out, out_tensor.eval()) + + def _testErosionValidPadding(self, use_gpu): + # [1, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.0]]] + # [1, 1, 1, 1] + out = [[[[.0]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="VALID", + out=out, + use_gpu=use_gpu) + + def _testErosionSamePadding(self, use_gpu): + # [1, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.0]]] + # [1, 2, 2, 1] + out = [[[[.0], [.1]], [[.3], [.4]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testErosionSamePaddingDepth(self, use_gpu): + # [1, 2, 2, 3] + image = [[[[.1, .2, .0], [.2, .3, .1]], [[.3, .4, .2], [.4, .5, .3]]]] + # [2, 2, 3] + kernel = [[[.4, .5, .3], [.3, .4, .2]], [[.1, .2, .0], [.0, .1, -.1]]] + # [1, 2, 2, 3] + out = [[[[.0, .0, .0], [.1, .1, .1]], [[.3, .3, .3], [.4, .4, .4]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testErosionSamePaddingBatch(self, use_gpu): + # [2, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]], [[[.2], [.3]], [[.4], [.5]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.0]]] + # [2, 2, 2, 1] + out = [[[[.0], [.1]], [[.3], [.4]]], [[[.1], [.2]], [[.4], [.5]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testErosionValidPaddingNonSquareWindow(self, use_gpu): + # [1, 2, 2, 1] + image = [[[[.1], [.2]], [[.3], [.4]]]] + # [1, 2, 1] + kernel = [[[.4], [.3]]] + # [1, 2, 1, 1] + out = [[[[-.2]], [[.0]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[1, 1], + padding="VALID", + out=out, + use_gpu=use_gpu) + + def _testErosionSamePaddingRate(self, use_gpu): + # [1, 3, 3, 1] + image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.2]]] + # Because rate = 2, the effective kernel is [3, 3, 1]: + # kernel_eff = [[[.4], [.0], [.3]], + # [[.0], [.0], [.0]], + # [[.1], [.0], [.2]]] + # [1, 3, 3, 1] + out = [[[[.1], [.1], [.2]], [[0.1], [-.1], [.0]], [[.4], [.2], [.3]]]] + self._VerifyValues(image, + kernel, + strides=[1, 1], + rates=[2, 2], + padding="SAME", + out=out, + use_gpu=use_gpu) + + def _testErosionValidPaddingUnevenStride(self, use_gpu): + # [1, 3, 3, 1] + image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]], + [[.9], [1.0], [1.1], [1.2]]]] + # [2, 2, 1] + kernel = [[[.4], [.3]], [[.1], [.2]]] + # [1, 2, 2, 1] + out = [[[[-.1], [.1]], [[.3], [.5]]]] + self._VerifyValues(image, + kernel, + strides=[1, 2], + rates=[1, 1], + padding="VALID", + out=out, + use_gpu=use_gpu) + + def testErosion(self): + for use_gpu in True, False: + self._testErosionValidPadding(use_gpu) + self._testErosionSamePadding(use_gpu) + self._testErosionSamePaddingDepth(use_gpu) + self._testErosionSamePaddingBatch(use_gpu) + self._testErosionValidPaddingNonSquareWindow(use_gpu) + self._testErosionSamePaddingRate(use_gpu) + self._testErosionValidPaddingUnevenStride(use_gpu) + + def _ConstructAndTestGradient(self, image_shape, kernel_shape, strides, rates, + padding, use_gpu): + """Verifies the gradients of the erosion function. + + Args: + image_shape: Input shape, [batch, in_height, in_width, channels]. + kernel_shape: Filter shape, [filter_height, filter_width, channels]. + strides: Output strides, specified as [stride_height, stride_width]. + rates: Atrous rates, specified as [rate_height, rate_width]. + padding: Padding type. + use_gpu: Whether we are running on GPU. + """ + assert image_shape[3] == kernel_shape[2] + + np.random.seed(1) # Make it reproducible. + image = np.random.random_sample(image_shape).astype(np.float32) + kernel = np.random.random_sample(kernel_shape).astype(np.float32) + image_init = np.random.random_sample(image_shape).astype(np.float32) + kernel_init = np.random.random_sample(kernel_shape).astype(np.float32) + + strides = [1] + strides + [1] + rates = [1] + rates + [1] + + with self.test_session(use_gpu=use_gpu): + image_tensor = tf.constant(image, shape=image_shape, name="input") + kernel_tensor = tf.constant(kernel, shape=kernel_shape, name="filter") + out_tensor = tf.nn.erosion2d(image_tensor, + kernel_tensor, + strides=strides, + rates=rates, + padding=padding, + name="erosion2d") + out_shape = out_tensor.eval().shape + + # Small delta is necessary for argmax to remain the same. + err = tf.test.compute_gradient_error([image_tensor, kernel_tensor], + [image_shape, kernel_shape], + out_tensor, + out_shape, [image_init, kernel_init], + delta=1e-3) + + print("Erosion gradient error = %f" % err) + self.assertLess(err, 1e-4) + + def _testErosionGradValidPadding_1x1x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[1, 1, 1], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + use_gpu=use_gpu) + + def _testErosionGradSamePadding_1x1x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[1, 1, 1], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testErosionGradSamePadding_1x1x2(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 2], + kernel_shape=[1, 1, 2], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testErosionGradValidPadding_2x2x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[2, 2, 1], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + use_gpu=use_gpu) + + def _testErosionGradSamePadding_2x2x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 1], + kernel_shape=[2, 2, 1], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testErosionGradSamePaddingBatch_2x2x1(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[4, 3, 3, 1], + kernel_shape=[2, 2, 1], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def _testErosionGradSamePadding_2x2x4(self, use_gpu): + self._ConstructAndTestGradient(image_shape=[1, 3, 3, 4], + kernel_shape=[2, 2, 4], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + use_gpu=use_gpu) + + def testErosionGrad(self): + for use_gpu in True, False: + self._testErosionGradValidPadding_1x1x1(use_gpu) + self._testErosionGradSamePadding_1x1x1(use_gpu) + self._testErosionGradSamePadding_1x1x2(use_gpu) + self._testErosionGradValidPadding_2x2x1(use_gpu) + self._testErosionGradSamePadding_2x2x1(use_gpu) + self._testErosionGradSamePaddingBatch_2x2x1(use_gpu) + self._testErosionGradSamePadding_2x2x4(use_gpu) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index bfb8dc987477b6..295afb56e888e8 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -131,6 +131,46 @@ @@avg_pool3d @@max_pool3d +## Morphological filtering + +Morphological operators are non-linear filters used in image processing. + +[Greyscale morphological dilation] +(https://en.wikipedia.org/wiki/Dilation_(morphology)) is the max-sum counterpart +of standard sum-product convolution: + + output[b, y, x, c] = + max_{dy, dx} input[b, + strides[1] * y + rates[1] * dy, + strides[2] * x + rates[2] * dx, + c] + + filter[dy, dx, c] + +The `filter` is usually called structuring function. Max-pooling is a special +case of greyscale morphological dilation when the filter assumes all-zero +values (a.k.a. flat structuring function). + +[Greyscale morphological erosion] +(https://en.wikipedia.org/wiki/Erosion_(morphology)) is the min-sum counterpart +of standard sum-product convolution: + + output[b, y, x, c] = + min_{dy, dx} input[b, + strides[1] * y - rates[1] * dy, + strides[2] * x - rates[2] * dx, + c] - + filter[dy, dx, c] + +Dilation and erosion are dual to each other. The dilation of the input signal +`f` by the structuring signal `g` is equal to the negation of the erosion of +`-f` by the reflected `g`, and vice versa. + +Striding and padding is carried out in exactly the same way as in standard +convolution. Please refer to the `Convolution` section for details. + +@@dilation2d +@@erosion2d + ## Normalization Normalization is useful to prevent neurons from saturating when inputs may diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 606c2fc71a8b40..ec3b424a32a3e5 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -270,6 +270,18 @@ def _DepthwiseConv2dNativeGrad(op, grad): ] +@ops.RegisterGradient("Dilation2D") +def _Dilation2DGrad(op, grad): + return [nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad, + op.get_attr("strides"), + op.get_attr("rates"), + op.get_attr("padding")), + nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad, + op.get_attr("strides"), + op.get_attr("rates"), + op.get_attr("padding"))] + + @ops.RegisterGradient("LRN") def _LRNGrad(op, grad): depth_radius = op.get_attr("depth_radius") diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index c6df23ab2135b7..3e5be3c3fb2861 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1068,4 +1068,142 @@ def conv1d(value, filters, stride, padding, data_format=data_format) return array_ops.squeeze(result, [1]) + +@ops.RegisterShape("Dilation2D") +def _Dilation2DShape(op): + """Shape function for Dilation2D op.""" + input_shape = op.inputs[0].get_shape().with_rank(4) + filter_shape = op.inputs[1].get_shape().with_rank(3) + + batch_size = input_shape[0] + in_rows = input_shape[1] + in_cols = input_shape[2] + depth = input_shape[3] + + filter_rows = filter_shape[0] + filter_cols = filter_shape[1] + # Check that the input depths are compatible. + input_shape[3].assert_is_compatible_with(filter_shape[2]) + + stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") + if stride_b != 1 or stride_d != 1: + raise ValueError("Current implementation does not yet support " + "strides in the batch and depth dimensions.") + + rate_b, rate_r, rate_c, rate_d = op.get_attr("rates") + if rate_b != 1 or rate_d != 1: + raise ValueError("Current implementation does not yet support " + "rates in the batch and depth dimensions.") + + filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_r - 1) + filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_c - 1) + + padding = op.get_attr("padding") + out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols, + filter_rows_eff, + filter_cols_eff, + stride_r, stride_c, + padding) + + output_shape = [batch_size, out_rows, out_cols, depth] + return [tensor_shape.TensorShape(output_shape)] + + +@ops.RegisterShape("Dilation2DBackpropInput") +def _Dilation2DBackpropInputShape(op): + """Shape function for Dilation2DBackpropInput op.""" + return [op.inputs[0].get_shape()] + + +@ops.RegisterShape("Dilation2DBackpropFilter") +def _Dilation2DBackpropFilterShape(op): + """Shape function for Dilation2DBackpropFilter op.""" + return [op.inputs[1].get_shape()] + + +@ops.RegisterStatistics("Dilation2D", "flops") +def _calc_dilation2d_flops(graph, node): + """Calculates the compute resources needed for Dilation2D.""" + input_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) + input_shape.assert_is_fully_defined() + filter_shape = graph_util.tensor_shape_from_node_def_name(graph, + node.input[1]) + filter_shape.assert_is_fully_defined() + output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) + output_shape.assert_is_fully_defined() + filter_height = int(filter_shape[0]) + filter_width = int(filter_shape[1]) + output_count = np.prod(output_shape.as_list()) + return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) + + +@ops.RegisterStatistics("Dilation2D", "weight_parameters") +def _calc_dilation2d_weight_params(graph, node): + """Calculates the on-disk size of the weights for Dilation2D.""" + filter_shape = graph_util.tensor_shape_from_node_def_name(graph, + node.input[1]) + filter_shape.assert_is_fully_defined() + filter_height = int(filter_shape[0]) + filter_width = int(filter_shape[1]) + filter_depth = int(filter_shape[2]) + return ops.OpStats("weight_parameters", + (filter_height * filter_width * filter_depth)) + + +def erosion2d(value, kernel, strides, rates, padding, name=None): + """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors. + + The `value` tensor has shape `[batch, in_height, in_width, depth]` and the + `kernel` tensor has shape `[kernel_height, kernel_width, depth]`, i.e., + each input channel is processed independently of the others with its own + structuring function. The `output` tensor has shape + `[batch, out_height, out_width, depth]`. The spatial dimensions of the + output tensor depend on the `padding` algorithm. We currently only support the + default "NHWC" `data_format`. + + In detail, the grayscale morphological 2-D erosion is given by: + + output[b, y, x, c] = + min_{dy, dx} value[b, + strides[1] * y - rates[1] * dy, + strides[2] * x - rates[2] * dx, + c] - + kernel[dy, dx, c] + + Duality: The erosion of `value` by the `kernel` is equal to the negation of + the dilation of `-value` by the reflected `kernel`. + + Args: + value: A `Tensor`. 4-D with shape `[batch, in_height, in_width, depth]`. + kernel: A `Tensor`. Must have the same type as `value`. + 3-D with shape `[kernel_height, kernel_width, depth]`. + strides: A list of `ints` that has length `>= 4`. + 1-D of length 4. The stride of the sliding window for each dimension of + the input tensor. Must be: `[1, stride_height, stride_width, 1]`. + rates: A list of `ints` that has length `>= 4`. + 1-D of length 4. The input stride for atrous morphological dilation. + Must be: `[1, rate_height, rate_width, 1]`. + padding: A `string` from: `"SAME", "VALID"`. + The type of padding algorithm to use. + name: A name for the operation (optional). If not specified "erosion2d" + is used. + + Returns: + A `Tensor`. Has the same type as `value`. + 4-D with shape `[batch, out_height, out_width, depth]`. + + Raises: + ValueError: If the `value` depth does not match `kernel`' shape, or if + padding is other than `'VALID'` or `'SAME'`. + """ + with ops.op_scope([value, kernel], name, "erosion2d") as name: + # Reduce erosion to dilation by duality. + return math_ops.neg(gen_nn_ops.dilation2d(input=math_ops.neg(value), + filter=array_ops.reverse( + kernel, [True, True, False]), + strides=strides, + rates=rates, + padding=padding, + name=name)) + # pylint: enable=invalid-name