Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] NonMaxSuppression #39562

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 36 additions & 30 deletions tensorflow/core/kernels/non_max_suppression_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include <limits>

Expand All @@ -28,7 +28,12 @@ limitations under the License.
#include "tensorflow/core/util/gpu_launch_config.h"
#include "tensorflow/stream_executor/stream_executor.h"

struct __align__(16) Box {

struct
#if GOOGLE_CUDA
__align__(16)
#endif
Box {
float x1, y1, x2, y2;
};

Expand Down Expand Up @@ -114,7 +119,7 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
char* result_mask) {
extern __shared__ int local[];
// set global mask to accept all boxes
for (int box : CudaGridRangeX(bit_mask_len)) {
for (int box : GpuGridRangeX(bit_mask_len)) {
local[box] = 0xFFFFFFFF;
}
__syncthreads();
Expand All @@ -127,15 +132,15 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
accepted_boxes += 1;
int offset = box * bit_mask_len;
// update global mask with current box's mask
for (int b : CudaGridRangeX(bit_mask_len)) {
for (int b : GpuGridRangeX(bit_mask_len)) {
local[b] &= ~bitmask[offset + b];
}
__syncthreads();
if (accepted_boxes > max_boxes) break;
}
// copy global mask to result_max char array. char array is needed for
// cub::DeviceSelect later.
for (int box : CudaGridRangeX(num_boxes)) {
for (int box : GpuGridRangeX(num_boxes)) {
result_mask[box] = CheckBit(local, box);
}
}
Expand Down Expand Up @@ -232,14 +237,14 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
template <typename Index, typename T, typename... Args>
__global__ void IndexMultiSelect(const int num_elements, const Index* indices,
const T* original, T* selected, Args... args) {
for (const int idx : CudaGridRangeX(num_elements)) {
for (const int idx : GpuGridRangeX(num_elements)) {
SelectHelper(idx, indices[idx], original, selected, args...);
}
}

template <typename T>
__global__ void Iota(const int num_elements, const T offset, T* to_fill) {
for (int idx : CudaGridRangeX(num_elements)) {
for (int idx : GpuGridRangeX(num_elements)) {
to_fill[idx] = static_cast<T>(idx) + offset;
}
}
Expand Down Expand Up @@ -322,13 +327,13 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
// do Cub::deviceSelect::flagged
size_t flagged_buffer_size = 0;
cub::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
flagged_buffer_size,
static_cast<int*>(nullptr), // input
static_cast<char*>(nullptr), // selection flag
static_cast<int*>(nullptr), // selected items
static_cast<int*>(nullptr), // num_selected
num_boxes, device.stream());
gpuprim::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
flagged_buffer_size,
static_cast<int*>(nullptr), // input
static_cast<char*>(nullptr), // selection flag
static_cast<int*>(nullptr), // selected items
static_cast<int*>(nullptr), // num_selected
num_boxes, device.stream());
Tensor cub_scratch;
TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
Expand All @@ -337,22 +342,22 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
TensorShape({1}), &d_num_selected));

cub::DeviceSelect::Flagged(
gpuprim::DeviceSelect::Flagged(
(void*)cub_scratch.flat<int8>().data(), // temp_storage
flagged_buffer_size,
d_indices.flat<int>().data(), // input
selected, // selection flag
d_selected_indices, // selected items
d_num_selected.flat<int>().data(), num_boxes, device.stream());
cudaEvent_t copy_done;
gpuEvent_t copy_done;
TF_RETURN_IF_CUDA_ERROR(
cudaEventCreateWithFlags(&copy_done, cudaEventDisableTiming));
gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
sizeof(int));
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
*h_nkeep = *h_selected_count;
cudaEventDestroy(copy_done);
gpuEventDestroy(copy_done);
return Status::OK();
}

Expand All @@ -375,27 +380,28 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
size_t workspace_size = 0;
auto cuda_stream = tensorflow::GetGpuStream(context);
auto device = context->eigen_gpu_device();
cub::DeviceSelect::If(nullptr, workspace_size, static_cast<float*>(nullptr),
static_cast<float*>(nullptr),
static_cast<int*>(nullptr), num_elements, op);
gpuprim::DeviceSelect::If(nullptr, workspace_size,
static_cast<float*>(nullptr),
static_cast<float*>(nullptr),
static_cast<int*>(nullptr), num_elements, op);

TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
TensorShape({1}), &element_count));
cudaEvent_t copy_done;
gpuEvent_t copy_done;
TF_RETURN_IF_CUDA_ERROR(
cudaEventCreateWithFlags(&copy_done, cudaEventDisableTiming));
TF_RETURN_IF_CUDA_ERROR(cub::DeviceSelect::If(
gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
TF_RETURN_IF_CUDA_ERROR(gpuprim::DeviceSelect::If(
workspace.flat<int8>().data(), workspace_size, dev_array,
scratch_output.flat<float>().data(), element_count.flat<int32>().data(),
num_elements, op, cuda_stream));
device.memcpyDeviceToHost(result, element_count.flat<int32>().data(),
sizeof(int));
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
return Status::OK();
}

Expand All @@ -418,7 +424,7 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes,
return Status::OK();
}

cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
cudaError_t cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
nullptr, cub_sort_temp_storage_bytes,
static_cast<float*>(nullptr), // scores
static_cast<float*>(nullptr), // sorted scores
Expand Down Expand Up @@ -458,7 +464,7 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes,
config.virtual_thread_count, 0,
d_indices.flat<int>().data()));
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/non_max_suppression_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct NonMaxSuppression {

} // namespace functor

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
extern const int kNmsBoxesPerTread;

// Given descending sorted box list, apply non-maximal-suppression with given
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ limitations under the License.

namespace tensorflow {

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// These tests are copied from non_max_suppression_op_test.cc file and modified
// to use GPU ops. See other file for test details.

Expand Down
9 changes: 5 additions & 4 deletions tensorflow/core/kernels/ops_testutil.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/node_properties.h"
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
#endif
Expand Down Expand Up @@ -112,7 +112,7 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
thread_pool_.get());

device_type_ = device_type;
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (device_type == DEVICE_GPU) {
managed_allocator_.reset(new GpuManagedAllocator());
allocator_ = managed_allocator_.get();
Expand All @@ -122,7 +122,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
}
#else
CHECK_NE(device_type, DEVICE_GPU)
<< "Requesting GPU on binary compiled without GOOGLE_CUDA.";
<< "Requesting GPU on binary compiled without GOOGLE_CUDA or "
"TENSORFLOW_USE_ROCM.";
allocator_ = device_->GetAllocator(AllocatorAttributes());
#endif
}
Expand Down Expand Up @@ -195,7 +196,7 @@ TensorValue OpsTestBase::mutable_input(int input_index) {
Tensor* OpsTestBase::GetOutput(int output_index) {
CHECK_LT(output_index, context_->num_outputs());
Tensor* output = context_->mutable_output(output_index);
#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
if (device_type_ == DEVICE_GPU) {
managed_outputs_.resize(context_->num_outputs());
// Copy the output tensor to managed memory if we haven't done so.
Expand Down