Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audio HIP PR3 - Downmixing HIP Support #261

Open
wants to merge 58 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
c33af22
Bump rocm-docs-core[api_reference] from 0.35.0 to 0.35.1 in /docs/sph…
dependabot[bot] Mar 6, 2024
14f6334
Bump rocm-docs-core[api_reference] from 0.35.1 to 0.36.0 in /docs/sph…
dependabot[bot] Mar 12, 2024
95c3272
Merge branch 'master' into develop
kiritigowda Mar 12, 2024
3973c34
added api support for ToDecibels HIP kernel
sampath1117 Mar 19, 2024
3f08f90
added test suite support for audio in HIP
sampath1117 Mar 8, 2024
b9c0788
added profiler support for hip test suite
sampath1117 Mar 18, 2024
641f653
Docs - Bump rocm-docs-core[api_reference] from 0.36.0 to 0.37.0 in /d…
dependabot[bot] Mar 20, 2024
5568573
Link cleanup (#326)
LisaDelaney Mar 20, 2024
a6749ba
Update notes
LisaDelaney Mar 20, 2024
a255906
Docs - Bump rocm-docs-core[api_reference] from 0.37.0 to 0.37.1 in /d…
dependabot[bot] Mar 22, 2024
d3df761
RPP Voxel Flip on HIP and HOST (#285)
r-abishek Mar 23, 2024
ebecb42
RPP Vignette Tensor on HOST and HIP (#311)
r-abishek Mar 23, 2024
fc1410b
Bump rocm-docs-core[api_reference] from 0.37.1 to 0.38.0 in /docs/sph…
dependabot[bot] Mar 27, 2024
cb3d539
added initial api and test suite support for downmixing hip kernel
sampath1117 Mar 29, 2024
7fe2a0e
Merge branch 'develop' into sr/downmixing_hip
sampath1117 Mar 29, 2024
e80533a
initial working commit for downmixing hip kernel
sampath1117 Mar 29, 2024
50a743f
added support fo copy input to output when channels is 1
sampath1117 Mar 29, 2024
fe06106
minor code cleanup
sampath1117 Apr 1, 2024
3ebd7c3
RPP Tensor Audio Support - Resample (#310)
r-abishek Apr 3, 2024
76f31df
Docs - Missing input and output images for Doxygen (#331)
r-abishek Apr 3, 2024
b83f910
Scratch buffers rename for HOST and HIP (#324)
r-abishek Apr 3, 2024
ebeb131
Update CMakeLists.txt
kiritigowda Apr 3, 2024
7c194b2
Merge branch 'develop' into sr/downmixing_hip
sampath1117 Apr 8, 2024
cb06f7f
added missing hipDeviceSynchronize() in test suite
sampath1117 Apr 8, 2024
0e51993
removed f16 includes since not needed for audio
sampath1117 Apr 4, 2024
004e1d6
restructured python test suite
sampath1117 Apr 4, 2024
b8f5c60
fixed spacing in Doxygen
sampath1117 Apr 11, 2024
1147bfe
Update CMakeLists.txt
kiritigowda Apr 12, 2024
9d48447
Merge remote-tracking branch 'develop' into sr/downmixing_hip
sampath1117 Apr 16, 2024
5e3fc7a
Bump rocm-docs-core[api_reference] from 0.38.1 to 1.0.0 in /docs/sphi…
dependabot[bot] Apr 18, 2024
b6b7cc5
Bump rocm-docs-core[api_reference] from 1.0.0 to 1.1.0 in /docs/sphin…
dependabot[bot] Apr 25, 2024
e16ad7a
RPP Gaussian Noise Voxel Tensor on HOST and HIP (#323)
r-abishek Apr 26, 2024
9394c78
Merge branch 'develop' into sr/downmixing_hip
sampath1117 Apr 30, 2024
06263a5
modify CHECK to CHECK_RETURN_STATUS
sampath1117 Apr 30, 2024
a7e71a3
Merge branch 'develop' into sr/downmixing_hip
sampath1117 May 2, 2024
77e14ef
Minor common-fixes for HIP (#345)
r-abishek May 7, 2024
34f3f6d
Readme Updates: --usecase=rocm (#349)
kiritigowda May 8, 2024
ab52683
RPP Tensor Audio Support - Spectrogram (#312)
r-abishek May 8, 2024
ee0d6fe
Update CHANGELOG.md (#352)
r-abishek May 8, 2024
2decd32
RPP Tensor Audio Support - Slice (#325)
r-abishek May 8, 2024
30ce1d6
RPP Tensor Audio Support - MelFilterBank (#332)
r-abishek May 8, 2024
64ae74f
RPP Tensor Normalize ND on HOST and HIP (#335)
r-abishek May 9, 2024
1a3015c
SWDEV-459739 - Remove the package obsolete setting (#353)
raramakr May 9, 2024
b926816
Merge branch 'develop' into sr/downmixing_hip
sampath1117 May 9, 2024
7cb3c03
changed globalThreads_z to use batchsize from description pointer
sampath1117 May 9, 2024
4cb8d4b
Audio support merge commit fixes (#354)
r-abishek May 9, 2024
8aab10c
Merge branch 'develop' into sr/downmixing_hip
sampath1117 May 17, 2024
9d68c49
removed if else block based on channels inside kernel
sampath1117 May 17, 2024
195e4a4
rename instances of tensor_hip_audio to tensor_audio_hip
sampath1117 May 17, 2024
807bbe8
modified verify_output to have different cutoff for HIP and HOST back…
sampath1117 May 15, 2024
b8264f7
vectorized channels loop in hip kernel
sampath1117 May 19, 2024
5cf7360
Merge branch 'develop' into sr/downmixing_hip
sampath1117 May 30, 2024
d62e9cb
Merge branch 'develop' into sr/downmixing_hip
sampath1117 Jun 18, 2024
239bf76
removed the multiplication with normalizeWeight for every channel value
sampath1117 Jun 18, 2024
fd21921
vectorized the writes for hip kernel
sampath1117 Jun 18, 2024
23072da
moved constant compute outside the loop
sampath1117 Jun 18, 2024
3b00c64
reverted back to unvectorized writes version
sampath1117 Jun 21, 2024
0c71365
Merge branch 'develop' into sr/downmixing_hip
sampath1117 Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 60 additions & 43 deletions include/rppt_tensor_audio_augmentations.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ extern "C" {
* \details Non Silent Region Detection augmentation for 1D audio buffer
\n Finds the starting index and length of non silent region in the audio buffer by comparing the
calculated short-term power with cutoff value passed
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param[out] detectedIndexTensor beginning index of non silent region (1D tensor in HOST memory, of size batchSize)
* \param[out] detectionLengthTensor length of non silent region (1D tensor in HOST memory, of size batchSize)
* \param[in] cutOffDB cutOff in dB below which the signal is considered silent
* \param[in] windowLength window length used for computing short-term power of the signal
* \param[in] referencePower reference power that is used to convert the signal to dB
* \param[in] resetInterval number of samples after which the moving mean average is recalculated to avoid precision loss
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param [out] detectedIndexTensor beginning index of non silent region (1D tensor in HOST memory, of size batchSize)
* \param [out] detectionLengthTensor length of non silent region (1D tensor in HOST memory, of size batchSize)
* \param [in] cutOffDB cutOff in dB below which the signal is considered silent
* \param [in] windowLength window length used for computing short-term power of the signal
* \param [in] referencePower reference power that is used to convert the signal to dB
* \param [in] resetInterval number of samples after which the moving mean average is recalculated to avoid precision loss
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand All @@ -64,15 +64,15 @@ RppStatus rppt_non_silent_region_detection_host(RppPtr_t srcPtr, RpptDescPtr src

/*! \brief To Decibels augmentation on HOST backend
* \details To Decibels augmentation for 1D audio buffer converts magnitude values to decibel values
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcDims source tensor sizes for each element in batch (2D tensor in HOST memory, of size batchSize * 2)
* \param[in] cutOffDB minimum or cut-off ratio in dB
* \param[in] multiplier factor by which the logarithm is multiplied
* \param[in] referenceMagnitude Reference magnitude if not provided maximum value of input used as reference
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcDims source tensor sizes for each element in batch (2D tensor in HOST memory, of size batchSize * 2)
* \param [in] cutOffDB minimum or cut-off ratio in dB
* \param [in] multiplier factor by which the logarithm is multiplied
* \param [in] referenceMagnitude Reference magnitude if not provided maximum value of input used as reference
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand All @@ -81,14 +81,14 @@ RppStatus rppt_to_decibels_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_

/*! \brief Pre Emphasis Filter augmentation on HOST backend
* \details Pre Emphasis Filter augmentation for audio data
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param[in] coeffTensor preemphasis coefficient (1D tensor in HOST memory, of size batchSize)
* \param[in] borderType border value policy
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcLengthTensor source audio buffer length (1D tensor in HOST memory, of size batchSize)
* \param [in] coeffTensor preemphasis coefficient (1D tensor in HOST memory, of size batchSize)
* \param [in] borderType border value policy
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand All @@ -97,19 +97,36 @@ RppStatus rppt_pre_emphasis_filter_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr,

/*! \brief Down Mixing augmentation on HOST backend
* \details Down Mixing augmentation for audio data
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param[in] normalizeWeights bool flag to specify if normalization of weights is needed
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param [in] normalizeWeights bool flag to specify if normalization of weights is needed
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
*/
RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDimsTensor, bool normalizeWeights, rppHandle_t rppHandle);

#ifdef GPU_SUPPORT
/*! \brief Down Mixing augmentation on HIP backend
* \details Down Mixing augmentation for audio data
* \param [in] srcPtr source tensor in HIP memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HIP memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 2, offsetInBytes >= 0, dataType = F32)
* \param [in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HIP/Pinned memory, of size batchSize * 2)
* \param [in] normalizeWeights bool flag to specify if normalization of weights is needed
* \param [in] rppHandle RPP HIP handle created with <tt>\ref rppCreateWithStreamAndBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
*/
RppStatus rppt_down_mixing_gpu(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDimsTensor, bool normalizeWeights, rppHandle_t rppHandle);
#endif // GPU_SUPPORT

/*! \brief Produces a spectrogram from a 1D audio buffer on HOST backend
* \details Spectrogram for 1D audio buffer
* \param [in] srcPtr source tensor in HOST memory
Expand Down Expand Up @@ -153,15 +170,15 @@ RppStatus rppt_mel_filter_bank_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, Rpp

/*! \brief Resample augmentation on HOST backend
* \details Resample augmentation for audio data
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] inRate Input sampling rate (1D tensor in HOST memory, of size batchSize)
* \param[in] outRate Output sampling rate (1D tensor in HOST memory, of size batchSize)
* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param[in] window Resampling window (struct of type RpptRpptResamplingWindow)
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \param [in] srcPtr source tensor in HOST memory
* \param [in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [out] dstPtr destination tensor in HOST memory
* \param [in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param [in] inRate Input sampling rate (1D tensor in HOST memory, of size batchSize)
* \param [in] outRate Output sampling rate (1D tensor in HOST memory, of size batchSize)
* \param [in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param [in] window Resampling window (struct of type RpptRpptResamplingWindow)
* \param [in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
Expand Down
30 changes: 30 additions & 0 deletions src/modules/hip/hip_tensor_audio_augmentations.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
MIT License

Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#ifndef HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP
#define HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP

#include "kernel/down_mixing.hpp"

#endif // HIP_TENSOR_AUDIO_AUGMENTATIONS_HPP
73 changes: 73 additions & 0 deletions src/modules/hip/kernel/down_mixing.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include <hip/hip_runtime.h>
#include "rpp_hip_common.hpp"

__global__ void down_mixing_hip_tensor(float *srcPtr,
uint srcStride,
float *dstPtr,
uint dstStride,
int2 *srcDimsTensor)

{
int id_x = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not vectorizeed?

Copy link
Author

@sampath1117 sampath1117 May 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can only vectorize the channel code as how it is done in host code
Initially felt it might not be efficient to vectorize the channel loop in HIP
But explored further and vectorized the code now and checked the performance. It looks fine

Could not see improvement in performance though with current inputs
since we are testing with 2 channel input and vectorization only works with at least 8 channel input

int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z;
int srcLength = srcDimsTensor[id_z].x;
int channels = srcDimsTensor[id_z].y;

if (id_x >= srcLength)
return;

float nomalizedWeight = 1.f / channels;
float outVal = 0.0f;
uint srcIdx = id_z * srcStride + id_x * channels;
int i = 0;
int alignedChannels = (channels / 8) * 8;
// if number of channels is a multiple of 8, do 8 pixel load till alignedChannels value
if (alignedChannels)
{
d_float8 outVal_f8;
outVal_f8.f4[0] = static_cast<float4>(0.0f);
outVal_f8.f4[1] = outVal_f8.f4[0];
float4 normalizedWeight_f4 = static_cast<float4>(nomalizedWeight);
for(; i < alignedChannels; i += 8, srcIdx += 8)
{
d_float8 src_f8;
rpp_hip_load8_and_unpack_to_float8(srcPtr + srcIdx, &src_f8);
rpp_hip_math_multiply8_const(&src_f8, &src_f8, normalizedWeight_f4);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sampath1117 normalizedWeight is a constant which is just 1/channels. You don't need this multiply channels times in the loop.
Just add everything into outVal and divide once by channels right at the end.
Remove this normalizedWeight variable too.

rpp_hip_math_add8(&outVal_f8, &src_f8, &outVal_f8);
}
outVal_f8.f4[0] += outVal_f8.f4[1];
outVal += (outVal_f8.f1[0] + outVal_f8.f1[1] + outVal_f8.f1[2] + outVal_f8.f1[3]);
}
// process remaining channels
for(; i < channels; i++, srcIdx++)
outVal += srcPtr[srcIdx] * nomalizedWeight;

uint dstIdx = id_z * dstStride + id_x;
dstPtr[dstIdx] = outVal;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the range of values channels we can have? Any ways to vectorize the write here?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the RPPT_MAX_AUDIO_CHANNELS = 16 in the Resample PR. If thats the case, possibly running the 1 liner un-vectorized loop on L43 to do a "outVal += srcPtr[srcIdx]" should be fine followed by doing a "outVal *= nomalizedWeight" at the end.
What could possibly be tried is if you:

  • reduce threads launched by a factor of 8 (the usual >> 3 with d_float8s, or try with a >> 2 launch with float4s)
  • vectorize the dst writes here so we write 8 outVal_f8
  • to compute outVal_f8, do a "outVal_f8 += srcVals_f8" in the loop
  • resetting srcVals_f8 each time (channels times) inside the loop from global mem could take time, but I suggest trying this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this experiment and it is giving a good 30% improvement in performance

Previously
I did not vectorize the writes because, for the higher channel cases (16 channel input), we need to load 128 elements per thread (8 * 16) and felt it increases the load too much per thread

Currently we have only 2 channel input in test suite and it goes to non vectorized loop, so it is loading 16 elements per thread and giving a 30% boost compared to unvectorized writes case

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this performance improvement was an anomoly
Reran the tests for updating numbers in MS, but observing 3x more time for vectorized writes version compared to unvectorized writes

Reran multiple times today and confirmed there was a degradation in performance
So reverted back to single pixel writes version

}

RppStatus hip_exec_down_mixing_tensor(Rpp32f *srcPtr,
RpptDescPtr srcDescPtr,
Rpp32f *dstPtr,
RpptDescPtr dstDescPtr,
Rpp32s *srcDimsTensor,
bool normalizeWeights,
rpp::Handle& handle)
{
Rpp32s globalThreads_x = dstDescPtr->strides.nStride;
Rpp32s globalThreads_y = 1;
Rpp32s globalThreads_z = dstDescPtr->n;

hipLaunchKernelGGL(down_mixing_hip_tensor,
dim3(ceil((Rpp32f)globalThreads_x/LOCAL_THREADS_X_1DIM), ceil((Rpp32f)globalThreads_y/LOCAL_THREADS_Y_1DIM), ceil((Rpp32f)globalThreads_z/LOCAL_THREADS_Z_1DIM)),
dim3(LOCAL_THREADS_X_1DIM, LOCAL_THREADS_Y_1DIM, LOCAL_THREADS_Z_1DIM),
0,
handle.GetStream(),
srcPtr,
srcDescPtr->strides.nStride,
dstPtr,
dstDescPtr->strides.nStride,
reinterpret_cast<int2 *>(srcDimsTensor));

return RPP_SUCCESS;
}
44 changes: 44 additions & 0 deletions src/modules/rppt_tensor_audio_augmentations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ SOFTWARE.
#include "rppt_tensor_audio_augmentations.h"
#include "cpu/host_tensor_audio_augmentations.hpp"

#ifdef HIP_COMPILE
#include "hip/hip_tensor_audio_augmentations.hpp"
#endif // HIP_COMPILE

/******************** non_silent_region_detection ********************/

RppStatus rppt_non_silent_region_detection_host(RppPtr_t srcPtr,
Expand Down Expand Up @@ -268,3 +272,43 @@ RppStatus rppt_resample_host(RppPtr_t srcPtr,
return RPP_ERROR_NOT_IMPLEMENTED;
}
}

/********************************************************************************************************************/
/*********************************************** RPP_GPU_SUPPORT = ON ***********************************************/
/********************************************************************************************************************/

#ifdef GPU_SUPPORT

/******************** down_mixing ********************/

RppStatus rppt_down_mixing_gpu(RppPtr_t srcPtr,
RpptDescPtr srcDescPtr,
RppPtr_t dstPtr,
RpptDescPtr dstDescPtr,
Rpp32s *srcDimsTensor,
bool normalizeWeights,
rppHandle_t rppHandle)
{
#ifdef HIP_COMPILE
if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32))
{
hip_exec_down_mixing_tensor(static_cast<Rpp32f*>(srcPtr),
srcDescPtr,
static_cast<Rpp32f*>(dstPtr),
dstDescPtr,
srcDimsTensor,
normalizeWeights,
rpp::deref(rppHandle));
}
else
{
return RPP_ERROR_NOT_IMPLEMENTED;
}

return RPP_SUCCESS;
#elif defined(OCL_COMPILE)
return RPP_ERROR_NOT_IMPLEMENTED;
#endif // backend
}

#endif // GPU_SUPPORT
25 changes: 24 additions & 1 deletion utilities/test_suite/HIP/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ find_package(hip QUIET)
find_package(OpenCV QUIET)
find_package(TurboJpeg QUIET)
find_package(NIFTI QUIET)
find_library(libsnd_LIBS
NAMES sndfile libsndfile
PATHS ${CMAKE_SYSTEM_PREFIX_PATH} ${LIBSND_ROOT_DIR} "/usr/local"
PATH_SUFFIXES lib lib64)

# OpenMP
find_package(OpenMP REQUIRED)
Expand Down Expand Up @@ -102,4 +106,23 @@ if(NIFTI_FOUND AND OpenCV_FOUND)
target_link_libraries(Tensor_voxel_hip ${OpenCV_LIBS} -lturbojpeg -lrpp ${hip_LIBRARIES} pthread ${LINK_LIBRARY_LIST} hip::device ${NIFTI_PACKAGE_PREFIX}NIFTI::${NIFTI_PACKAGE_PREFIX}niftiio)
else()
message("-- ${Yellow}Warning: libniftiio must be installed to install ${PROJECT_NAME}/Tensor_voxel_hip successfully!${ColourReset}")
endif()
endif()

if(NOT libsnd_LIBS)
message("-- ${Yellow}Warning: libsndfile must be installed to install ${PROJECT_NAME}/Tensor_audio_host successfully!${ColourReset}")
else()
message("-- ${Green}${PROJECT_NAME} set to build with rpp and libsndfile ${ColourReset}")
set(COMPILER_FOR_HIP ${ROCM_PATH}/bin/hipcc)
set(CMAKE_CXX_COMPILER ${COMPILER_FOR_HIP})
include_directories(${ROCM_PATH}/include ${ROCM_PATH}/include/rpp /usr/local/include)
link_directories(${ROCM_PATH}/lib /usr/local/lib)
include_directories(${SndFile_INCLUDE_DIRS})
link_directories(${SndFile_LIBRARIES_DIR} /usr/local/lib/)

add_executable(Tensor_audio_hip Tensor_audio_hip.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=gnu++17")
if(NOT APPLE)
set(LINK_LIBRARY_LIST ${LINK_LIBRARY_LIST} stdc++fs)
endif()
target_link_libraries(Tensor_audio_hip ${libsnd_LIBS} -lsndfile -lrpp pthread ${LINK_LIBRARY_LIST})
endif()
Loading