Skip to content

Commit

Permalink
Support encode_mp3 in Linux with lame
Browse files Browse the repository at this point in the history
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang committed Mar 21, 2020
1 parent ea27e51 commit 4371f56
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/build.wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ if [[ $(uname) == "Linux" ]]; then
add-apt-repository -y ppa:deadsnakes/ppa
apt-get -y -qq update
fi
apt-get -y -qq install $PYTHON_VERSION ffmpeg dnsutils
apt-get -y -qq install $PYTHON_VERSION ffmpeg dnsutils libmp3lame0
curl -sSOL https://bootstrap.pypa.io/get-pip.py
$PYTHON_VERSION get-pip.py -q
fi
Expand Down
141 changes: 141 additions & 0 deletions tensorflow_io/core/kernels/audio_video_mp3_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,34 @@ limitations under the License.
#define MINIMP3_FLOAT_OUTPUT
#include "minimp3_ex.h"

#if defined(__linux__)
#include <dlfcn.h>
#endif

typedef void* lame_t;
typedef enum vbr_mode_e {
vbr_off = 0,
vbr_mt,
vbr_rh,
vbr_abr,
vbr_mtrh,
vbr_max_indicator,
vbr_default = vbr_mtrh
} vbr_mode;

static lame_t (*lame_init)(void);
static int (*lame_set_num_channels)(lame_t, int);
static int (*lame_set_in_samplerate)(lame_t, int);
static int (*lame_set_VBR)(lame_t, vbr_mode);
static int (*lame_init_params)(lame_t);
static int (*lame_encode_buffer_interleaved_ieee_float)(lame_t gfp,
const float pcm[],
const int nsamples,
unsigned char* mp3buf,
const int mp3buf_size);
static int (*lame_encode_flush)(lame_t gfp, unsigned char* mp3buf, int size);
static int (*lame_close)(lame_t);

namespace tensorflow {
namespace data {
namespace {
Expand Down Expand Up @@ -200,8 +228,121 @@ class AudioDecodeMP3Op : public OpKernel {
Env* env_ GUARDED_BY(mu_);
};

bool LoadLame() {
#if defined(__linux__)
void* lib = dlopen("libmp3lame.so.0", RTLD_NOW);
if (lib != nullptr) {
*(void**)(&lame_init) = dlsym(lib, "lame_init");
*(void**)(&lame_set_num_channels) = dlsym(lib, "lame_set_num_channels");
*(void**)(&lame_set_in_samplerate) = dlsym(lib, "lame_set_in_samplerate");
*(void**)(&lame_set_VBR) = dlsym(lib, "lame_set_VBR");
*(void**)(&lame_init_params) = dlsym(lib, "lame_init_params");
*(void**)(&lame_encode_buffer_interleaved_ieee_float) =
dlsym(lib, "lame_encode_buffer_interleaved_ieee_float");
*(void**)(&lame_encode_flush) = dlsym(lib, "lame_encode_flush");
*(void**)(&lame_close) = dlsym(lib, "lame_close");
if (lame_init != nullptr && lame_set_num_channels != nullptr &&
lame_set_in_samplerate != nullptr && lame_set_VBR != nullptr &&
lame_init_params != nullptr &&
lame_encode_buffer_interleaved_ieee_float != nullptr &&
lame_encode_flush != nullptr && lame_close != nullptr) {
return true;
}
}
LOG(WARNING) << "libmp3lame.so.0 or lame functions are not available";
#endif
return false;
}

class AudioEncodeMP3Op : public OpKernel {
public:
explicit AudioEncodeMP3Op(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
}

void Compute(OpKernelContext* context) override {
OP_REQUIRES(context, lame_available_,
errors::InvalidArgument("lame library is not available"));
const Tensor* input_tensor;
OP_REQUIRES_OK(context, context->input("input", &input_tensor));

const Tensor* rate_tensor;
OP_REQUIRES_OK(context, context->input("rate", &rate_tensor));

const int64 rate = rate_tensor->scalar<int64>()();
const int64 samples = input_tensor->shape().dim_size(0);
const int64 channels = input_tensor->shape().dim_size(1);

Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(0, TensorShape({}), &output_tensor));

tstring& output = output_tensor->scalar<tstring>()();

std::unique_ptr<void, void (*)(void*)> lame(nullptr, [](void* p) {
if (p != nullptr) {
lame_close(p);
}
});
lame.reset(lame_init());
OP_REQUIRES(context, (lame.get() != nullptr),
errors::InvalidArgument("unable to initialize lame"));

int status;
status = lame_set_num_channels(lame.get(), channels);
OP_REQUIRES(context, (status == 0),
errors::InvalidArgument("unable to set channels: ", status));

status = lame_set_in_samplerate(lame.get(), rate);
OP_REQUIRES(context, (status == 0),
errors::InvalidArgument("unable to set rate: ", status));

status = lame_set_VBR(lame.get(), vbr_default);
OP_REQUIRES(context, (status == 0),
errors::InvalidArgument("unable to set vbr: ", status));

status = lame_init_params(lame.get());
OP_REQUIRES(context, (status == 0),
errors::InvalidArgument("unable to init params ", status));

const float* pcm = input_tensor->flat<float>().data();

// worse case according to lame:
// mp3buf_size in bytes = 1.25*num_samples + 7200
output.resize(samples * 5 / 4 + 7200);
unsigned char* mp3buf = (unsigned char*)&output[0];
int mp3buf_size = output.size();
status = lame_encode_buffer_interleaved_ieee_float(lame.get(), pcm, samples,
mp3buf, mp3buf_size);
OP_REQUIRES(context, (status >= 0),
errors::InvalidArgument("unable to encode: ", status));

int encoded = status;

mp3buf = (unsigned char*)&output[encoded];
mp3buf_size = output.size() - encoded;
status = lame_encode_flush(lame.get(), mp3buf, mp3buf_size);
OP_REQUIRES(context, (status >= 0),
errors::InvalidArgument("unable to flush: ", status));
encoded = encoded + status;
// cur to the encoded length
output.resize(encoded);
}

private:
mutable mutex mu_;
Env* env_ GUARDED_BY(mu_);

static bool lame_available_;
};

bool AudioEncodeMP3Op::lame_available_ = LoadLame();

REGISTER_KERNEL_BUILDER(Name("IO>AudioDecodeMP3").Device(DEVICE_CPU),
AudioDecodeMP3Op);
REGISTER_KERNEL_BUILDER(Name("IO>AudioEncodeMP3").Device(DEVICE_CPU),
AudioEncodeMP3Op);

} // namespace

Status MP3ReadableResourceInit(
Expand Down
9 changes: 9 additions & 0 deletions tensorflow_io/core/ops/audio_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ REGISTER_OP("IO>AudioDecodeMP3")
return Status::OK();
});

REGISTER_OP("IO>AudioEncodeMP3")
.Input("input: float32")
.Input("rate: int64")
.Output("value: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
return Status::OK();
});

} // namespace
} // namespace io
} // namespace tensorflow
1 change: 1 addition & 0 deletions tensorflow_io/core/python/api/experimental/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
decode_ogg,
encode_ogg,
decode_mp3,
encode_mp3,
)
14 changes: 14 additions & 0 deletions tensorflow_io/core/python/experimental/audio_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,17 @@ def decode_mp3(input, shape=None, name=None): # pylint: disable=redefined-built
if shape is None:
shape = tf.constant([-1, -1], tf.int64)
return core_ops.io_audio_decode_mp3(input, shape=shape, name=name)


def encode_mp3(input, rate, name=None): # pylint: disable=redefined-builtin
"""Encode MP3 audio into string.
Args:
input: A `Tensor` of the audio input.
rate: The sample rate of the audio.
name: A name for the operation (optional).
Returns:
output: Encoded audio.
"""
return core_ops.io_audio_encode_mp3(input, rate, name=name)
46 changes: 46 additions & 0 deletions tests/test_audio_ops_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,32 @@ def func(e):
return args, func, expected


@pytest.fixture(name="encode_mp3", scope="module")
def fixture_encode_mp3():
"""fixture_encode_mp3"""
raw_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_audio", "l1-fl6.raw"
)
raw = np.fromfile(raw_path, np.int16)
raw = raw.reshape([-1, 2])
value = tf.cast(raw, tf.float32) / 32768.0

# lame has a delay which will expand the number of samples.
# for that this test simply check the number of samples
args = value

def func(e):
v = tfio.experimental.audio.encode_mp3(e, rate=44100)
v = tfio.experimental.audio.decode_mp3(v)
v = tf.shape(v)
return v

# Should be [18816, 2] but lame expand additional samples
expected = tf.constant([21888, 2], tf.int32)

return args, func, expected


# By default, operations runs in eager mode,
# Note as of now shape inference is skipped in eager mode
@pytest.mark.parametrize(
Expand All @@ -257,6 +283,15 @@ def func(e):
pytest.param("decode_ogg"),
pytest.param("encode_ogg"),
pytest.param("decode_mp3"),
pytest.param(
"encode_mp3",
marks=[
pytest.mark.skipif(
sys.platform in ("win32", "darwin"),
reason="no lame for darwin or win32",
),
],
),
],
ids=[
"resample",
Expand All @@ -267,6 +302,7 @@ def func(e):
"decode_ogg",
"encode_ogg",
"decode_mp3",
"encode_mp3",
],
)
def test_audio_ops(fixture_lookup, io_data_fixture):
Expand All @@ -289,6 +325,15 @@ def test_audio_ops(fixture_lookup, io_data_fixture):
pytest.param("decode_ogg"),
pytest.param("encode_ogg"),
pytest.param("decode_mp3"),
pytest.param(
"encode_mp3",
marks=[
pytest.mark.skipif(
sys.platform in ("win32", "darwin"),
reason="no lame for darwin or win32",
),
],
),
],
ids=[
"resample",
Expand All @@ -299,6 +344,7 @@ def test_audio_ops(fixture_lookup, io_data_fixture):
"decode_ogg",
"encode_ogg",
"decode_mp3",
"encode_mp3",
],
)
def test_audio_ops_in_graph(fixture_lookup, io_data_fixture):
Expand Down

0 comments on commit 4371f56

Please sign in to comment.