Skip to content

Commit

Permalink
RPP Tensor Audio Support - Down Mixing (ROCm#296)
Browse files Browse the repository at this point in the history
* Initial commit - Non slient region detection

Includes unittest setup

* Initial commit - To Decibels

Includes unittest setup

* Intial commit - pre_emphasis_filter

* Intial commit - down_mixing

* Replace vectors with arrays

* Cleanup

* Minor cleanup

* Optimize downmixing Kernel

Includes cleanup

* Replace Rpp64s with Rpp32s

* Cleanup

* Optimize and precompute cutOff

* Fix buffer used

* Fix buffer used

* Additional Cleanup

* Optimize post incrmeent operation

* Optimize post increment operation

* Update testsuite for Audio

* code cleanup

* Add Readme file for Audio test suite

* changes based on review comments

* minor change

* Remove unittest folders and updated README.md

* Remove unit tests

* minor change

* code cleanup

* added common header file for audio helper functions

* removed unncessary audio wav files

fixed bug in ROI updation for audio test suite

resolved issue in summary generation for performance tests in python

* removed log file

* added doxygen support for audio

* added doxygen changes for to_decibels

* updated test suite support for to_decibels

* minor change

* added doxygen changes for preemphasis filter

* updated changes for preemphasis filter in test suite

* removed the usage of getMax function and used std::max_element

* modularized code in test suite

* merge with latest changes

* minor change

* minor change

* minor change

* resolved codacy warnings

* Codacy fix - Remove unused cpuTime

* CMakeLists - Version Update

1.5.0 - TOT Version

* CHANGELOG Updates

Version 1.5.0 placeholder

* resolved issue with file_system dependency in test suite

* Doxygen changes

changed malloc to new in NSR kernel

* RPP RICAP Tensor for HOST and HIP (r-abishek#213)

* Initial commit - Ricap HOST Tensor

Includes testsuite changes

* Add QA tests for RICAP

Used three_images_224x224_src1 folder to create golden outputs

* Add three_images_224x224_src1 into TEST_IMAGES

* Support HIP Backend for RICAP

* Fix HIP pkd3->pkd3 variant

* regenerated golden outputs for RICAP

minor changes in HOST shell script for handling RICAP in QA mode

* minor bug fix in RICAP HIP kernels

* Improve readability and Cleanup

* Additional cleanup

* Cleanup testsuite

Includes new golden outputs

* Additional testuite fixes

* Minor cleanup

* Fix codacy warnings

* Address other codacy warnings

* Update ricap.hpp with reference paper

* Add RICAP dataset path in readme

* Make changes to error codes returned

* Modify roi crop region for unit and perf tests

* RPP Tensor Water Augmentation on HOST and HIP (r-abishek#181)

* added water HOST and HIP codes

* added water case in test suite

* added golden outputs for water

* added omp thread changes for water augmentation

* experimental changes

* fixed output issue with AVX2 instructions

* added AVX2 support for PKD3 load function

minor changes in PLN variant load functions

* nwc commit - added avx2 changes for u8 layout toggle variants but need to add store functions for completion

* Add Avx2 implementation for F32 and U8 toggle variants

* Add AVX2 support for u8 pkd3-pln3 and i8 pkd3-pln3 for water augmentation

* change F32 load and store logic

* optimized the store function for F32 PLN3-PKD3

* reverted back irrelevant changes

* minor change

* optimized load and store functions for water U8 and F32 variants in host

removed commented code

* removed golden outputs for water

* minor changes

* renamed few functions and removed unused functions

updated i8 pln1 load as per the optimized u8 pln1 load

* fixed bug in i8 load function

* changed cast to c++ style

resolved spacing issues and added comments for AVX codes for better understanding

made changes to handle cases where QA Tests are not supported

* added golden outputs for water

* updated golden outputs with latest changes

* modified the u8, i8 pkd3-pln3 function and added comments for the vectorized code

* fixed minor bug in I8 variants

* made to changes to resolve codacy warnings

* changed cast to c++ style in hip kernel

* changed generic nn F32 loads using gather and setr instructions

* added comments for latest changes

* minor change

* added definition for storing 32 and 64 bits from a 128bit register

---------

Co-authored-by: sampath1117 <sampath.rachumallu@multicorewareinc.com>
Co-authored-by: HazarathKumarM <hazarathkumar@multicorewareinc.com>

* Fix build error

* CMakeLists - Version Update

1.5.0 - TOT Version

* CHANGELOG Updates

Version 1.5.0 placeholder

* Boost deps fix for test suite

---------

Co-authored-by: Snehaa Giridharan <snehaa@multicorewareinc.com>
Co-authored-by: sampath1117 <sampath.rachumallu@multicorewareinc.com>
Co-authored-by: Snehaa-Giridharan <118163708+snehaa8@users.noreply.github.com>
Co-authored-by: HazarathKumarM <hazarathkumar@multicorewareinc.com>
Co-authored-by: Kiriti Gowda <kiritigowda@gmail.com>

* Documentation - Readme & changelog updates (r-abishek#251)

* readme and changelog updates for 6.0

* minor update

* added ctests for audio test suite for CI

made changes to add more clarity on the QA Tests results

* Cmake mods for ctest

* HOST-only build error bugfix

* added qa mode paramter to python audio script

added golden output map for QA testing of Non silent region detection

* minor change

* Documentation - Bump rocm-docs-core[api_reference] from 0.26.0 to 0.27.0 in /docs/sphinx (r-abishek#253)

Bumps [rocm-docs-core[api_reference]](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.26.0 to 0.27.0.
- [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases)
- [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md)
- [Commits](ROCm/rocm-docs-core@v0.26.0...v0.27.0)

---
updated-dependencies:
- dependency-name: rocm-docs-core[api_reference]
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* RPP Resize Mirror Normalize Bugfix (r-abishek#252)

* added fix for hipMemset

* remove pixel check for U8-F32 and U8-F16 for HOST codes

---------

Co-authored-by: sampath1117 <sampath.rachumallu@multicorewareinc.com>

* added example for MMS calculation in comments for better understanding

* Sphinx - updates (r-abishek#257)

* Sphinx - updates

* Doxygen - Updates

* Docs - Remove index.md

* updated info used to for running audio test suite

* removed bitdepth variable from audio test suite

* added more information on computing NSR outputs in the example added

* Fix doxygen for decibels

Also removes extra QA reference files

* move tensor_host_audio.cpp to host folder

* Fix build errors and qa tests in Audio Test suite

* Fix build errors and qa tests in Audio Test suite

* Add reference output and test samples for downmix

* Add down_mix in augmentation list and supported cases

* Remove auto-merge repeated funcs

* Improve clarity of header docs

* Remove blank line

* Improve clarity on header docs

* Add Doxygen comments

* minor change

* converted golden outputs to binary file for downmixing

* removed old golden output file for preemphasis and todecibels

* modified info for downmixing as per new changes

used handle memory for temporary buffers

* formatting changes

* moved the common code for SSE and AVX to outside

* Update down_mixing.hpp license

* Update rppt_tensor_audio_augmentations.h

* combined the srcLength and channels tensors into single tensor

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: Snehaa Giridharan <snehaa@multicorewareinc.com>
Co-authored-by: HazarathKumarM <hazarathkumar@multicorewareinc.com>
Co-authored-by: sampath1117 <sampath.rachumallu@multicorewareinc.com>
Co-authored-by: Kiriti Gowda <kiritigowda@gmail.com>
Co-authored-by: Snehaa-Giridharan <118163708+snehaa8@users.noreply.github.com>
Co-authored-by: Lisa <lisajdelaney@gmail.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Sundarrajan98 <sundarrajan@multicorewareinc.com>
  • Loading branch information
9 people committed Mar 5, 2024
1 parent 12cb43c commit 1c4c366
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 6 deletions.
17 changes: 16 additions & 1 deletion include/rppt_tensor_audio_augmentations.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,22 @@ RppStatus rppt_to_decibels_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_
*/
RppStatus rppt_pre_emphasis_filter_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcLengthTensor, Rpp32f *coeffTensor, RpptAudioBorderType borderType, rppHandle_t rppHandle);

/*! \brief Down Mixing augmentation on HOST backend
* \details Down Mixing augmentation for audio data
* \param[in] srcPtr source tensor in HOST memory
* \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[out] dstPtr destination tensor in HOST memory
* \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32)
* \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2)
* \param[in] normalizeWeights bool flag to specify if normalization of weights is needed
* \param[in] rppHandle RPP HOST handle created with <tt>\ref rppCreateWithBatchSize()</tt>
* \return A <tt> \ref RppStatus</tt> enumeration.
* \retval RPP_SUCCESS Successful completion.
* \retval RPP_ERROR* Unsuccessful completion.
*/
RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDimsTensor, bool normalizeWeights, rppHandle_t rppHandle);

#ifdef __cplusplus
}
#endif
#endif // RPPT_TENSOR_AUDIO_AUGMENTATIONS_H
#endif // RPPT_TENSOR_AUDIO_AUGMENTATIONS_H
23 changes: 23 additions & 0 deletions src/include/cpu/rpp_cpu_simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2438,6 +2438,29 @@ static inline __m128 log_ps(__m128 x)
return x;
}

inline Rpp32f rpp_hsum_ps(__m128 x)
{
__m128 shuf = _mm_movehdup_ps(x); // broadcast elements 3,1 to 2,0
__m128 sums = _mm_add_ps(x, shuf);
shuf = _mm_movehl_ps(shuf, sums); // high half -> low half
sums = _mm_add_ss(sums, shuf);
return _mm_cvtss_f32(sums);
}

inline Rpp32f rpp_hsum_ps(__m256 x)
{
__m128 p0 = _mm256_extractf128_ps(x, 1); // Contains x7, x6, x5, x4
__m128 p1 = _mm256_castps256_ps128(x); // Contains x3, x2, x1, x0
__m128 sum = _mm_add_ps(p0, p1); // Contains x3 + x7, x2 + x6, x1 + x5, x0 + x4
p0 = sum; // Contains -, -, x1 + x5, x0 + x4
p1 = _mm_movehl_ps(sum, sum); // Contains -, -, x3 + x7, x2 + x6
sum = _mm_add_ps(p0, p1); // Contains -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6
p0 = sum; // Contains -, -, -, x0 + x2 + x4 + x6
p1 = _mm_shuffle_ps(sum, sum, 0x1); // Contains -, -, -, x1 + x3 + x5 + x7
sum = _mm_add_ss(p0, p1); // Contains -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7
return _mm_cvtss_f32(sum);
}

static inline void fast_matmul4x4_sse(float *A, float *B, float *C)
{
__m128 row1 = _mm_load_ps(&B[0]); // Row 0 of B
Expand Down
1 change: 1 addition & 0 deletions src/modules/cpu/host_tensor_audio_augmentations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ SOFTWARE.
#include "kernel/non_silent_region_detection.hpp"
#include "kernel/to_decibels.hpp"
#include "kernel/pre_emphasis_filter.hpp"
#include "kernel/down_mixing.hpp"

#endif // HOST_TENSOR_AUDIO_AUGMENTATIONS_HPP
122 changes: 122 additions & 0 deletions src/modules/cpu/kernel/down_mixing.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
MIT License
Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#include "rppdefs.h"
#include <omp.h>

RppStatus down_mixing_host_tensor(Rpp32f *srcPtr,
RpptDescPtr srcDescPtr,
Rpp32f *dstPtr,
RpptDescPtr dstDescPtr,
Rpp32s *srcDimsTensor,
bool normalizeWeights,
rpp::Handle& handle)
{
Rpp32u numThreads = handle.GetNumThreads();

omp_set_dynamic(0);
#pragma omp parallel for num_threads(numThreads)
for(int batchCount = 0; batchCount < srcDescPtr->n; batchCount++)
{
Rpp32f *srcPtrTemp = srcPtr + batchCount * srcDescPtr->strides.nStride;
Rpp32f *dstPtrTemp = dstPtr + batchCount * dstDescPtr->strides.nStride;

Rpp32s samples = srcDimsTensor[batchCount * 2];
Rpp32s channels = srcDimsTensor[batchCount * 2 + 1];
bool flagAVX = 0;

if(channels == 1)
{
// No need of downmixing, do a direct memcpy
memcpy(dstPtrTemp, srcPtrTemp, (size_t)(samples * sizeof(Rpp32f)));
}
else
{
Rpp32f *weights = handle.GetInitHandle()->mem.mcpu.tempFloatmem + batchCount * channels;
std::fill(weights, weights + channels, 1.f / channels);

if(normalizeWeights)
{
// Compute sum of the weights
Rpp32f sum = 0.0;
for(int i = 0; i < channels; i++)
sum += weights[i];

// Normalize the weights
Rpp32f invSum = 1.0 / sum;
for(int i = 0; i < channels; i++)
weights[i] *= invSum;
}

Rpp32s channelIncrement = 4;
Rpp32s alignedChannels = (channels / 4) * 4;
if(channels > 7)
{
flagAVX = 1;
channelIncrement = 8;
alignedChannels = (channels / 8) * 8;
}

// use weights to downmix to mono
for(int64_t dstIdx = 0; dstIdx < samples; dstIdx++)
{
Rpp32s channelLoopCount = 0;
// if number of channels are greater than or equal to 8, use AVX implementation
if(flagAVX)
{
__m256 pDst = avx_p0;
for(; channelLoopCount < alignedChannels; channelLoopCount += channelIncrement)
{
__m256 pSrc, pWeights;
pWeights = _mm256_setr_ps(weights[channelLoopCount], weights[channelLoopCount + 1], weights[channelLoopCount + 2], weights[channelLoopCount + 3],
weights[channelLoopCount + 4], weights[channelLoopCount + 5], weights[channelLoopCount + 6], weights[channelLoopCount + 7]);
pSrc = _mm256_loadu_ps(srcPtrTemp);
pSrc = _mm256_mul_ps(pSrc, pWeights);
pDst = _mm256_add_ps(pDst, pSrc);
srcPtrTemp += channelIncrement;
}
dstPtrTemp[dstIdx] = rpp_hsum_ps(pDst);
}
else
{
__m128 pDst = xmm_p0;
for(; channelLoopCount < alignedChannels; channelLoopCount += channelIncrement)
{
__m128 pSrc, pWeights;
pWeights = _mm_setr_ps(weights[channelLoopCount], weights[channelLoopCount + 1], weights[channelLoopCount + 2], weights[channelLoopCount + 3]);
pSrc = _mm_loadu_ps(srcPtrTemp);
pSrc = _mm_mul_ps(pSrc, pWeights);
pDst = _mm_add_ps(pDst, pSrc);
srcPtrTemp += channelIncrement;
}
dstPtrTemp[dstIdx] = rpp_hsum_ps(pDst);
}
for(; channelLoopCount < channels; channelLoopCount++)
dstPtrTemp[dstIdx] += ((*srcPtrTemp++) * weights[channelLoopCount]);
}
}
}

return RPP_SUCCESS;
}
28 changes: 28 additions & 0 deletions src/modules/rppt_tensor_audio_augmentations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,31 @@ RppStatus rppt_pre_emphasis_filter_host(RppPtr_t srcPtr,
return RPP_ERROR_NOT_IMPLEMENTED;
}
}

/******************** down_mixing ********************/

RppStatus rppt_down_mixing_host(RppPtr_t srcPtr,
RpptDescPtr srcDescPtr,
RppPtr_t dstPtr,
RpptDescPtr dstDescPtr,
Rpp32s *srcDimsTensor,
bool normalizeWeights,
rppHandle_t rppHandle)
{
if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32))
{
down_mixing_host_tensor(static_cast<Rpp32f*>(srcPtr),
srcDescPtr,
static_cast<Rpp32f*>(dstPtr),
dstDescPtr,
srcDimsTensor,
normalizeWeights,
rpp::deref(rppHandle));

return RPP_SUCCESS;
}
else
{
return RPP_ERROR_NOT_IMPLEMENTED;
}
}
2 changes: 1 addition & 1 deletion utilities/test_suite/HOST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ else()
endif()

if(NOT libsnd_LIBS)
message("-- ${Yellow}Warning: libsndfile must be installed to install ${PROJECT_NAME}/Tensor_voxel_host successfully!${ColourReset}")
message("-- ${Yellow}Warning: libsndfile must be installed to install ${PROJECT_NAME}/Tensor_audio_host successfully!${ColourReset}")
else()
message("-- ${Green}${PROJECT_NAME} set to build with rpp and libsndfile ${ColourReset}")
include_directories(${ROCM_PATH}/include ${ROCM_PATH}/include/rpp /usr/local/include)
Expand Down
21 changes: 20 additions & 1 deletion utilities/test_suite/HOST/Tensor_host_audio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,25 @@ int main(int argc, char **argv)

break;
}
case 3:
{
testCaseName = "down_mixing";
bool normalizeWeights = false;
Rpp32s srcDimsTensor[batchSize * 2];

for (int i = 0, j = 0; i < batchSize; i++, j += 2)
{
srcDimsTensor[j] = srcLengthTensor[i];
srcDimsTensor[j + 1] = channelsTensor[i];
dstDims[i].height = srcLengthTensor[i];
dstDims[i].width = 1;
}

startWallTime = omp_get_wtime();
rppt_down_mixing_host(inputf32, srcDescPtr, outputf32, dstDescPtr, srcDimsTensor, normalizeWeights, handle);

break;
}
default:
{
missingFuncFlag = 1;
Expand Down Expand Up @@ -263,4 +282,4 @@ int main(int argc, char **argv)
free(inputf32);
free(outputf32);
return 0;
}
}
22 changes: 20 additions & 2 deletions utilities/test_suite/HOST/runAudioTests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
outFolderPath = os.getcwd()
buildFolderPath = os.getcwd()
caseMin = 0
caseMax = 2
caseMax = 3

# Checks if the folder path is empty, or is it a root folder, or if it exists, and remove its contents
def validate_and_remove_files(path):
Expand Down Expand Up @@ -235,13 +235,31 @@ def rpp_test_suite_parser_and_validator():
exit(0)

for case in caseList:
if "--input_path" not in sys.argv:
if case == "3":
srcPath = scriptPath + "/../TEST_AUDIO_FILES/three_sample_multi_channel_src1"
else:
srcPath = inFilePath
if int(case) < 0 or int(case) > 3:
print(f"Invalid case number {case}. Case number must be 0-3 range!")
continue

run_unit_test(srcPath, case, numRuns, testType, batchSize, outFilePath)
else:
for case in caseList:
if "--input_path" not in sys.argv:
if case == "3":
srcPath = scriptPath + "/../TEST_AUDIO_FILES/three_sample_multi_channel_src1"
else:
srcPath = inFilePath
if int(case) < 0 or int(case) > 3:
print(f"Invalid case number {case}. Case number must be 0-3 range!")
continue

run_performance_test(loggingFolder, srcPath, case, numRuns, testType, batchSize, outFilePath)

# print the results of qa tests
supportedCaseList = ['0', '1', '2']
supportedCaseList = ['0', '1', '2', '3']
nonQACaseList = [] # Add cases present in supportedCaseList, but without QA support

if testType == 0:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
3 changes: 2 additions & 1 deletion utilities/test_suite/rpp_test_suite_audio.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::map<int, string> audioAugmentationMap =
{0, "non_silent_region_detection"},
{1, "to_decibels"},
{2, "pre_emphasis_filter"},
{3, "down_mixing"},
};

// Golden outputs for Non Silent Region Detection
Expand Down Expand Up @@ -137,7 +138,7 @@ void verify_output(Rpp32f *dstPtr, RpptDescPtr dstDescPtr, RpptImagePatchPtr dst
// read data from golden outputs
Rpp64u oBufferSize = dstDescPtr->n * dstDescPtr->strides.nStride;
Rpp32f *refOutput = static_cast<Rpp32f *>(malloc(oBufferSize * sizeof(float)));
string outFile = scriptPath + testCase + "/" + testCase + ".bin";
string outFile = scriptPath + "/../REFERENCE_OUTPUTS_AUDIO/" + testCase + "/" + testCase + ".bin";
std::fstream fin(outFile, std::ios::in | std::ios::binary);
if(fin.is_open())
{
Expand Down

0 comments on commit 1c4c366

Please sign in to comment.