Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arbitrary dim for slice #11140

Merged
merged 33 commits into from
Nov 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
35f6a8f
finish refactor
yanchen036 Jun 27, 2017
7219b0d
delete impl files
yanchen036 Jun 27, 2017
d9a4827
compile success
yanchen036 Jun 27, 2017
fed6f55
modify slice interface used in strided_slice
yanchen036 Jun 28, 2017
44bf2f6
delete impl files
yanchen036 Jun 28, 2017
113a9b2
fix undefined error of helper function
yanchen036 Jun 28, 2017
23fd81d
finish refactor
yanchen036 Jun 27, 2017
c57744a
delete impl files
yanchen036 Jun 27, 2017
d9f8662
compile success
yanchen036 Jun 27, 2017
3fc4cb9
modify slice interface used in strided_slice
yanchen036 Jun 28, 2017
d299d1c
delete impl files
yanchen036 Jun 28, 2017
dfb0810
fix undefined error of helper function
yanchen036 Jun 28, 2017
b53900f
type unsupported
yanchen036 Jun 29, 2017
099ceba
Merge branch 'arbitrary_dim_for_slice' of https://github.com/yanchen0…
yanchen036 Jun 29, 2017
d7678d0
remove changes in py file
yanchen036 Jun 29, 2017
ebf6e3e
move SliceSimple function into header
yanchen036 Jun 30, 2017
5ebca51
fix compiling problem
yanchen036 Jun 30, 2017
f928df6
add python test
yanchen036 Jun 30, 2017
ff02620
add python test
yanchen036 Jul 4, 2017
8fe173c
change type
yanchen036 Jul 5, 2017
201280d
Merge branch 'master' into arbitrary_dim_for_slice
yanchen036 Jul 12, 2017
2f5f0f5
compile each dim of slice seperately
yanchen036 Jul 17, 2017
294a354
add files in tensorflow/contrib/makefile/tf_op_files.txt
yanchen036 Jul 17, 2017
a5ed7f6
add some const
yanchen036 Jul 20, 2017
72cca14
add benchmark
yanchen036 Jul 31, 2017
efdcd76
capitalize and punctuate comment
yanchen036 Aug 24, 2017
9591af7
Merge remote-tracking branch 'upstream/master' into arbitrary_dim_for…
yanchen036 Sep 5, 2017
0272f63
Merge remote-tracking branch 'origin/arbitrary_dim_for_slice' into ar…
yanchen036 Sep 5, 2017
bb155d3
uncollapse the for loop by Duff's device
yanchen036 Sep 5, 2017
d393eb4
Merge remote-tracking branch 'upstream/master' into arbitrary_dim_for…
yanchen036 Sep 28, 2017
18dcbb6
Merge remote-tracking branch 'upstream/master' into arbitrary_dim_for…
yanchen036 Oct 9, 2017
3259994
Merge remote-tracking branch 'upstream/master' into arbitrary_dim_for…
yanchen036 Nov 2, 2017
23f08f5
remove ">>>>>>>>>>"
yanchen036 Nov 2, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
119 changes: 42 additions & 77 deletions tensorflow/core/kernels/slice_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,42 +190,26 @@ class SliceOp : public OpKernel {
}
return;
}
#define HANDLE_DIM(NDIM) \
if (input_dims == NDIM) { \
HandleCase<NDIM>(context, begin, size, result); \
return; \
#define HANDLE_DIM(NDIM) \
if (input_dims == NDIM) { \
functor::Slice<Device, T, NDIM>()( \
context->eigen_device<Device>(), result, input, begin, size); \
return; \
}

HANDLE_DIM(1);
HANDLE_DIM(2);
HANDLE_DIM(3);
HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
HANDLE_DIM(7);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not 7?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case 7 is in line 208. It handle all the cases which dim >= 7


#undef HANDLE_DIM

OP_REQUIRES(context, false, errors::Unimplemented(
"SliceOp : Unhandled input dimensions"));
// handle cases which dim >= 7
functor::Slice<Device, T, 7>()(
context->eigen_device<Device>(), result, input, begin, size);
}
}

private:
template <int NDIM>
void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
for (int i = 0; i < NDIM; ++i) {
indices[i] = begin[i];
sizes[i] = size[i];
}

functor::Slice<Device, T, NDIM>()(
context->eigen_device<Device>(), result->tensor<T, NDIM>(),
context->input(0).tensor<T, NDIM>(), indices, sizes);
}
};

#ifdef INTEL_MKL
Expand Down Expand Up @@ -264,24 +248,13 @@ class MklSliceOp : public OpKernel {
}
return;
}
#define HANDLE_DIM(NDIM) \
if (input_dims == NDIM) { \
HandleCase<NDIM>(context, begin, size, result); \
return; \
}

HANDLE_DIM(1);
HANDLE_DIM(2);
HANDLE_DIM(3);
HANDLE_DIM(4);
HANDLE_DIM(5);
HANDLE_DIM(6);
HANDLE_DIM(7);

#undef HANDLE_DIM

OP_REQUIRES(context, false, errors::Unimplemented(
"SliceOp : Unhandled input dimensions"));
// Special case for handling 4-D tensor slice.
if (input_dims == 4) {
HandleCase4D(context, begin, size, result);
} else {
functor::Slice<Device, T, input_dims>()(
context->eigen_device<Device>(), result, input, begin, size);
}
}
}

Expand Down Expand Up @@ -328,8 +301,7 @@ class MklSliceOp : public OpKernel {
return false;
}

template <int NDIM>
void HandleCase(OpKernelContext* context,
void HandleCase4D(OpKernelContext* context,
const gtl::ArraySlice<int64>& begin,
const gtl::ArraySlice<int64>& size, Tensor* result) {
int slice_dim = -1;
Expand All @@ -338,8 +310,7 @@ class MklSliceOp : public OpKernel {
// differs from the input tensor in only 1 out of 4 dimensions.
// This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
// format over channel dimension.
if (NDIM == 4 &&
DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
if (DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
size_t in_strides[4] = { (size_t) in_shape.dim_size(1) *
in_shape.dim_size(2) *
in_shape.dim_size(3),
Expand Down Expand Up @@ -403,30 +374,22 @@ class MklSliceOp : public OpKernel {
// slice_dim is not 1 or 3, then we fallback to Eigen implementation.
}

Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
for (int i = 0; i < NDIM; ++i) {
indices[i] = begin[i];
sizes[i] = size[i];
}

functor::Slice<Device, T, NDIM>()(
context->eigen_device<Device>(), result->tensor<T, NDIM>(),
context->input(0).tensor<T, NDIM>(), indices, sizes);
functor::Slice<Device, T, 4>()(
context->eigen_device<Device>(), result, input, begin, size);
}
};
#endif

// Forward declarations of the functor specializations for declared in the
// sharded source files.
namespace functor {
#define DECLARE_CPU_SPEC(T, NDIM) \
template <> \
void Slice<CPUDevice, T, NDIM>::operator()( \
const CPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
typename TTypes<T, NDIM>::ConstTensor input, \
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
#define DECLARE_CPU_SPEC(T, NDIM) \
template <> \
void Slice<CPUDevice, T, NDIM>::operator()( \
const CPUDevice& d, Tensor* output, \
const Tensor& input, \
const gtl::ArraySlice<int64>& slice_indices, \
const gtl::ArraySlice<int64>& slice_sizes); \
extern template struct Slice<CPUDevice, T, NDIM>;

#define DECLARE_FOR_N(T) \
Expand Down Expand Up @@ -476,13 +439,14 @@ REGISTER_SLICE(bfloat16);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T, NDIM) \
template <> \
void Slice<GPUDevice, T, NDIM>::operator()( \
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
typename TTypes<T, NDIM>::ConstTensor input, \
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
#define DECLARE_GPU_SPEC(T, NDIM) \
template <> \
void Slice<GPUDevice, T, NDIM>::operator()( \
const GPUDevice& d, \
Tensor* output, \
const Tensor& input, \
const gtl::ArraySlice<int64>& slice_indices, \
const gtl::ArraySlice<int64>& slice_sizes); \
extern template struct Slice<GPUDevice, T, NDIM>;

#define DECLARE_FOR_N(T) \
Expand Down Expand Up @@ -536,13 +500,14 @@ REGISTER_KERNEL_BUILDER(Name("Slice")
#ifdef TENSORFLOW_USE_SYCL
// Forward declarations of the functor specializations for SYCL.
namespace functor {
#define DECLARE_SYCL_SPEC(T, NDIM) \
template <> \
void Slice<SYCLDevice, T, NDIM>::operator()( \
const SYCLDevice& d, typename TTypes<T, NDIM>::Tensor output,\
typename TTypes<T, NDIM>::ConstTensor input, \
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes); \
#define DECLARE_SYCL_SPEC(T, NDIM) \
template <> \
void Slice<SYCLDevice, T, NDIM>::operator()( \
const SYCLDevice& d, \
Tensor* output, \
const Tensor& input, \
const gtl::ArraySlice<int64>& slice_indices, \
const gtl::ArraySlice<int64>& slice_sizes); \
extern template struct Slice<SYCLDevice, T, NDIM>;

#define DECLARE_FOR_N(T) \
Expand Down
109 changes: 91 additions & 18 deletions tensorflow/core/kernels/slice_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,104 @@ limitations under the License.
// Functor definition for SliceOp, must be compilable by nvcc.

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/ops_util.h"

namespace tensorflow {
namespace functor {

namespace internal {

template <typename Device, typename T>
void SliceSimple(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int64>& slice_indices);
template <typename Device, typename T>
void SliceSimpleGpu(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int64>& slice_indices);

template <typename Device, typename T>
void SliceSimple(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int64>& slice_indices) {
const int ndims = in.dims();
const int64 nelem = out->NumElements();
const gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
const gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
const T* p = in.flat<T>().data();
T* q = out->flat<T>().data();

std::vector<int64> i_idx(nelem, 0);
std::vector<int64> t(nelem, 0);

for (int64 o_idx = 0; o_idx < nelem; ++o_idx) {
t[o_idx] = o_idx;
}
for (int i = 0; i < ndims; ++i) {
int64 n = (nelem + 7) / 8;
int64 o_idx = 0;
switch (nelem % 8) {
#define CALC_INPUT_IDX \
i_idx[o_idx] += (t[o_idx] / out_strides[i] + slice_indices[i]) * in_strides[i]; \
t[o_idx] %= out_strides[i]; \
++o_idx;
case 0: do { CALC_INPUT_IDX;
case 7: CALC_INPUT_IDX;
case 6: CALC_INPUT_IDX;
case 5: CALC_INPUT_IDX;
case 4: CALC_INPUT_IDX;
case 3: CALC_INPUT_IDX;
case 2: CALC_INPUT_IDX;
case 1: CALC_INPUT_IDX;
#undef CALC_INPUT_IDX
} while (--n > 0);
}
}
for (int64 o_idx = 0; o_idx < nelem; ++o_idx) {
q[o_idx] = p[i_idx[o_idx]];
}
}

template <typename Device, typename T, int NDIMS>
void SliceUsingEigen(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int64>& slice_indices,
const gtl::ArraySlice<int64>& slice_sizes) {
auto input = in.tensor<T, NDIMS>();
auto output = out->tensor<T, NDIMS>();
Eigen::DSizes<int, NDIMS> indices;
for (int i = 0; i < NDIMS; ++i) {
indices[i] = slice_indices[i];
}
Eigen::DSizes<int, NDIMS> sizes;
for (int i = 0; i < NDIMS; ++i) {
sizes[i] = slice_sizes[i];
}
const bool use_64bit = input.size() > Eigen::NumTraits<int>::highest();
if (!use_64bit &&
Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes);
} else {
output.device(d) = input.slice(indices, sizes);
}
}

} // namespace internal

namespace functor {

// Template parameter NDIM is not neccesary here. The aim of keeping it
// is to compile struct slice seperately which minimizes the compiling time.
template <typename Device, typename T, int NDIM>
struct Slice {
void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
typename TTypes<T, NDIMS>::ConstTensor input,
const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_indices,
const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_sizes) {
bool use_64bit = (input.size() > Eigen::NumTraits<int>::highest());
if (!use_64bit &&
Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
Eigen::DSizes<int, NDIMS> indices;
for (int i = 0; i < NDIMS; ++i) {
indices[i] = slice_indices[i];
}
Eigen::DSizes<int, NDIMS> sizes;
for (int i = 0; i < NDIMS; ++i) {
sizes[i] = slice_sizes[i];
}
To32Bit(output).device(d) = To32Bit(input).slice(indices, sizes);
void operator()(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int64>& slice_indices,
const gtl::ArraySlice<int64>& slice_sizes) {
if (in.dims() == NDIM) {
internal::SliceUsingEigen<Device, T, NDIM>(d, out, in, slice_indices, slice_sizes);
} else {
output.device(d) = input.slice(slice_indices, slice_sizes);
if (Eigen::internal::is_same<Device, Eigen::GpuDevice>::value) {
internal::SliceSimpleGpu<Device, T>(d, out, in, slice_indices);
} else {
internal::SliceSimple<Device, T>(d, out, in, slice_indices);
}
}
}
};
Expand Down
56 changes: 56 additions & 0 deletions tensorflow/core/kernels/slice_op_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,65 @@ limitations under the License.

#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"

namespace tensorflow {
namespace internal {

template <typename T>
__global__ void SliceKernel(int nthreads, const T* src, const int32* buf,
const int32 ndims, T* dst) {
const int32* in_strides = buf;
const int32* out_strides = buf + ndims;
const int32* slice_indices = buf + ndims * 2;
CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
int32 i_idx = 0;
int32 t = o_idx;
for (int i = 0; i < ndims; ++i) {
i_idx += (t / out_strides[i] + slice_indices[i]) * in_strides[i];
t %= out_strides[i];
}
dst[o_idx] = ldg(src + i_idx);
}
}

template <typename Device, typename T>
void SliceSimpleGpu(const Device& d, Tensor* out, const Tensor& in,
const gtl::ArraySlice<int64>& slice_indices) {
// Ensures we can use 32-bit index.
const int64 in_nelem = in.NumElements();
CHECK_LT(in_nelem, kint32max) << "Tensor too large to transpose on GPU";
const int64 out_nelem = out->NumElements();
CHECK_LT(out_nelem, kint32max) << "Tensor too large to transpose on GPU";
// Pack strides and slice indices sizes into one buffer.
const int32 ndims = in.dims();
gtl::InlinedVector<int32, 24> host_buf(ndims * 3);
gtl::InlinedVector<int32, 8> in_strides = ComputeStride<int32>(in.shape());
gtl::InlinedVector<int32, 8> out_strides = ComputeStride<int32>(out->shape());
for (int i = 0; i < ndims; ++i) {
host_buf[i] = in_strides[i];
host_buf[ndims + i] = out_strides[i];
host_buf[ndims * 2 + i] = slice_indices[i];
}
auto num_bytes = sizeof(int64) * host_buf.size();
auto dev_buf = d.allocate(num_bytes);
// NOTE: host_buf is not allocated by CudaHostAllocator, and
// therefore we are doing a sync copy effectively.
d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes);
// Launch kernel to q[...] = p[...].
const T* p = in.flat<T>().data();
T* q = out->flat<T>().data();
CudaLaunchConfig cfg = GetCudaLaunchConfig(out_nelem, d);
SliceKernel<<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
cfg.virtual_thread_count, p, reinterpret_cast<const int32*>(dev_buf),
ndims, q);
// Safe to deallocate immediately after the kernel launch.
d.deallocate(dev_buf);
}

} // namespace internal

typedef Eigen::GpuDevice GPUDevice;

Expand Down