-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Audio HIP PR5 - Preemphasis Filter HIP Support #270
base: develop
Are you sure you want to change the base?
Changes from 3 commits
1147bfe
5e3fc7a
fe1a3e6
77e14ef
34f3f6d
ab52683
ee0d6fe
2decd32
30ce1d6
64ae74f
1a3015c
e5865f9
c87f98b
64ca5a3
5eeb4b1
708160c
7e4f3f1
7e7af14
290449e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,4 @@ python: | |
build: | ||
os: ubuntu-22.04 | ||
tools: | ||
python: "3.8" | ||
python: "3.10" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
rocm-docs-core[api_reference]==0.38.1 | ||
rocm-docs-core[api_reference]==1.0.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#ifndef HIP_TENSOR_AUDIO_HPP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MACRO should match file name for this files So it should be changed to HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
#define HIP_TENSOR_AUDIO_HPP | ||
|
||
#include "kernel/pre_emphasis_filter.hpp" | ||
|
||
#endif // HIP_TENSOR_AUDIO_HPP |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#include <hip/hip_runtime.h> | ||
#include "rpp_hip_common.hpp" | ||
|
||
__device__ void pre_emphasis_filter_hip_compute(d_float8 *src1_f8, d_float8 *src2_f8, d_float8 *dst_f8, float4 *coeff_f4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change device helpers to |
||
{ | ||
dst_f8->f4[0] = src1_f8->f4[0] - *coeff_f4 * src2_f8->f4[0]; | ||
dst_f8->f4[1] = src1_f8->f4[1] - *coeff_f4 * src2_f8->f4[1]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please fix double space, and add parentheses around (*coeff_f4 * src2_f8->f4[0]) for clarity. |
||
} | ||
|
||
__global__ void pre_emphasis_filter_tensor(float *srcPtr, | ||
uint2 srcStridesNH, | ||
float *dstPtr, | ||
uint2 dstStridesNH, | ||
RpptImagePatchPtr srcDims, | ||
float *coeffTensor, | ||
RpptAudioBorderType borderType) | ||
{ | ||
int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8 + 1; | ||
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; | ||
int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z; | ||
|
||
if ((id_x >= srcDims[id_z].width) || (id_y >= srcDims[id_z].height)) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove braces since it has only 1 line inside if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
return; | ||
} | ||
|
||
uint srcIdx = (id_z * srcStridesNH.x) + (id_y * srcStridesNH.y) + id_x; | ||
uint dstIdx = (id_z * dstStridesNH.x) + (id_y * dstStridesNH.y) + id_x; | ||
|
||
float4 coeff_f4 = (float4)coeffTensor[id_z]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In all recent HIP kernels, we are using static_cast instead of c style casting
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
d_float8 src1_f8, src2_f8, dst_f8; | ||
rpp_hip_load8_and_unpack_to_float8(srcPtr + srcIdx, &src1_f8); | ||
rpp_hip_load8_and_unpack_to_float8(srcPtr + srcIdx - 1, &src2_f8); | ||
pre_emphasis_filter_hip_compute(&src1_f8, &src2_f8, &dst_f8, &coeff_f4); | ||
rpp_hip_pack_float8_and_store8(dstPtr + dstIdx, &dst_f8); | ||
} | ||
|
||
RppStatus hip_exec_pre_emphasis_filter_tensor(Rpp32f *srcPtr, | ||
RpptDescPtr srcDescPtr, | ||
Rpp32f *dstPtr, | ||
RpptDescPtr dstDescPtr, | ||
RpptImagePatchPtr srcDims, | ||
RpptAudioBorderType borderType, | ||
rpp::Handle& handle) | ||
{ | ||
int globalThreads_x = (dstDescPtr->w + 7) >> 3; | ||
int globalThreads_y = dstDescPtr->h; | ||
int globalThreads_z = handle.GetBatchSize(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use dstDescPtr->n instead of handle.GetBatchSize() in all the new HIP kernels we are going to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
float *coeff = handle.GetInitHandle()->mem.mgpu.floatArr[0].floatmem; | ||
|
||
for(int i = 0; i < srcDescPtr->n; i++) | ||
{ | ||
int id_x = i * srcDescPtr->strides.nStride; | ||
if(borderType == RpptAudioBorderType::ZERO) | ||
dstPtr[id_x] = srcPtr[id_x]; | ||
else | ||
{ | ||
float border = (borderType == RpptAudioBorderType::CLAMP) ? srcPtr[id_x] : srcPtr[id_x + 1]; | ||
dstPtr[id_x] = srcPtr[id_x] - coeff[id_x] * border; | ||
} | ||
} | ||
|
||
hipLaunchKernelGGL(pre_emphasis_filter_tensor, | ||
dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)), | ||
dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), | ||
0, | ||
handle.GetStream(), | ||
srcPtr, | ||
make_uint2(srcDescPtr->strides.nStride, srcDescPtr->strides.hStride), | ||
dstPtr, | ||
make_uint2(dstDescPtr->strides.nStride, dstDescPtr->strides.hStride), | ||
srcDims, | ||
coeff, | ||
borderType); | ||
|
||
return RPP_SUCCESS; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,10 @@ SOFTWARE. | |
#include "rppt_tensor_audio_augmentations.h" | ||
#include "cpu/host_tensor_audio_augmentations.hpp" | ||
|
||
#ifdef HIP_COMPILE | ||
#include "hip/hip_tensor_audio_augmentations.hpp" | ||
#endif // HIP_COMPILE | ||
|
||
/******************** non_silent_region_detection ********************/ | ||
|
||
RppStatus rppt_non_silent_region_detection_host(RppPtr_t srcPtr, | ||
|
@@ -186,3 +190,42 @@ RppStatus rppt_resample_host(RppPtr_t srcPtr, | |
return RPP_ERROR_NOT_IMPLEMENTED; | ||
} | ||
} | ||
|
||
/********************************************************************************************************************/ | ||
/*********************************************** RPP_GPU_SUPPORT = ON ***********************************************/ | ||
/********************************************************************************************************************/ | ||
|
||
#ifdef GPU_SUPPORT | ||
|
||
/******************** pre_emphasis_filter ********************/ | ||
|
||
RppStatus rppt_pre_emphasis_filter_gpu(RppPtr_t srcPtr, | ||
RpptDescPtr srcDescPtr, | ||
RppPtr_t dstPtr, | ||
RpptDescPtr dstDescPtr, | ||
RpptImagePatchPtr srcDims, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is a problem with this PR @HazarathKumarM i think this PR wont be able to build There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the commits didn't get pushed. I have now pushed all the commits. |
||
Rpp32f *coeffTensor, | ||
RpptAudioBorderType borderType, | ||
rppHandle_t rppHandle) | ||
{ | ||
#ifdef HIP_COMPILE | ||
Rpp32u paramIndex = 0; | ||
copy_param_float(coeffTensor, rpp::deref(rppHandle), paramIndex++); | ||
|
||
if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32)) | ||
{ | ||
hip_exec_pre_emphasis_filter_tensor(static_cast<Rpp32f*>(srcPtr), | ||
srcDescPtr, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should check numDims = 3 and other such restrictions in the API file for all these PRs |
||
static_cast<Rpp32f*>(dstPtr), | ||
dstDescPtr, | ||
srcDims, | ||
borderType, | ||
rpp::deref(rppHandle)); | ||
} | ||
return RPP_SUCCESS; | ||
#elif defined(OCL_COMPILE) | ||
return RPP_ERROR_NOT_IMPLEMENTED; | ||
#endif // backend | ||
} | ||
|
||
#endif // GPU_SUPPORT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am also not able to see any test suite changes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have now pushed the changes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add blank line at EOF for all new files you are adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes please end all files with one blank line |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the description for HIP API too
I have added similar description for ToDecibels HIP API
https://github.com/sampath1117/rpp/blob/sr/to_decibels_hip/include/rppt_tensor_audio_augmentations.h#L83
In recent RPP changes, we are following this approach of adding description for both api's
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done