diff --git a/src/modules/hip/kernel/erase.hpp b/src/modules/hip/kernel/erase.hpp index 2591b53f0..f18306a9d 100644 --- a/src/modules/hip/kernel/erase.hpp +++ b/src/modules/hip/kernel/erase.hpp @@ -117,12 +117,34 @@ RppStatus hip_exec_erase_tensor(T *srcPtr, int globalThreads_y = dstDescPtr->h; int globalThreads_z = handle.GetBatchSize(); - if ((srcDescPtr->layout == RpptLayout::NHWC) && (dstDescPtr->layout == RpptLayout::NHWC)) + if (dstDescPtr->layout == RpptLayout::NHWC) { - if (srcDescPtr->dataType == RpptDataType::U8) + // if src layout is NHWC, copy src to dst + if (srcDescPtr->layout == RpptLayout::NHWC) { - hipMemcpyAsync(dstPtr, srcPtr, static_cast(srcDescPtr->n * srcDescPtr->strides.nStride * sizeof(Rpp8u)), hipMemcpyDeviceToDevice, handle.GetStream()); + hipMemcpyAsync(dstPtr, srcPtr, static_cast(srcDescPtr->n * srcDescPtr->strides.nStride * sizeof(T)), hipMemcpyDeviceToDevice, handle.GetStream()); hipStreamSynchronize(handle.GetStream()); + } + // if src layout is NCHW, convert src from NCHW to NHWC + else if (srcDescPtr->layout == RpptLayout::NCHW) + { + globalThreads_x = (dstDescPtr->w + 7) >> 3; + hipLaunchKernelGGL(convert_pln3_pkd3_hip_tensor, + dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)), + dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), + 0, + handle.GetStream(), + srcPtr, + make_uint3(srcDescPtr->strides.nStride, srcDescPtr->strides.cStride, srcDescPtr->strides.hStride), + dstPtr, + make_uint2(dstDescPtr->strides.nStride, dstDescPtr->strides.hStride), + roiTensorPtrSrc); + globalThreads_x = dstDescPtr->w; + hipStreamSynchronize(handle.GetStream()); + } + + if (srcDescPtr->dataType == RpptDataType::U8) + { hipLaunchKernelGGL(erase_pkd_hip_tensor, dim3(ceil((float)globalThreads_x / LOCAL_THREADS_X), ceil((float)globalThreads_y / LOCAL_THREADS_Y), ceil((float)globalThreads_z / LOCAL_THREADS_Z)), dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), @@ -137,8 +159,6 @@ RppStatus hip_exec_erase_tensor(T *srcPtr, } else if (srcDescPtr->dataType == RpptDataType::F16) { - hipMemcpyAsync(dstPtr, srcPtr, static_cast(srcDescPtr->n * srcDescPtr->strides.nStride * sizeof(Rpp16f)), hipMemcpyDeviceToDevice, handle.GetStream()); - hipStreamSynchronize(handle.GetStream()); hipLaunchKernelGGL(erase_pkd_hip_tensor, dim3(ceil((float)globalThreads_x / LOCAL_THREADS_X), ceil((float)globalThreads_y / LOCAL_THREADS_Y), ceil((float)globalThreads_z / LOCAL_THREADS_Z)), dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), @@ -153,8 +173,6 @@ RppStatus hip_exec_erase_tensor(T *srcPtr, } else if (srcDescPtr->dataType == RpptDataType::F32) { - hipMemcpyAsync(dstPtr, srcPtr, static_cast(srcDescPtr->n * srcDescPtr->strides.nStride * sizeof(Rpp32f)), hipMemcpyDeviceToDevice, handle.GetStream()); - hipStreamSynchronize(handle.GetStream()); hipLaunchKernelGGL(erase_pkd_hip_tensor, dim3(ceil((float)globalThreads_x / LOCAL_THREADS_X), ceil((float)globalThreads_y / LOCAL_THREADS_Y), ceil((float)globalThreads_z / LOCAL_THREADS_Z)), dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), @@ -169,8 +187,6 @@ RppStatus hip_exec_erase_tensor(T *srcPtr, } else if (srcDescPtr->dataType == RpptDataType::I8) { - hipMemcpyAsync(dstPtr, srcPtr, static_cast(srcDescPtr->n * srcDescPtr->strides.nStride * sizeof(Rpp8s)), hipMemcpyDeviceToDevice, handle.GetStream()); - hipStreamSynchronize(handle.GetStream()); hipLaunchKernelGGL(erase_pkd_hip_tensor, dim3(ceil((float)globalThreads_x / LOCAL_THREADS_X), ceil((float)globalThreads_y / LOCAL_THREADS_Y), ceil((float)globalThreads_z / LOCAL_THREADS_Z)), dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), @@ -245,33 +261,6 @@ RppStatus hip_exec_erase_tensor(T *srcPtr, numBoxesTensor, roiTensorPtrSrc); } - else if ((srcDescPtr->layout == RpptLayout::NCHW) && (dstDescPtr->layout == RpptLayout::NHWC)) - { - globalThreads_x = (dstDescPtr->w + 7) >> 3; - hipLaunchKernelGGL(convert_pln3_pkd3_hip_tensor, - dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)), - dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), - 0, - handle.GetStream(), - srcPtr, - make_uint3(srcDescPtr->strides.nStride, srcDescPtr->strides.cStride, srcDescPtr->strides.hStride), - dstPtr, - make_uint2(dstDescPtr->strides.nStride, dstDescPtr->strides.hStride), - roiTensorPtrSrc); - hipStreamSynchronize(handle.GetStream()); - globalThreads_x = dstDescPtr->w; - hipLaunchKernelGGL(erase_pkd_hip_tensor, - dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)), - dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), - 0, - handle.GetStream(), - dstPtr, - make_uint2(dstDescPtr->strides.nStride, dstDescPtr->strides.hStride), - anchorBoxInfoTensor, - colorsTensor, - numBoxesTensor, - roiTensorPtrSrc); - } } return RPP_SUCCESS;