Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode MP3 from Memory #815

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,10 @@ http_archive(
http_archive(
name = "minimp3",
build_file = "//third_party:minimp3.BUILD",
sha256 = "53dd89dbf235c3a282b61fec07eb29730deb1a828b0c9ec95b17b9bd4b22cc3d",
strip_prefix = "minimp3-2b9a0237547ca5f6f98e28a850237cc68f560f7a",
sha256 = "09395758f4c964fb158875f3cc9b9a65f36e9f5b2a27fb10f99519a0a6aef664",
strip_prefix = "minimp3-55da78cbeea5fb6757f8df672567714e1e8ca3e9",
urls = [
"https://github.com/lieff/minimp3/archive/2b9a0237547ca5f6f98e28a850237cc68f560f7a.tar.gz",
"https://github.com/lieff/minimp3/archive/55da78cbeea5fb6757f8df672567714e1e8ca3e9.tar.gz",
],
)

Expand Down
134 changes: 124 additions & 10 deletions tensorflow_io/core/kernels/audio_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,86 @@ limitations under the License.

namespace tensorflow {
namespace data {

// DecodedAudio
size_t DecodedAudio::data_size() {
return channels * samples_perchannel * sizeof(int16);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think google's code style prefer method name with CamelCase, so DataSize instead?

Also, not all audio files stay with int16 so this one has to at least take into consideration the data types. But that is a larger discussion.

}

DecodedAudio::DecodedAudio(bool success, size_t channels,
size_t samples_perchannel, size_t sampling_rate,
int16 *data)
: success(success), channels(channels),
samples_perchannel(samples_perchannel), sampling_rate(sampling_rate),
data(data) {}

DecodedAudio::~DecodedAudio() {
if (data) {
std::free((void *)data);
}
}

// DecodeAudioBaseOp
DecodeAudioBaseOp::DecodeAudioBaseOp(OpKernelConstruction *context) : OpKernel(context) {}

void DecodeAudioBaseOp::Compute(OpKernelContext *context) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit unhappy with the fact that these methods are in tensorflow::data while everything else is in this nested nameless namespace. I'm not experienced enough with C++ to know what this is about, so happy about feedback/ideas.

// get the input data, i.e. encoded audio data
const Tensor &input_tensor = context->input(0);
const string &input_data = input_tensor.scalar<tstring>()();
StringPiece data (input_data.data(), input_data.size());

// decode audio
std::unique_ptr<DecodedAudio> decoded = decode(data, nullptr);

// make sure decoding was successful
OP_REQUIRES(
context, decoded->success,
errors::InvalidArgument("Audio data could not be decoded"));

// output 1: samples
Tensor *output_tensor = nullptr;
TensorShape output_shape {decoded->channels, decoded->samples_perchannel};
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output_tensor));

// copy data from decoder buffer into output tensor
auto output_flat = output_tensor->flat<int16>();
std::memcpy(output_flat.data(), decoded->data, decoded->data_size());

// output 2: sample rate
Tensor *sample_rate_output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
&sample_rate_output));
sample_rate_output->flat<int32>()(0) = decoded->sampling_rate;
}

namespace {

AudioFileFormat ClassifyAudioFileFormat(StringPiece &data) {
static StringPiece wav_header("RIFF");
static StringPiece ogg_header("OggS");
static StringPiece flac_header("fLaC");
static StringPiece mp4_subtype_header("ftyp");

if (data.size() < 8) {
// this is a bit hacky, but substr throws out of range exceptions
return UnknownFormat;
} else if (data.substr(0, wav_header.size()) == wav_header) {
return WavFormat;
} else if (data.substr(0, ogg_header.size()) == ogg_header) {
return OggFormat;
} else if (data.substr(0, flac_header.size()) == flac_header) {
return FlacFormat;
} else if (data.substr(4, mp4_subtype_header.size()) == mp4_subtype_header) {
return Mp4Format;
} else if (IsMP3(data)) {
// MP3 can not reliably be detected by the header alone
return Mp3Format;
} else {
return UnknownFormat;
}
}

class AudioReadableResource : public AudioReadableResourceBase {
public:
AudioReadableResource(Env* env) : env_(env), resource_(nullptr) {}
Expand All @@ -39,25 +117,29 @@ class AudioReadableResource : public AudioReadableResourceBase {
char header[8];
StringPiece result;
TF_RETURN_IF_ERROR(file->Read(0, sizeof(header), &result, header));
if (memcmp(header, "RIFF", 4) == 0) {
StringPiece header_str(header, sizeof(header));
switch (ClassifyAudioFileFormat(header_str)) {
case WavFormat:
return WAVReadableResourceInit(env_, filename, optional_memory,
optional_length, resource_);
} else if (memcmp(header, "OggS", 4) == 0) {
case OggFormat:
return OggReadableResourceInit(env_, filename, optional_memory,
optional_length, resource_);
} else if (memcmp(header, "fLaC", 4) == 0) {
case FlacFormat:
return FlacReadableResourceInit(env_, filename, optional_memory,
optional_length, resource_);
}
Status status = MP3ReadableResourceInit(env_, filename, optional_memory,
optional_length, resource_);
if (status.ok()) {
return status;
}
if (memcmp(&header[4], "ftyp", 4) == 0) {
case Mp4Format:
LOG(ERROR) << "MP4A file is not fully supported!";
return MP4ReadableResourceInit(env_, filename, optional_memory,
optional_length, resource_);
default:
// mp3 is not always easily identifiable by the header alone
// currently we are trying MP3 as a default option
Status status = MP3ReadableResourceInit(env_, filename, optional_memory,
optional_length, resource_);
if (status.ok()) {
return status;
}
}
return errors::InvalidArgument("unknown file type: ", filename);
}
Expand Down Expand Up @@ -276,6 +358,34 @@ class AudioResampleOp : public OpKernel {
int64 quality_;
};

class DecodeAudioOp : public DecodeAudioBaseOp {
public:
DecodeAudioOp(OpKernelConstruction *context) : DecodeAudioBaseOp(context) {}

std::unique_ptr<DecodedAudio> decode(StringPiece &data, void *config) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Google's style is to return Status in return field of the function, and for any values/pointers that needs to be returned as well, they will be placed at the end of the function with *, so something like:

Status Decode(StringPiece &data, DecodeAudio** audio);

Then you could use:

DecodeAudio* audio;
Status status = Decode(data, &audio);

std::unique_ptr<DecodedAudio> d;
d.reset(audio);

You might also pass unique_ptr as well I think:

Status Decode(StringPiece &data, std::unique_ptr<DecodeAudio>* audio);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a pointer to a unique pointer sounds pretty hacky, do you think that's a good idea? Probably I'd rather go with the first option. Thx for pointing me to the coding style!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I don't see the decode here? as it only does a classify?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto error = std::unique_ptr<DecodedAudio>(new DecodedAudio(false, 0, 0, 0, nullptr));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this does no capture the error situation.

switch (ClassifyAudioFileFormat(data)) {
case WavFormat:
LOG(ERROR) << "Direct decoding of WAV not yet supported.";
return error;
case OggFormat:
LOG(ERROR) << "Direct decoding of Ogg not yet supported.";
return error;
case FlacFormat:
LOG(ERROR) << "Direct decoding of Flac not yet supported.";
return error;
case Mp4Format:
LOG(ERROR) << "Direct decoding of Mp4 not yet supported.";
return error;
case Mp3Format:
return DecodeMP3(data);
default:
LOG(ERROR) << "Unsupported audio format.";
return error;
}
}
};

class AudioDecodeWAVOp : public OpKernel {
public:
explicit AudioDecodeWAVOp(OpKernelConstruction* context) : OpKernel(context) {
Expand Down Expand Up @@ -339,6 +449,10 @@ REGISTER_KERNEL_BUILDER(Name("IO>AudioReadableRead").Device(DEVICE_CPU),

REGISTER_KERNEL_BUILDER(Name("IO>AudioResample").Device(DEVICE_CPU),
AudioResampleOp);

REGISTER_KERNEL_BUILDER(Name("IO>AudioDecode").Device(DEVICE_CPU),
DecodeAudioOp);

REGISTER_KERNEL_BUILDER(Name("IO>AudioDecodeWAV").Device(DEVICE_CPU),
AudioDecodeWAVOp);
} // namespace
Expand Down
44 changes: 44 additions & 0 deletions tensorflow_io/core/kernels/audio_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ limitations under the License.
namespace tensorflow {
namespace data {

enum AudioFileFormat {
UnknownFormat = 0,
WavFormat = 1,
FlacFormat = 2,
OggFormat = 3,
Mp4Format = 4,
Mp3Format = 5
};

class AudioReadableResourceBase : public ResourceBase {
public:
virtual Status Init(const string& filename,
Expand Down Expand Up @@ -52,5 +61,40 @@ Status MP4ReadableResourceInit(
const size_t optional_length,
std::unique_ptr<AudioReadableResourceBase>& resource);

// Container for decoded audio.
class DecodedAudio {
public:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need a class here, as it is just a struct with one function that does a

channels * samples_perchannel * sizeof(int16);

maybe we don't need this class after all?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think you're right. I think a struct is enough.

const bool success;
const int channels;
const int samples_perchannel;
const int sampling_rate;
// should first contain all samples of the left channel
// followed by the right channel
const int16 *data;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could avoid allocate the memory here, as we can create the output Tensor which will hold the memory. Then the output Tensor can be used directly to get the data?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this for a while, but I'm not sure how it would work. The problem is I need a shape to allocate the output tensor, which I get by decoding the MP3, for which in turn I need to have memory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of a shape, in:

Status Read(const int64 start, const int64 stop,

a callback-type of lambda is passed which allows the allocation to be done when shape is ready. This woucl be helpful when we only want to call once to read.


size_t data_size();

DecodedAudio(bool success, size_t channels, size_t samples_perchannel,
size_t sampling_rate, int16 *data);
~DecodedAudio();
};

// Base class for simple, in-memory audio data decoding.
// Handles creating output tensors of the right shape.
// Subclasses must implement the decode method.
class DecodeAudioBaseOp : public OpKernel {
public:
explicit DecodeAudioBaseOp(OpKernelConstruction* context);
void Compute(OpKernelContext* context) override;
virtual std::unique_ptr<DecodedAudio> decode(StringPiece &data, void *config) = 0;
};

// TODO it seems a bit weird to have this here, maybe we should create a audio_mp3_kernels header
// Detect MP3 data.
bool IsMP3(StringPiece &data);

// Decode MP3 data.
std::unique_ptr<DecodedAudio> DecodeMP3(StringPiece &data);

} // namespace data
} // namespace tensorflow
40 changes: 40 additions & 0 deletions tensorflow_io/core/kernels/audio_mp3_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,45 @@ Status MP3ReadableResourceInit(
return status;
}

bool IsMP3(StringPiece &data) {
return 0 == mp3dec_detect_buf((uint8_t *) data.data(), data.size());
}

std::unique_ptr<DecodedAudio> DecodeMP3(StringPiece &data) {
// initialize mp3 decoder
mp3dec_t mp3dec;
mp3dec_init(&mp3dec);

// decode mp3
mp3dec_file_info_t mp3;
memset(&mp3, 0x00, sizeof(mp3dec_file_info_t));
mp3dec_load_buf(&mp3dec, (const uint8_t *)data.data(),
data.size(), &mp3, NULL /* progress callback */,
NULL /* user data */);

// if channels == 0, decoding was not successful
if (!mp3.channels) {
return std::unique_ptr<DecodedAudio>(new DecodedAudio(false, 0, 0, 0, nullptr));
}

// TODO: could channels have different numbers of samples?
int samples_perchannel = mp3.samples / mp3.channels;

return std::unique_ptr<DecodedAudio>(new DecodedAudio(true, mp3.channels, samples_perchannel, mp3.hz, mp3.buffer));
}


class DecodeMP3Op : public DecodeAudioBaseOp {
public:
explicit DecodeMP3Op(OpKernelConstruction *context) : DecodeAudioBaseOp(context) {}

std::unique_ptr<DecodedAudio> decode(StringPiece &data, void *config) override {
return DecodeMP3(data);
}
};

REGISTER_KERNEL_BUILDER(Name("IO>AudioDecodeMp3").Device(DEVICE_CPU),
DecodeMP3Op);

} // namespace data
} // namespace tensorflow
20 changes: 20 additions & 0 deletions tensorflow_io/core/ops/audio_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ REGISTER_OP("IO>AudioResample")
return Status::OK();
});

REGISTER_OP("IO>AudioDecode")
.Input("contents: string")
.Output("samples: int16")
.Output("sample_rate: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
c->set_output(1, c->Scalar());
return Status::OK();
});

REGISTER_OP("IO>AudioDecodeMp3")
.Input("contents: string")
.Output("samples: int16")
.Output("sample_rate: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
c->set_output(1, c->Scalar());
return Status::OK();
});

REGISTER_OP("IO>AudioDecodeWAV")
.Input("input: string")
.Input("shape: int64")
Expand Down
1 change: 1 addition & 0 deletions tensorflow_io/core/python/api/v0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@

from tensorflow_io.core.python.api.v0 import genome
from tensorflow_io.core.python.api.v0 import image
from tensorflow_io.core.python.api.v0 import audio
18 changes: 18 additions & 0 deletions tensorflow_io/core/python/api/v0/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tensorflow_io.audio"""

from tensorflow_io.core.python.ops.audio_ops import decode # pylint: disable=unused-import
from tensorflow_io.core.python.ops.audio_ops import decode_mp3 # pylint: disable=unused-import
Loading