-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Audio HIP PR6 - Mel Filter Bank HIP Support #286
Changes from all commits
1147bfe
5e3fc7a
fe1a3e6
77e14ef
34f3f6d
ab52683
ee0d6fe
2decd32
30ce1d6
64ae74f
1a3015c
e5865f9
c87f98b
64ca5a3
5eeb4b1
708160c
7e4f3f1
7e7af14
8c39e81
5089694
9122db9
c63d026
5f59f4e
a495814
4955aaa
db2a979
fc744f5
298d08b
261f93f
497288f
ea592a2
5825741
f9e70ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,12 @@ SOFTWARE. | |
} \ | ||
} while (0) | ||
|
||
#ifdef HIP_COMPILE | ||
#define RPP_HOST_DEVICE __host__ __device__ | ||
#else | ||
#define RPP_HOST_DEVICE | ||
#endif | ||
|
||
const float ONE_OVER_6 = 1.0f / 6; | ||
const float ONE_OVER_3 = 1.0f / 3; | ||
const float ONE_OVER_255 = 1.0f / 255; | ||
|
@@ -145,7 +151,9 @@ typedef enum | |
/*! \brief Scratch memory size needed is beyond the bounds (Needs to adhere to function specification.) \ingroup group_rppdefs */ | ||
RPP_ERROR_OUT_OF_BOUND_SCRATCH_MEMORY_SIZE = -22, | ||
/*! \brief Number of src dims is invalid. (Needs to adhere to function specification.) \ingroup group_rppdefs */ | ||
RPP_ERROR_INVALID_SRC_DIMS = -23 | ||
RPP_ERROR_INVALID_SRC_DIMS = -23, | ||
/*! \brief Number of dst dims is invalid. (Needs to adhere to function specification.) \ingroup group_rppdefs */ | ||
RPP_ERROR_INVALID_DST_DIMS = -24 | ||
} RppStatus; | ||
|
||
/*! \brief RPP rppStatus_t type enums | ||
|
@@ -738,6 +746,67 @@ typedef struct RpptResamplingWindow | |
__m128 pCenter, pScale; | ||
} RpptResamplingWindow; | ||
|
||
/*! \brief Base class for Mel scale conversions. | ||
* \ingroup group_rppdefs | ||
*/ | ||
struct BaseMelScale | ||
{ | ||
public: | ||
inline RPP_HOST_DEVICE virtual Rpp32f hz_to_mel(Rpp32f hz) = 0; | ||
inline RPP_HOST_DEVICE virtual Rpp32f mel_to_hz(Rpp32f mel) = 0; | ||
virtual ~BaseMelScale() = default; | ||
}; | ||
|
||
/*! \brief Derived class for HTK Mel scale conversions. | ||
* \ingroup group_rppdefs | ||
*/ | ||
struct HtkMelScale : public BaseMelScale | ||
{ | ||
inline RPP_HOST_DEVICE Rpp32f hz_to_mel(Rpp32f hz) { return 1127.0f * std::log(1.0f + (hz / 700.0f)); } | ||
inline RPP_HOST_DEVICE Rpp32f mel_to_hz(Rpp32f mel) { return 700.0f * (std::exp(mel / 1127.0f) - 1.0f); } | ||
public: | ||
~HtkMelScale() {}; | ||
}; | ||
|
||
/*! \brief Derived class for Slaney Mel scale conversions. | ||
* \ingroup group_rppdefs | ||
*/ | ||
struct SlaneyMelScale : public BaseMelScale | ||
{ | ||
const Rpp32f freqLow = 0; | ||
const Rpp32f fsp = 66.666667f; | ||
const Rpp32f minLogHz = 1000.0; | ||
const Rpp32f minLogMel = (minLogHz - freqLow) / fsp; | ||
const Rpp32f stepLog = 0.068751777; // Equivalent to std::log(6.4) / 27.0; | ||
|
||
const Rpp32f invMinLogHz = 0.001f; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HazarathKumarM Please ensure to run HOST QA tests and ensure it passes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. passing |
||
const Rpp32f invStepLog = 1.0f / stepLog; | ||
const Rpp32f invFsp = 1.0f / fsp; | ||
|
||
inline RPP_HOST_DEVICE Rpp32f hz_to_mel(Rpp32f hz) | ||
{ | ||
Rpp32f mel = 0.0f; | ||
if (hz >= minLogHz) | ||
mel = minLogMel + std::log(hz * invMinLogHz) * invStepLog; | ||
else | ||
mel = (hz - freqLow) * invFsp; | ||
|
||
return mel; | ||
} | ||
|
||
inline RPP_HOST_DEVICE Rpp32f mel_to_hz(Rpp32f mel) | ||
{ | ||
Rpp32f hz = 0.0f; | ||
if (mel >= minLogMel) | ||
hz = minLogHz * std::exp(stepLog * (mel - minLogMel)); | ||
else | ||
hz = freqLow + mel * fsp; | ||
return hz; | ||
} | ||
public: | ||
~SlaneyMelScale() {}; | ||
}; | ||
|
||
/******************** HOST memory typedefs ********************/ | ||
|
||
/*! \brief RPP HOST 32-bit float memory | ||
|
@@ -1055,7 +1124,7 @@ typedef struct | |
Rpp64u* dstBatchIndex; | ||
Rpp32u* inc; | ||
Rpp32u* dstInc; | ||
hipMemRpp32u scratchBuf; | ||
hipMemRpp32f scratchBufferPinned; | ||
} memGPU; | ||
|
||
/*! \brief RPP HIP-HOST memory management | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
#include <hip/hip_runtime.h> | ||
#include "rpp_hip_common.hpp" | ||
|
||
__device__ __forceinline__ void compute_mel(float *srcPtr, int melBin, float *weightsDown, int *intervals, int2 fftStrides, float normFactor, float &dstVal) | ||
{ | ||
dstVal = 0; | ||
//start and end FFT bin indices for the current mel bin | ||
int fftbin = intervals[melBin]; | ||
int fftBinEnd = intervals[melBin + 1]; | ||
|
||
float *srcPtrTemp = srcPtr + fftbin * fftStrides.x + fftStrides.y; | ||
// Process the first interval of FFT bins, applying the weights up | ||
for (; fftbin < fftBinEnd; fftbin++, srcPtrTemp += fftStrides.x) | ||
{ | ||
float weightUp = 1.0f - weightsDown[fftbin]; | ||
weightUp *= normFactor; | ||
dstVal += *srcPtrTemp * weightUp; | ||
} | ||
|
||
fftBinEnd = intervals[melBin + 2]; // Update the end FFT bin index for the next interval | ||
srcPtrTemp = srcPtr + fftbin * fftStrides.x + fftStrides.y; | ||
|
||
// Process the second interval of FFT bins, applying the weights down | ||
for (; fftbin < fftBinEnd; fftbin++, srcPtrTemp += fftStrides.x) | ||
{ | ||
float weightDown = weightsDown[fftbin]; | ||
weightDown *= normFactor; | ||
dstVal += *srcPtrTemp * weightDown; | ||
} | ||
} | ||
|
||
__global__ void mel_filter_bank_tensor(float *srcPtr, | ||
uint2 srcStridesNH, | ||
float *dstPtr, | ||
uint2 dstStridesNH, | ||
int *srcDimsTensor, | ||
int numFilter, | ||
bool normalize, | ||
float *normFactors, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check if you can combine multiple params into single param There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed some of the un necessary params |
||
float *weightsDown, | ||
int *intervals) | ||
{ | ||
int id_x = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; | ||
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y; | ||
int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z; | ||
|
||
if (id_x >= srcDimsTensor[id_z * 2 + 1] || id_y >= numFilter) | ||
return; | ||
|
||
uint dstIdx = id_z * dstStridesNH.x + id_y * dstStridesNH.y + id_x; | ||
uint srcIdx = id_z * srcStridesNH.x; | ||
|
||
float normFactor = (normalize) ? normFactors[id_y] : 1; | ||
compute_mel(srcPtr + srcIdx, id_y, weightsDown, intervals, make_int2(srcStridesNH.y, id_x), normFactor, dstPtr[dstIdx]); | ||
} | ||
|
||
RppStatus hip_exec_mel_filter_bank_tensor(Rpp32f *srcPtr, | ||
RpptDescPtr srcDescPtr, | ||
Rpp32f *dstPtr, | ||
RpptDescPtr dstDescPtr, | ||
Rpp32s* srcDimsTensor, | ||
Rpp32f maxFreqVal, | ||
Rpp32f minFreqVal, | ||
RpptMelScaleFormula melFormula, | ||
Rpp32s numFilter, | ||
Rpp32f sampleRate, | ||
bool normalize, | ||
rpp::Handle& handle) | ||
{ | ||
// Create an instance of the MelScale class based on the chosen formula | ||
BaseMelScale *melScalePtr; | ||
switch (melFormula) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this initial setup from L70-L132 (before you set globalThreads), can we have some simple one liner comments for each logically separated code block as to what you are doing or setting up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
{ | ||
case RpptMelScaleFormula::HTK: | ||
melScalePtr = new HtkMelScale; | ||
break; | ||
case RpptMelScaleFormula::SLANEY: | ||
default: | ||
melScalePtr = new SlaneyMelScale(); | ||
break; | ||
} | ||
|
||
Rpp32f maxFreq = sampleRate / 2; | ||
Rpp32f minFreq = minFreqVal; | ||
|
||
// Convert the frequency range to Mel scale and compute Mel step size | ||
Rpp64f melLow = melScalePtr->hz_to_mel(minFreq); | ||
Rpp64f melHigh = melScalePtr->hz_to_mel(maxFreq); | ||
Rpp64f melStep = (melHigh - melLow) / (numFilter + 1); | ||
|
||
Rpp32f *scratchMem = handle.GetInitHandle()->mem.mgpu.scratchBufferPinned.floatmem; | ||
Rpp32f *normFactors = scratchMem; | ||
Rpp32f *weightsDown = scratchMem + numFilter; | ||
Rpp32s *intervals = reinterpret_cast<Rpp32s *>(weightsDown + srcDescPtr->h); | ||
|
||
// parameters for FFT and frequency bins | ||
Rpp32s nfft = (srcDescPtr->h - 1) * 2; | ||
Rpp32s numBins = nfft / 2 + 1; | ||
Rpp64f hzStep = static_cast<Rpp64f>(sampleRate) / nfft; | ||
Rpp64f invHzStep = 1.0 / hzStep; | ||
|
||
// start and end bins for the Mel filter bank | ||
Rpp32s fftBinStart = std::ceil(minFreq * invHzStep); | ||
Rpp32s fftBinEnd = std::ceil(maxFreq * invHzStep); | ||
fftBinEnd = std::min(fftBinEnd, numBins); | ||
|
||
// Initialize arrays used for Mel filter bank computation | ||
std::fill(normFactors, normFactors + numFilter, 1.0f); | ||
memset(weightsDown, 0, sizeof(srcDescPtr->h * sizeof(Rpp32f))); | ||
std::fill(intervals, intervals + numFilter + 2, -1); | ||
|
||
// Compute Mel filter weights and intervals | ||
Rpp32s fftBin = fftBinStart; | ||
Rpp64f mel0 = melLow, mel1 = melLow + melStep; | ||
Rpp64f fIter = fftBin * hzStep; | ||
|
||
intervals[0] = fftBinStart; | ||
intervals[numFilter + 1] = fftBinEnd; | ||
|
||
for (int interval = 1, index = 0; index < numFilter + 1; interval++, index++, mel0 = mel1, mel1 += melStep) | ||
{ | ||
Rpp64f f0 = melScalePtr->mel_to_hz(mel0); | ||
Rpp64f f1 = melScalePtr->mel_to_hz(index == numFilter ? melHigh : mel1); | ||
Rpp64f slope = 1.0 / (f1 - f0); | ||
intervals[interval] = std::ceil(f1 / hzStep); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove empty line at L125 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
if (normalize && index < numFilter) | ||
{ | ||
Rpp64f f2 = melScalePtr->mel_to_hz(mel1 + melStep); | ||
normFactors[index] = 2.0 / (f2 - f0); | ||
} | ||
|
||
// Compute weights for each filter bank | ||
for (; fftBin < fftBinEnd && fIter < f1; fftBin++, fIter = fftBin * hzStep) { | ||
weightsDown[fftBin] = (f1 - fIter) * slope; | ||
} | ||
} | ||
|
||
Rpp32s globalThreads_x = dstDescPtr->w; | ||
Rpp32s globalThreads_y = dstDescPtr->h; | ||
Rpp32s globalThreads_z = dstDescPtr->n; | ||
hipLaunchKernelGGL(mel_filter_bank_tensor, | ||
dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)), | ||
dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z), | ||
0, | ||
handle.GetStream(), | ||
srcPtr, | ||
make_uint2(srcDescPtr->strides.nStride, srcDescPtr->strides.hStride), | ||
dstPtr, | ||
make_uint2(dstDescPtr->strides.nStride, dstDescPtr->strides.hStride), | ||
srcDimsTensor, | ||
numFilter, | ||
normalize, | ||
normFactors, | ||
weightsDown, | ||
intervals); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L153-L160 - lot of arguments passed in, could we combine? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed the un necessary arguments |
||
|
||
delete melScalePtr; | ||
return RPP_SUCCESS; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure we need these structs with all this compute code in a high level rppdefs.h header thats's used for just one functionality?
@sampath1117 Please do a round of review on this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@r-abishek
We moved this to rppdefs.h because the same code is used in both HOST and HIP
Please let us know if there is any other place where we can move this to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, for now just add the doxygen tag for this struct like in RpptResamplingWindow above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done