Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix audio spectrogram FPE.
Do input validation in shape function.

PiperOrigin-RevId: 503481241
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jan 20, 2023
1 parent 430626d commit d0d4e77
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/kernels/BUILD
Expand Up @@ -5922,6 +5922,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:status_matchers",
],
)

Expand Down
41 changes: 41 additions & 0 deletions tensorflow/core/kernels/spectrogram_op_test.cc
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/audio_ops.h"
#include "tensorflow/cc/ops/const_op.h"
Expand All @@ -29,6 +31,9 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tsl/lib/core/status_test_util.h"
#include "tensorflow/tsl/platform/errors.h"
#include "tensorflow/tsl/platform/status_matchers.h"

namespace tensorflow {
namespace ops {
Expand Down Expand Up @@ -140,6 +145,42 @@ TEST(SpectrogramOpTest, MultichannelTest) {
}
}

TEST(SpectrogramOpTest, InvalidWindowSize) {
Scope root = Scope::NewRootScope();
const int audio_size = 8;
const int channel_size = 2;
Tensor audio_tensor(DT_FLOAT, TensorShape({audio_size, channel_size}));
test::FillValues<float>(
&audio_tensor, {-1.0f, -1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, -1.0f,
-1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f});
Output audio_const_op = Const(root.WithOpName("audio_const_op"),
Input::Initializer(audio_tensor));
AudioSpectrogram spectrogram_op =
AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op,
/*window_size=*/1, /*stride=*/1);
EXPECT_THAT(root.status(),
tsl::testing::StatusIs(tsl::error::Code::INVALID_ARGUMENT,
::testing::ContainsRegex("window size")));
}

TEST(SpectrogramOpTest, InvalidStride) {
Scope root = Scope::NewRootScope();
const int audio_size = 8;
const int channel_size = 2;
Tensor audio_tensor(DT_FLOAT, TensorShape({audio_size, channel_size}));
test::FillValues<float>(
&audio_tensor, {-1.0f, -1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f, -1.0f,
-1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 0.0f});
Output audio_const_op = Const(root.WithOpName("audio_const_op"),
Input::Initializer(audio_tensor));
AudioSpectrogram spectrogram_op =
AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op,
/*window_size=*/2, /*stride=*/0);
EXPECT_THAT(root.status(),
tsl::testing::StatusIs(tsl::error::Code::INVALID_ARGUMENT,
::testing::ContainsRegex("stride")));
}

} // namespace
} // namespace ops
} // namespace tensorflow
10 changes: 10 additions & 0 deletions tensorflow/core/ops/audio_ops.cc
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/errors.h"

namespace tensorflow {

Expand Down Expand Up @@ -72,8 +73,17 @@ Status SpectrogramShapeFn(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
int32_t window_size;
TF_RETURN_IF_ERROR(c->GetAttr("window_size", &window_size));
if (window_size <= 1) {
return errors::InvalidArgument("window size must be > 1, got ",
window_size);
}

int32_t stride;
TF_RETURN_IF_ERROR(c->GetAttr("stride", &stride));
if (stride <= 0) {
return errors::InvalidArgument("stride must be strictly positive, got ",
stride);
}

DimensionHandle input_length = c->Dim(input, 0);
DimensionHandle input_channels = c->Dim(input, 1);
Expand Down

0 comments on commit d0d4e77

Please sign in to comment.