Skip to content
Permalink
Browse files Browse the repository at this point in the history
Make MfccMelFilterbank fail initialization if num_channels is > max i…
…nt value.

Also initialize MfccDct only if MfccMelFilterbank initialization was successful.

PiperOrigin-RevId: 477844246
  • Loading branch information
swachhandl authored and tensorflower-gardener committed Sep 29, 2022
1 parent 21b567d commit 39ec7ea
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 8 deletions.
6 changes: 4 additions & 2 deletions tensorflow/core/kernels/mfcc.cc
Expand Up @@ -38,8 +38,10 @@ bool Mfcc::Initialize(int input_length, double input_sample_rate) {
bool initialized = mel_filterbank_.Initialize(
input_length, input_sample_rate, filterbank_channel_count_,
lower_frequency_limit_, upper_frequency_limit_);
initialized &=
dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_);
if (initialized) {
initialized =
dct_.Initialize(filterbank_channel_count_, dct_coefficient_count_);
}
initialized_ = initialized;
return initialized;
}
Expand Down
14 changes: 13 additions & 1 deletion tensorflow/core/kernels/mfcc_mel_filterbank.cc
Expand Up @@ -32,6 +32,8 @@ limitations under the License.

#include <math.h>

#include <limits>

#include "tensorflow/core/platform/logging.h"

namespace tensorflow {
Expand Down Expand Up @@ -74,7 +76,17 @@ bool MfccMelFilterbank::Initialize(int input_length, double input_sample_rate,

// An extra center frequency is computed at the top to get the upper
// limit on the high side of the final triangular filter.
center_frequencies_.resize(num_channels_ + 1);
std::size_t center_frequencies_size = std::size_t(num_channels_) + 1;
if (center_frequencies_size >= std::numeric_limits<int>::max() ||
center_frequencies_size > center_frequencies_.max_size()) {
LOG(ERROR) << "Number of filterbank channels must be less than "
<< std::numeric_limits<int>::max()
<< " and less than or equal to "
<< center_frequencies_.max_size();
return false;
}
center_frequencies_.resize(center_frequencies_size);

const double mel_low = FreqToMel(lower_frequency_limit);
const double mel_hi = FreqToMel(upper_frequency_limit);
const double mel_span = mel_hi - mel_low;
Expand Down
34 changes: 34 additions & 0 deletions tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"

#include <limits>
#include <vector>

#include "tensorflow/core/platform/test.h"
Expand Down Expand Up @@ -85,4 +86,37 @@ TEST(MfccMelFilterbankTest, IgnoresExistingContentOfOutputVector) {
}
}

TEST(MfccMelFilterbankTest, FailsWhenChannelsGreaterThanMaxIntValue) {
// Test for bug where vector throws a length_error when it suspects the size
// to be more than it's max_size. For now, we fail initialization when the
// number of requested channels is >= the maximum value int can take (since
// num_channels_ is an int).
MfccMelFilterbank filterbank;

const int kSampleCount = 513;
std::size_t num_channels = std::numeric_limits<int>::max();
bool initialized = filterbank.Initialize(
kSampleCount, 2 /* sample rate */, num_channels /* channels */,
1.0 /* lower frequency limit */, 5.0 /* upper frequency limit */);

EXPECT_FALSE(initialized);
}

TEST(MfccMelFilterbankTest, FailsWhenChannelsGreaterThanMaxSize) {
// Test for bug where vector throws a length_error when it suspects the size
// to be more than it's max_size. For now, we fail initialization when the
// number of requested channels is > than std::vector<double>::max_size().
MfccMelFilterbank filterbank;

const int kSampleCount = 513;
// Set num_channels to exceed the max_size a double vector can
// theoretically take.
std::size_t num_channels = std::vector<double>().max_size() + 1;
bool initialized = filterbank.Initialize(
kSampleCount, 2 /* sample rate */, num_channels /* channels */,
1.0 /* lower frequency limit */, 5.0 /* upper frequency limit */);

EXPECT_FALSE(initialized);
}

} // namespace tensorflow
12 changes: 7 additions & 5 deletions tensorflow/core/kernels/mfcc_op.cc
Expand Up @@ -25,7 +25,7 @@ limitations under the License.

namespace tensorflow {

// Create a speech fingerpring from spectrogram data.
// Create a speech fingerprint from spectrogram data.
class MfccOp : public OpKernel {
public:
explicit MfccOp(OpKernelConstruction* context) : OpKernel(context) {
Expand Down Expand Up @@ -60,10 +60,12 @@ class MfccOp : public OpKernel {
mfcc.set_lower_frequency_limit(lower_frequency_limit_);
mfcc.set_filterbank_channel_count(filterbank_channel_count_);
mfcc.set_dct_coefficient_count(dct_coefficient_count_);
OP_REQUIRES(context, mfcc.Initialize(spectrogram_channels, sample_rate),
errors::InvalidArgument(
"Mfcc initialization failed for channel count ",
spectrogram_channels, " and sample rate ", sample_rate));
OP_REQUIRES(
context, mfcc.Initialize(spectrogram_channels, sample_rate),
errors::InvalidArgument("Mfcc initialization failed for channel count ",
spectrogram_channels, ", sample rate ",
sample_rate, " and filterbank_channel_count ",
filterbank_channel_count_));

Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
Expand Down

0 comments on commit 39ec7ea

Please sign in to comment.