Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast 3DY reduction when the middle dim Y is large but X and Z are small #29224

Merged
merged 3 commits into from Jul 17, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
227 changes: 222 additions & 5 deletions tensorflow/core/kernels/reduction_gpu_kernels.cu.h
Expand Up @@ -452,6 +452,155 @@ __global__ void ColumnReduceSimpleKernel(T in, outT out, int num_planes,
out[plane * num_cols + col] = sum;
}

template <int unroll, typename T, typename IN_T, typename Op>
__device__ __inline__ T ComputeSum(IN_T in_, const int plane,
const int num_out_rows, int num_rows,
int num_cols, const int col, Op op) {
const int out_rows = num_rows / (2 * unroll);
const int num_rem_rows = num_rows % (2 * unroll);
const int elems_per_plane = num_rows * num_cols;
T reg[2 * unroll];
T sum;
int offset = 0;
if (out_rows != 0) {
for(int i = 0; i < 2 * unroll; i++) {
reg[i] = in_[plane * elems_per_plane + i * (num_out_rows * num_cols)
+ col];
}
sum = reg[0];
for(int i = 1; i < 2 * unroll; i++) {
sum = op(sum, reg[i]);
}
offset = 2 * unroll * (num_out_rows * num_cols);
}

if (col < num_cols && num_rem_rows > 0) {
reg[0] = in_[plane * elems_per_plane + offset + 0 * num_cols + col];
if (out_rows != 0) {
sum = op(sum, reg[0]);
} else {
sum = reg[0];
}
for (int i = 1; i < num_rem_rows; i++) {
reg[0] = in_[plane * elems_per_plane + offset + i * num_cols + col];
sum = op(sum, reg[0]);
}
}
return sum;
}

template <int unroll, typename IN_T, typename Op>
__global__ void ColumnReduceInToTempKernel(void* temp, int temp_in_offset,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change 'T' to 'IN_T' to make things consistent, here and below

int temp_out_offset,
IN_T in, int num_planes,
int num_rows, int num_cols, Op op) {
typedef typename std::iterator_traits<IN_T>::value_type value_type;

value_type* t = (value_type*)temp;
value_type* out_ = t + temp_out_offset;

const int gid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_out_rows = max(1, num_rows / (2 * unroll));
const int plane = gid / (num_out_rows * num_cols);
const int col = gid % (num_out_rows * num_cols);

if (plane >= num_planes) return;

value_type sum;
if (temp_in_offset == -1) {
auto in_ = in;
sum = ComputeSum<unroll, value_type, IN_T, Op>(
in_, plane, num_out_rows, num_rows, num_cols, col, op);
} else {
auto in_ = t + temp_in_offset;
sum = ComputeSum<unroll, value_type, value_type*, Op>(
in_, plane, num_out_rows, num_rows, num_cols, col, op);
}
out_[plane * num_out_rows * num_cols + col] = sum;
}

template <typename T, typename outT, typename Op>
__global__ void ColumnReduceTempToOutKernel(void* temp, int temp_in_offset,
T in, outT out, int num_planes,
int num_rows, int num_cols, Op op) {
typedef typename std::iterator_traits<T>::value_type value_type;
value_type* t = (value_type*)temp;
const int tid = threadIdx.x;
const int gid = threadIdx.x + blockIdx.x * blockDim.x;
int elems_per_plane = num_rows * num_cols;

if (num_rows == 1) {
if (gid >= num_planes * num_cols) return;
if (temp_in_offset == -1) {
auto in_ = in;
out[gid] = in_[gid];
} else {
auto in_ = t + temp_in_offset;
out[gid] = in_[gid];
}
return;
}

const int planes_per_block = 1;
const int plane = blockIdx.x * planes_per_block + tid / elems_per_plane;
// A thread block contains one or multiple plane(s),
// i.e. num_rows * num_cols <= blockDim.x
const int col = tid % elems_per_plane;
const int local_plane = plane % planes_per_block;

if (tid >= planes_per_block * elems_per_plane || plane >= num_planes) return;

extern __shared__ char ss[];
value_type* const smem = reinterpret_cast<value_type*>(ss);

if (temp_in_offset == -1) {
auto in_ = in;
smem[local_plane * elems_per_plane + col] = in_[plane * elems_per_plane + col];
} else {
auto in_ = t + temp_in_offset;
smem[local_plane * elems_per_plane + col] = in_[plane * elems_per_plane + col];
}
__syncthreads();

int num_in_rows = num_rows;
int num_out_rows;
int num_rem_rows;

int in_offset = 0;
int out_offset = blockDim.x;

int in_elems_per_plane = elems_per_plane;
int out_elems_per_plane;

while(num_in_rows > 1) {
num_out_rows = num_in_rows / 2;
num_rem_rows = num_in_rows % 2;
out_elems_per_plane = num_out_rows * num_cols;

if (col < out_elems_per_plane) {
value_type sum;
sum = op(smem[in_offset + local_plane * in_elems_per_plane + col],
smem[in_offset + local_plane * in_elems_per_plane + out_elems_per_plane + col]);
if (num_rem_rows == 1 && col < num_cols) {
sum = op(sum, smem[in_offset + local_plane * in_elems_per_plane
+ 2 * out_elems_per_plane + col]);
}
smem[out_offset + local_plane * out_elems_per_plane + col] = sum;
}

num_in_rows = num_out_rows;
in_elems_per_plane = out_elems_per_plane;
int t_offset = in_offset;
in_offset = out_offset;
out_offset = t_offset;
__syncthreads();
}

if (col < num_cols) {
out[plane * num_cols + col] = smem[in_offset + local_plane * out_elems_per_plane + col];
}
}

struct RowOffset {
__host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {}

Expand Down Expand Up @@ -714,9 +863,9 @@ void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
}

template <typename T, typename Op, typename OUT_T, typename IN_T>
void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
int extent_y, int extent_z, Op op, T init,
const gpuStream_t& cu_stream) {
void Launch3DYReductionSimple(OpKernelContext* ctx, OUT_T out, IN_T in,
int extent_x, int extent_y, int extent_z, Op op,
T init, const gpuStream_t& cu_stream) {
int threads_per_block = 128;
int num_blocks =
(extent_x * extent_z + threads_per_block - 1) / threads_per_block;
Expand All @@ -728,6 +877,68 @@ void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
out, extent_x, extent_y, extent_z, op));
}

template <typename T, typename Op, typename OUT_T, typename IN_T>
void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
int extent_y, int extent_z, Op op, T init,
const cudaStream_t& cu_stream) {
int threads_per_block = 128;

int n_group_in = extent_y;
int n_size = extent_z;
constexpr int unroll = 8;

// Calculate and allocate temporary space
std::size_t temp_storage_bytes = 0;
// A plane's size is n_group_in * n_size. We make sure no single plane crosses
// more than one thread block, meaning a thread block will handle one whole
// plane or multiple planes in the second stage. Also, It may handle a partial
// plane when n_size is too large and the while-loop will stop at
// n_group_in = 1, where we directly copy the temp to output in the next
// stage.
while (n_group_in >= 2 && n_group_in * n_size > threads_per_block) {
int n_group_out = std::max(1, n_group_in / (2 * unroll));
temp_storage_bytes += n_group_out * n_size;
n_group_in = n_group_out;
}
temp_storage_bytes *= extent_x * sizeof(T);
Tensor temp_storage;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
&temp_storage));

// Reduction
n_group_in = extent_y;
int temp_in_offset = -1;
int temp_out_offset = 0;
int num_blocks;
while (n_group_in >= 2 && n_group_in * n_size > threads_per_block) {
int n_group_out = std::max(1, n_group_in / (2 * unroll));
num_blocks = Eigen::divup(extent_x * n_group_out * n_size,
threads_per_block);
ColumnReduceInToTempKernel<unroll, IN_T, Op><<<
num_blocks, threads_per_block, 0, cu_stream>>>(
(void*)(temp_storage.flat<int8_t>().data()), temp_in_offset,
temp_out_offset, in, extent_x, n_group_in, extent_z, op);

n_group_in = n_group_out;
temp_in_offset = temp_out_offset;
temp_out_offset = temp_in_offset + extent_x * n_group_out * n_size;
}

if (n_group_in * n_size <= threads_per_block) {
num_blocks = extent_x;
} else {
DCHECK_EQ(1, n_group_in);
num_blocks = Eigen::divup(extent_x * n_size, threads_per_block);
}

ColumnReduceTempToOutKernel<<<num_blocks, threads_per_block,
2 * sizeof(T) * threads_per_block, cu_stream>>>(
(void*)(temp_storage.flat<int8_t>().data()), temp_in_offset, in, out,
extent_x, n_group_in, extent_z, op);
}

template <typename T, typename Op, typename OUT_T, typename IN_T>
void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
int extent_y, int extent_z, Op op, T init,
Expand Down Expand Up @@ -864,8 +1075,14 @@ void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank,
reduction_axes[0] == 0) { // column reduction
LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
} else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) {
Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
cu_stream);
int elems_per_thread = in_dim1 / (in_dim0 * in_dim2);
if (elems_per_thread >= 16) {
Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
cu_stream);
} else {
Launch3DYReductionSimple(ctx, out, in, in_dim0, in_dim1, in_dim2, op,
init, cu_stream);
}
} else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 &&
reduction_axes[1] == 2) {
Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
Expand Down