-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decode MP3 from Memory #815
Changes from all commits
83158c6
055c139
e539365
ded863b
c29c828
b9a7ca2
8d015d1
f0fcd18
6799c01
1c9995f
cba8bf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,86 @@ limitations under the License. | |
|
||
namespace tensorflow { | ||
namespace data { | ||
|
||
// DecodedAudio | ||
size_t DecodedAudio::data_size() { | ||
return channels * samples_perchannel * sizeof(int16); | ||
} | ||
|
||
DecodedAudio::DecodedAudio(bool success, size_t channels, | ||
size_t samples_perchannel, size_t sampling_rate, | ||
int16 *data) | ||
: success(success), channels(channels), | ||
samples_perchannel(samples_perchannel), sampling_rate(sampling_rate), | ||
data(data) {} | ||
|
||
DecodedAudio::~DecodedAudio() { | ||
if (data) { | ||
std::free((void *)data); | ||
} | ||
} | ||
|
||
// DecodeAudioBaseOp | ||
DecodeAudioBaseOp::DecodeAudioBaseOp(OpKernelConstruction *context) : OpKernel(context) {} | ||
|
||
void DecodeAudioBaseOp::Compute(OpKernelContext *context) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit unhappy with the fact that these methods are in |
||
// get the input data, i.e. encoded audio data | ||
const Tensor &input_tensor = context->input(0); | ||
const string &input_data = input_tensor.scalar<tstring>()(); | ||
StringPiece data (input_data.data(), input_data.size()); | ||
|
||
// decode audio | ||
std::unique_ptr<DecodedAudio> decoded = decode(data, nullptr); | ||
|
||
// make sure decoding was successful | ||
OP_REQUIRES( | ||
context, decoded->success, | ||
errors::InvalidArgument("Audio data could not be decoded")); | ||
|
||
// output 1: samples | ||
Tensor *output_tensor = nullptr; | ||
TensorShape output_shape {decoded->channels, decoded->samples_perchannel}; | ||
OP_REQUIRES_OK(context, | ||
context->allocate_output(0, output_shape, &output_tensor)); | ||
|
||
// copy data from decoder buffer into output tensor | ||
auto output_flat = output_tensor->flat<int16>(); | ||
std::memcpy(output_flat.data(), decoded->data, decoded->data_size()); | ||
|
||
// output 2: sample rate | ||
Tensor *sample_rate_output = nullptr; | ||
OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), | ||
&sample_rate_output)); | ||
sample_rate_output->flat<int32>()(0) = decoded->sampling_rate; | ||
} | ||
|
||
namespace { | ||
|
||
AudioFileFormat ClassifyAudioFileFormat(StringPiece &data) { | ||
static StringPiece wav_header("RIFF"); | ||
static StringPiece ogg_header("OggS"); | ||
static StringPiece flac_header("fLaC"); | ||
static StringPiece mp4_subtype_header("ftyp"); | ||
|
||
if (data.size() < 8) { | ||
// this is a bit hacky, but substr throws out of range exceptions | ||
return UnknownFormat; | ||
} else if (data.substr(0, wav_header.size()) == wav_header) { | ||
return WavFormat; | ||
} else if (data.substr(0, ogg_header.size()) == ogg_header) { | ||
return OggFormat; | ||
} else if (data.substr(0, flac_header.size()) == flac_header) { | ||
return FlacFormat; | ||
} else if (data.substr(4, mp4_subtype_header.size()) == mp4_subtype_header) { | ||
return Mp4Format; | ||
} else if (IsMP3(data)) { | ||
// MP3 can not reliably be detected by the header alone | ||
return Mp3Format; | ||
} else { | ||
return UnknownFormat; | ||
} | ||
} | ||
|
||
class AudioReadableResource : public AudioReadableResourceBase { | ||
public: | ||
AudioReadableResource(Env* env) : env_(env), resource_(nullptr) {} | ||
|
@@ -39,25 +117,29 @@ class AudioReadableResource : public AudioReadableResourceBase { | |
char header[8]; | ||
StringPiece result; | ||
TF_RETURN_IF_ERROR(file->Read(0, sizeof(header), &result, header)); | ||
if (memcmp(header, "RIFF", 4) == 0) { | ||
StringPiece header_str(header, sizeof(header)); | ||
switch (ClassifyAudioFileFormat(header_str)) { | ||
case WavFormat: | ||
return WAVReadableResourceInit(env_, filename, optional_memory, | ||
optional_length, resource_); | ||
} else if (memcmp(header, "OggS", 4) == 0) { | ||
case OggFormat: | ||
return OggReadableResourceInit(env_, filename, optional_memory, | ||
optional_length, resource_); | ||
} else if (memcmp(header, "fLaC", 4) == 0) { | ||
case FlacFormat: | ||
return FlacReadableResourceInit(env_, filename, optional_memory, | ||
optional_length, resource_); | ||
} | ||
Status status = MP3ReadableResourceInit(env_, filename, optional_memory, | ||
optional_length, resource_); | ||
if (status.ok()) { | ||
return status; | ||
} | ||
if (memcmp(&header[4], "ftyp", 4) == 0) { | ||
case Mp4Format: | ||
LOG(ERROR) << "MP4A file is not fully supported!"; | ||
return MP4ReadableResourceInit(env_, filename, optional_memory, | ||
optional_length, resource_); | ||
default: | ||
// mp3 is not always easily identifiable by the header alone | ||
// currently we are trying MP3 as a default option | ||
Status status = MP3ReadableResourceInit(env_, filename, optional_memory, | ||
optional_length, resource_); | ||
if (status.ok()) { | ||
return status; | ||
} | ||
} | ||
return errors::InvalidArgument("unknown file type: ", filename); | ||
} | ||
|
@@ -276,6 +358,34 @@ class AudioResampleOp : public OpKernel { | |
int64 quality_; | ||
}; | ||
|
||
class DecodeAudioOp : public DecodeAudioBaseOp { | ||
public: | ||
DecodeAudioOp(OpKernelConstruction *context) : DecodeAudioBaseOp(context) {} | ||
|
||
std::unique_ptr<DecodedAudio> decode(StringPiece &data, void *config) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Google's style is to return
Then you could use:
You might also pass unique_ptr as well I think:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing a pointer to a unique pointer sounds pretty hacky, do you think that's a good idea? Probably I'd rather go with the first option. Thx for pointing me to the coding style! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I don't see the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
auto error = std::unique_ptr<DecodedAudio>(new DecodedAudio(false, 0, 0, 0, nullptr)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this does no capture the error situation. |
||
switch (ClassifyAudioFileFormat(data)) { | ||
case WavFormat: | ||
LOG(ERROR) << "Direct decoding of WAV not yet supported."; | ||
return error; | ||
case OggFormat: | ||
LOG(ERROR) << "Direct decoding of Ogg not yet supported."; | ||
return error; | ||
case FlacFormat: | ||
LOG(ERROR) << "Direct decoding of Flac not yet supported."; | ||
return error; | ||
case Mp4Format: | ||
LOG(ERROR) << "Direct decoding of Mp4 not yet supported."; | ||
return error; | ||
case Mp3Format: | ||
return DecodeMP3(data); | ||
default: | ||
LOG(ERROR) << "Unsupported audio format."; | ||
return error; | ||
} | ||
} | ||
}; | ||
|
||
class AudioDecodeWAVOp : public OpKernel { | ||
public: | ||
explicit AudioDecodeWAVOp(OpKernelConstruction* context) : OpKernel(context) { | ||
|
@@ -339,6 +449,10 @@ REGISTER_KERNEL_BUILDER(Name("IO>AudioReadableRead").Device(DEVICE_CPU), | |
|
||
REGISTER_KERNEL_BUILDER(Name("IO>AudioResample").Device(DEVICE_CPU), | ||
AudioResampleOp); | ||
|
||
REGISTER_KERNEL_BUILDER(Name("IO>AudioDecode").Device(DEVICE_CPU), | ||
DecodeAudioOp); | ||
|
||
REGISTER_KERNEL_BUILDER(Name("IO>AudioDecodeWAV").Device(DEVICE_CPU), | ||
AudioDecodeWAVOp); | ||
} // namespace | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -19,6 +19,15 @@ limitations under the License. | |||
namespace tensorflow { | ||||
namespace data { | ||||
|
||||
enum AudioFileFormat { | ||||
UnknownFormat = 0, | ||||
WavFormat = 1, | ||||
FlacFormat = 2, | ||||
OggFormat = 3, | ||||
Mp4Format = 4, | ||||
Mp3Format = 5 | ||||
}; | ||||
|
||||
class AudioReadableResourceBase : public ResourceBase { | ||||
public: | ||||
virtual Status Init(const string& filename, | ||||
|
@@ -52,5 +61,40 @@ Status MP4ReadableResourceInit( | |||
const size_t optional_length, | ||||
std::unique_ptr<AudioReadableResourceBase>& resource); | ||||
|
||||
// Container for decoded audio. | ||||
class DecodedAudio { | ||||
public: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure we need a
maybe we don't need this class after all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think you're right. I think a struct is enough. |
||||
const bool success; | ||||
const int channels; | ||||
const int samples_perchannel; | ||||
const int sampling_rate; | ||||
// should first contain all samples of the left channel | ||||
// followed by the right channel | ||||
const int16 *data; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could avoid allocate the memory here, as we can create the output There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about this for a while, but I'm not sure how it would work. The problem is I need a shape to allocate the output tensor, which I get by decoding the MP3, for which in turn I need to have memory. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case of a shape, in:
a callback-type of lambda is passed which allows the allocation to be done when shape is ready. This woucl be helpful when we only want to call once to |
||||
|
||||
size_t data_size(); | ||||
|
||||
DecodedAudio(bool success, size_t channels, size_t samples_perchannel, | ||||
size_t sampling_rate, int16 *data); | ||||
~DecodedAudio(); | ||||
}; | ||||
|
||||
// Base class for simple, in-memory audio data decoding. | ||||
// Handles creating output tensors of the right shape. | ||||
// Subclasses must implement the decode method. | ||||
class DecodeAudioBaseOp : public OpKernel { | ||||
public: | ||||
explicit DecodeAudioBaseOp(OpKernelConstruction* context); | ||||
void Compute(OpKernelContext* context) override; | ||||
virtual std::unique_ptr<DecodedAudio> decode(StringPiece &data, void *config) = 0; | ||||
}; | ||||
|
||||
// TODO it seems a bit weird to have this here, maybe we should create a audio_mp3_kernels header | ||||
// Detect MP3 data. | ||||
bool IsMP3(StringPiece &data); | ||||
|
||||
// Decode MP3 data. | ||||
std::unique_ptr<DecodedAudio> DecodeMP3(StringPiece &data); | ||||
|
||||
} // namespace data | ||||
} // namespace tensorflow |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""tensorflow_io.audio""" | ||
|
||
from tensorflow_io.core.python.ops.audio_ops import decode # pylint: disable=unused-import | ||
from tensorflow_io.core.python.ops.audio_ops import decode_mp3 # pylint: disable=unused-import |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think google's code style prefer method name with CamelCase, so DataSize instead?
Also, not all audio files stay with
int16
so this one has to at least take into consideration the data types. But that is a larger discussion.