Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

max_pool_with_argmax GPU kernel supports include_batch_in_index #26562

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 33 additions & 50 deletions tensorflow/core/kernels/maxpooling_op.cc
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/kernels/maxpooling_op.h"

#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h"
Expand All @@ -39,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#if GOOGLE_CUDA
#include "cuda/include/cudnn.h"
Expand Down Expand Up @@ -914,13 +914,6 @@ class MaxPoolingWithArgmaxOp : public OpKernel {
"Pooling is not yet supported on the batch dimension."));
OP_REQUIRES_OK(context, context->GetAttr("include_batch_in_index",
&include_batch_in_index_));
if (context->device_type() == DeviceType(DEVICE_GPU)) {
OP_REQUIRES(context, include_batch_in_index_ == false,
errors::Unimplemented(
"include_batch_in_index=true is not yet supported "
"on the GPU kernel."));
}

TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_MAXPOOL_NANPROP", false,
&propagate_nans_));
}
Expand Down Expand Up @@ -1313,7 +1306,7 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
params.out_width, params.window_rows, params.window_cols,
params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
output->flat<T>().data(), nullptr, context->eigen_gpu_device(),
propagate_nans);
propagate_nans, false);
if (!status) {
context->SetStatus(
errors::Internal("Failed launching MaxPoolForwardNoMask"));
Expand All @@ -1326,18 +1319,14 @@ struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output, Tensor* argmax,
bool propagate_nans, bool include_batch_in_index) {
OP_REQUIRES(context, include_batch_in_index == false,
errors::Unimplemented(
"include_batch_in_index=true is not yet supported "
"on the GPU kernel."));
bool status = functor::MaxPoolForwardWithOptionalArgmax<T>()(
input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
params.tensor_in_cols, params.depth, params.out_height,
params.out_width, params.window_rows, params.window_cols,
params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
output->flat<T>().data(),
reinterpret_cast<int64*>(argmax->flat<int64>().data()),
context->eigen_gpu_device(), propagate_nans);
context->eigen_gpu_device(), propagate_nans, include_batch_in_index);
if (!status) {
context->SetStatus(
errors::Internal("Failed launching MaxPoolForwardWithArgmax"));
Expand All @@ -1350,10 +1339,6 @@ struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& grad_in, const Tensor& argmax,
Tensor* grad_out, const bool include_batch_in_index) {
OP_REQUIRES(context, include_batch_in_index == false,
errors::Unimplemented(
"include_batch_in_index=true is not yet supported "
"on the GPU kernel."));
const int input_size = params.tensor_in_batch * params.tensor_in_rows *
params.tensor_in_cols * params.depth;
const int output_size = params.tensor_in_batch * params.out_height *
Expand All @@ -1364,7 +1349,8 @@ struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
bool status = functor::MaxPoolBackwardWithArgmax<T>()(
output_size, input_size, grad_in.flat<T>().data(),
reinterpret_cast<const int64*>(argmax.flat<int64>().data()), top_offset,
bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device());
bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device(),
include_batch_in_index);
if (!status) {
context->SetStatus(
errors::Internal("Failed launching MaxPoolBackwardWithArgmax"));
Expand All @@ -1377,10 +1363,6 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& grad_in, const Tensor& argmax,
Tensor* grad_out, const bool include_batch_in_index) {
OP_REQUIRES(context, include_batch_in_index == false,
errors::Unimplemented(
"include_batch_in_index=true is not yet supported "
"on the GPU kernel."));
const int input_size = params.tensor_in_batch * params.tensor_in_rows *
params.tensor_in_cols * params.depth;
const int output_size = params.tensor_in_batch * params.out_height *
Expand All @@ -1392,7 +1374,8 @@ struct LaunchMaxPoolingGradGradWithArgmax<Eigen::GpuDevice, T> {
bool status = functor::MaxPoolGradBackwardWithArgmax<T>()(
output_size, input_size, grad_in.flat<T>().data(),
reinterpret_cast<const int64*>(argmax.flat<int64>().data()), top_offset,
bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device());
bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device(),
include_batch_in_index);
if (!status) {
context->SetStatus(
errors::Internal("Failed launching MaxPoolGradBackwardWithArgmax"));
Expand Down Expand Up @@ -1473,32 +1456,32 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
// default Eigen implementation so we are using the custom kernel as the
// default. However, you can explicitly invoke the eigen version using
// kernel_label_map.
#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("MaxPool") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.Label("eigen_tensor"), \
MaxPoolingOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
.Device(DEVICE_GPU) \
.HostMemory("ksize") \
.HostMemory("strides") \
.TypeConstraint<T>("T") \
.Label("eigen_tensor"), \
MaxPoolingV2Op<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
MaxPoolingNoMaskOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
.Device(DEVICE_GPU) \
.HostMemory("ksize") \
.HostMemory("strides") \
.TypeConstraint<T>("T"), \
MaxPoolingNoMaskV2Op<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<int64>("Targmax"), \
#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("MaxPool") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.Label("eigen_tensor"), \
MaxPoolingOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
.Device(DEVICE_GPU) \
.HostMemory("ksize") \
.HostMemory("strides") \
.TypeConstraint<T>("T") \
.Label("eigen_tensor"), \
MaxPoolingV2Op<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
MaxPoolingNoMaskOp<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") \
.Device(DEVICE_GPU) \
.HostMemory("ksize") \
.HostMemory("strides") \
.TypeConstraint<T>("T"), \
MaxPoolingNoMaskV2Op<GPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \
.Device(DEVICE_GPU) \
.TypeConstraint<T>("T") \
.TypeConstraint<int64>("Targmax"), \
MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);

Expand Down
81 changes: 46 additions & 35 deletions tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
Expand Up @@ -54,21 +54,21 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool IsGreaterThan(dtype a, dtype b) {
// int form, keeping track of the flattened index of the input item that
// produces the max output. If a nullptr is passed in for mask, no mask
// will be produced.
// include_batch_in_index: whether to include batch dimension in flattened
// index of `argmax`.
//
// To call the forward and backward functions, use e.g.:
// const int kThreadsPerBlock = 1024
// const int output_size = batch * channels * pooled_height * pooled_width;
// MaxPoolForwardNCHW<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
// kThreadsPerBlock, 0, cuda_stream>>>(...);
template <bool propagate_nans, typename dtype>
__global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
const int channels, const int height,
const int width, const int pooled_height,
const int pooled_width, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_t,
const int pad_l, dtype* top_data,
int64* mask) {
__global__ void MaxPoolForwardNCHW(
const int nthreads, const dtype* bottom_data, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
dtype* top_data, int64* mask, const bool include_batch_in_index) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
Expand All @@ -82,12 +82,13 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
wstart = max(wstart, 0);
dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * channels * height * width;
const int offset = n * channels * height * width;
const dtype* bottom_data_n = bottom_data + offset;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int idx = c * height * width + h * width + w;
if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) {
maxidx = idx;
maxidx = include_batch_in_index ? idx + offset : idx;
maxval = bottom_data_n[idx];
}
}
Expand Down Expand Up @@ -136,14 +137,12 @@ __global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
}

template <bool propagate_nans, typename dtype>
__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
const int height, const int width,
const int channels, const int pooled_height,
const int pooled_width, const int kernel_h,
const int kernel_w, const int stride_h,
const int stride_w, const int pad_t,
const int pad_l, dtype* top_data,
int64* mask) {
__global__ void MaxPoolForwardNHWC(
const int nthreads, const dtype* bottom_data, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
dtype* top_data, int64* mask, const bool include_batch_in_index) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index;
int c = n % channels;
Expand All @@ -158,12 +157,13 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
wstart = max(wstart, 0);
dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
const int offset = n * height * width * channels;
const dtype* bottom_data_n = bottom_data + offset;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int idx = (h * width + w) * channels + c;
if (IsGreaterThan<propagate_nans>(bottom_data_n[idx], maxval)) {
maxidx = idx;
maxidx = include_batch_in_index ? idx + offset : idx;
maxval = bottom_data_n[idx];
}
}
Expand Down Expand Up @@ -231,17 +231,20 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool input.
// This is equal to H*W*C.
// bottom_diff: the gradient with respect to the input.
// include_batch_in_index: whether to include batch dimension in flattened
// index of `argmax`.
// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
// the kernel is run, you will need to make sure that bottom_diff is filled with
// zero first.
template <typename dtype>
__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
const int64* mask, const int top_offset,
const int bottom_offset, dtype* bottom_diff) {
const int bottom_offset, dtype* bottom_diff,
const bool include_batch_in_index) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int image_id = (index / top_offset);
CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
top_diff[index]);
const int offset =
include_batch_in_index ? 0 : (index / top_offset) * bottom_offset;
CudaAtomicAdd(bottom_diff + offset + mask[index], top_diff[index]);
}
}

Expand Down Expand Up @@ -358,14 +361,17 @@ __global__ void MaxPoolGradBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool output.
// This is equal to Hout*Wout*C.
// bottom_diff: the gradient of the gradient w.r.t. output.
// include_batch_in_index: whether to include batch dimension in flattened
// index of `argmax`.
template <typename dtype>
__global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
const int64* mask, const int top_offset,
const int bottom_offset,
dtype* bottom_diff) {
const int bottom_offset, dtype* bottom_diff,
const bool include_batch_in_index) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int image_id = (index / bottom_offset);
bottom_diff[index] = top_diff[image_id * top_offset + mask[index]];
const int offset =
include_batch_in_index ? 0 : (index / bottom_offset) * top_offset;
bottom_diff[index] = top_diff[offset + mask[index]];
}
}

Expand Down Expand Up @@ -399,7 +405,8 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
const int channels, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_t, const int pad_l, T* top_data,
int64* mask, const Eigen::GpuDevice& d, bool propagate_nans) {
int64* mask, const Eigen::GpuDevice& d, bool propagate_nans,
const bool include_batch_in_index) {
const int kThreadsPerBlock = 1024;
const int output_size = batch * channels * pooled_height * pooled_width;
if (output_size == 0) return true;
Expand All @@ -409,14 +416,14 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
kThreadsPerBlock, 0, d.stream()>>>(
output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask);
top_data, mask, include_batch_in_index);
} else {
MaxPoolForwardNHWC<false>
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(
output_size, bottom_data, height, width, channels, pooled_height,
pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
top_data, mask);
top_data, mask, include_batch_in_index);
}
return d.ok();
}
Expand Down Expand Up @@ -449,14 +456,16 @@ template <typename T>
bool MaxPoolBackwardWithArgmax<T>::operator()(
const int output_size, const int input_size, const T* top_diff,
const int64* mask, const int top_offset, const int bottom_offset,
T* bottom_diff, const Eigen::GpuDevice& d) {
T* bottom_diff, const Eigen::GpuDevice& d,
const bool include_batch_in_index) {
const int kThreadsPerBlock = 1024;
if (input_size == 0) return true;
SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(
output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff);
output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff,
include_batch_in_index);
return d.ok();
}

Expand Down Expand Up @@ -492,12 +501,14 @@ template <typename T>
bool MaxPoolGradBackwardWithArgmax<T>::operator()(
const int output_size, const int input_size, const T* top_diff,
const int64* mask, const int top_offset, const int bottom_offset,
T* bottom_diff, const Eigen::GpuDevice& d) {
T* bottom_diff, const Eigen::GpuDevice& d,
const bool include_batch_in_index) {
if (input_size == 0) return true;
CudaLaunchConfig config = GetCudaLaunchConfig(output_size, d);
MaxPoolGradBackward<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(output_size, top_diff, mask, top_offset,
bottom_offset, bottom_diff);
bottom_offset, bottom_diff,
include_batch_in_index);
return d.ok();
}

Expand Down