Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix heap buffer overflow in UnsortedSegmentSum.
When Index=int32, data_size and num_segments were truncated from int64 to int32. This truncation can produce negative numbers, which causes UnsortedSegmentFunctor to access out of bounds memory.

Also:
- Switches some indexing calculations to int64 to avoid signed integer overflow when either the input or output tensors have more than 2**31 - 1 elements.
- Fixes a range check error in the GPU kernel. The segment ID was checked against an upper bound measured in elements, not segments.
PiperOrigin-RevId: 256451663
  • Loading branch information
rryan authored and tensorflower-gardener committed Jul 4, 2019
1 parent f7fe61a commit db4f971
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 33 deletions.
19 changes: 9 additions & 10 deletions tensorflow/core/kernels/segment_reduction_ops.cc
Expand Up @@ -376,18 +376,17 @@ namespace functor {
template <typename T, typename Index, typename InitialValueF,
typename ReductionF>
struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
void operator()(OpKernelContext* ctx, const Index num_segments,
const TensorShape& segment_ids_shape,
void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::ConstTensor data,
typename TTypes<T, 2>::Tensor output) {
output.setConstant(InitialValueF()());
if (data_size == 0) {
if (data.size() == 0) {
return;
}
const int64 N = segment_ids.dimension(0);
const int64 num_segments = output.dimension(0);
ReductionF reduction;
auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
for (int64 i = 0; i < N; ++i) {
Index j = internal::SubtleMustCopy(segment_ids(i));
if (j < 0) {
Expand All @@ -397,7 +396,7 @@ struct UnsortedSegmentFunctor<CPUDevice, T, Index, InitialValueF, ReductionF> {
errors::InvalidArgument(
"segment_ids", SliceDebugString(segment_ids_shape, i),
" = ", j, " is out of range [0, ", num_segments, ")"));
reduction(data_flat.template chip<0>(i), output.template chip<0>(j));
reduction(data.template chip<0>(i), output.template chip<0>(j));
}
}
};
Expand Down Expand Up @@ -485,7 +484,7 @@ class UnsortedSegmentReductionOp : public OpKernel {
return;
}
const auto segment_flat = segment_ids.flat<Index>();
const Index output_rows = internal::SubtleMustCopy(static_cast<Index>(
const int64 output_rows = internal::SubtleMustCopy(static_cast<int64>(
num_segments.dtype() == DT_INT32 ? num_segments.scalar<int32>()()
: num_segments.scalar<int64>()()));
OP_REQUIRES(context, output_rows >= 0,
Expand All @@ -499,9 +498,9 @@ class UnsortedSegmentReductionOp : public OpKernel {
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
auto output_flat = output->flat_outer_dims<T>();
auto data_ptr = data.template flat<T>().data();
reduction_functor_(context, output_rows, segment_ids.shape(), segment_flat,
data.NumElements(), data_ptr, output_flat);
auto data_flat = data.flat_inner_outer_dims<T, 2>(segment_ids.dims() - 1);
reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat,
output_flat);
}

protected:
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/core/kernels/segment_reduction_ops.h
Expand Up @@ -59,10 +59,9 @@ struct SegmentSumFunctor {
template <typename Device, typename T, typename Index, typename InitialValueF,
typename ReductionF>
struct UnsortedSegmentFunctor {
void operator()(OpKernelContext* ctx, const Index num_segments,
const TensorShape& segment_ids_shape,
void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::ConstTensor data,
typename TTypes<T, 2>::Tensor output);
};

Expand Down
41 changes: 21 additions & 20 deletions tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
Expand Up @@ -106,21 +106,21 @@ __global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size,
// Each element is mapped from input to output by a combination of its
// 'segment_ids' mapping and 'inner_dim_size'.
template <typename T, typename Index, typename KernelReductionFunctor>
__global__ void UnsortedSegmentCustomKernel(const Index input_outer_dim_size,
const Index inner_dim_size,
const Index output_outer_dim_size,
__global__ void UnsortedSegmentCustomKernel(const int64 input_outer_dim_size,
const int64 inner_dim_size,
const int64 output_outer_dim_size,
const Index* segment_ids,
const T* input, T* output) {
const Index input_total_size = input_outer_dim_size * inner_dim_size;
const Index output_total_size = output_outer_dim_size * inner_dim_size;
for (int input_index : GpuGridRangeX(input_total_size)) {
const Index input_segment_index = input_index / inner_dim_size;
const Index segment_offset = input_index % inner_dim_size;
const int64 input_total_size = input_outer_dim_size * inner_dim_size;
for (int64 input_index : GpuGridRangeX(input_total_size)) {
const int64 input_segment_index = input_index / inner_dim_size;
const int64 segment_offset = input_index % inner_dim_size;
const Index output_segment_index = segment_ids[input_segment_index];
if (output_segment_index < 0 || output_segment_index >= output_total_size) {
if (output_segment_index < 0 ||
output_segment_index >= output_outer_dim_size) {
continue;
}
const Index output_index =
const int64 output_index =
output_segment_index * inner_dim_size + segment_offset;
KernelReductionFunctor()(output + output_index, ldg(input + input_index));
}
Expand Down Expand Up @@ -174,10 +174,9 @@ void SegmentSumFunctor<T, Index>::operator()(
template <typename T, typename Index, typename InitialValueF,
typename ReductionF>
struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
void operator()(OpKernelContext* ctx, const Index num_segments,
const TensorShape& segment_ids_shape,
void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::ConstTensor data,
typename TTypes<T, 2>::Tensor output) {
if (output.size() == 0) {
return;
Expand All @@ -188,6 +187,7 @@ struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
TF_CHECK_OK(GpuLaunchKernel(
SetToValue<T>, config.block_count, config.thread_per_block, 0,
d.stream(), output.size(), output.data(), InitialValueF()()));
const int64 data_size = data.size();
if (data_size == 0 || segment_ids_shape.num_elements() == 0) {
return;
}
Expand All @@ -196,15 +196,16 @@ struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
// *) 'data_size' is the total number of elements to process.
// *) 'segment_ids.shape' is a prefix of data's shape.
// *) 'input_outer_dim_size' is the total number of segments to process.
const Index input_outer_dim_size = segment_ids.dimension(0);
const Index input_inner_dim_size = data_size / input_outer_dim_size;
const int64 input_outer_dim_size = segment_ids.dimension(0);
const int64 input_inner_dim_size = data.dimension(1);
const int64 output_outer_dim_size = output.dimension(0);
config = GetGpuLaunchConfig(data_size, d);

TF_CHECK_OK(
GpuLaunchKernel(UnsortedSegmentCustomKernel<T, Index, ReductionF>,
config.block_count, config.thread_per_block, 0,
d.stream(), input_outer_dim_size, input_inner_dim_size,
num_segments, segment_ids.data(), data, output.data()));
TF_CHECK_OK(GpuLaunchKernel(
UnsortedSegmentCustomKernel<T, Index, ReductionF>, config.block_count,
config.thread_per_block, 0, d.stream(), input_outer_dim_size,
input_inner_dim_size, output_outer_dim_size, segment_ids.data(),
data.data(), output.data()));
}
};

Expand Down

0 comments on commit db4f971

Please sign in to comment.