-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
audio_feature.py
404 lines (337 loc) · 13 KB
/
audio_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Audio feature."""
from __future__ import annotations
import abc
import enum
import functools
import io
import os
from typing import BinaryIO, Optional, Union
import wave
from absl import logging
from etils import epath
import numpy as np
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.features import tensor_feature
from tensorflow_datasets.core.proto import feature_pb2
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
Json = type_utils.Json
Encoding = tensor_feature.Encoding
class _TfioFileFormat(enum.Enum):
"""Format of the audio files supported for decoding by `tensorflow_io`.
Attributes:
FLAC: Free Lossless Audio Codec
"""
FLAC = 'flac'
@functools.lru_cache(maxsize=None)
def _tfio_decode_fn():
tfio = lazy_imports_lib.lazy_imports.tensorflow_io
return {
_TfioFileFormat.FLAC: tfio.audio.decode_flac,
}
@functools.lru_cache(maxsize=None)
def _tfio_acceptable_dtypes():
return {
_TfioFileFormat.FLAC: [tf.uint8, tf.int16, tf.int32],
}
class _AudioDecoder(abc.ABC):
"""Utils which encode/decode audios."""
def __init__(
self, file_format: Optional[str], np_dtype: np.dtype, shape: utils.Shape
):
"""Constructs the lazy audio decoder.
Args:
file_format: `str`, the audio file format. Can be any format ffmpeg
understands. If in `_TfioFileFormat`, then will use `tensorflow_io`
np_dtype: The numpy dtype of the data.
shape: `tuple`, shape of the data.
"""
self._file_format = file_format
self._np_dtype = np_dtype
self._dtype = tf.dtypes.as_dtype(self._np_dtype)
self._shape = shape
self._channels = shape[1] if len(shape) > 1 else 1
@abc.abstractmethod
def encode_audio(
self, fobj: BinaryIO, file_format: Optional[str]
) -> np.ndarray:
"""Encode audio into numpy array for storing as a tf-example."""
raise NotImplementedError
@abc.abstractmethod
def decode_audio(self, audio_tensor: tf.Tensor) -> tf.Tensor:
"""Decode audio from the loaded tf-example."""
raise NotImplementedError
def _pydub_load_audio(
fobj: BinaryIO, file_format: Optional[str], channels: int
) -> np.ndarray:
"""Read audio using pydub library."""
pydub = lazy_imports_lib.lazy_imports.pydub
audio_segment = pydub.AudioSegment.from_file(fobj, format=file_format)
if channels != audio_segment.channels:
logging.info(
'Modifying audio segment from %s to %s channel(s).',
audio_segment.channels,
channels,
)
audio_segment = audio_segment.set_channels(channels)
raw_samples = np.array(audio_segment.get_array_of_samples())
if audio_segment.channels > 1:
return raw_samples.reshape((-1, audio_segment.channels))
else:
return raw_samples
def _pydub_decode_audio(
audio_tensor: tf.Tensor,
file_format_tensor: tf.experimental.Optional,
channels: int,
) -> np.ndarray:
"""Decode audio from tf.Tensor using pydub library."""
fobj = io.BytesIO(audio_tensor.numpy())
if file_format_tensor.has_value():
file_format = file_format_tensor.get_value().numpy()
else:
file_format = None
return _pydub_load_audio(fobj, file_format, channels)
class _LazyDecoder(_AudioDecoder):
"""Read audio during decoding."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
self._tfio_file_format = _TfioFileFormat(self._file_format)
except ValueError:
self._tfio_file_format = None
logging.warning(
(
'Using lazy encoding with `file_format=%s` might be very slow'
' when reading prepared dataset. Consider using one of these file'
' formats: %s'
),
self._file_format,
[item.value for item in _TfioFileFormat],
)
if self._tfio_file_format:
acceptable_dtypes = _tfio_acceptable_dtypes()[self._tfio_file_format]
if self._dtype not in acceptable_dtypes:
raise ValueError(
f'Acceptable `dtype` for lazy loading {self._file_format}: '
f'{acceptable_dtypes} (was {self._dtype})'
)
def encode_audio(
self, fobj: BinaryIO, file_format: Optional[str]
) -> np.ndarray:
return np.array(fobj.read(), dtype=tf.string.as_numpy_dtype)
def decode_audio(self, audio_tensor: tf.Tensor) -> tf.Tensor:
if self._tfio_file_format:
decoded_audio_tensor = _tfio_decode_fn()[self._tfio_file_format](
audio_tensor, dtype=self._dtype
)
decoded_audio_tensor = tf.squeeze(decoded_audio_tensor)
else:
if self._file_format:
file_format_tensor = tf.experimental.Optional.from_value(
self._file_format
)
else:
file_format_tensor = tf.experimental.Optional.empty(
tf.TensorSpec(shape=(), dtype=tf.string)
)
# pydub.AudioSegment.get_array_of_samples returns an array with type code
# `b`, `h` or `i` which can be all converted to `tf.int32`
decoded_audio_tensor = tf.py_function(
_pydub_decode_audio,
[audio_tensor, file_format_tensor, self._channels],
tf.int32,
)
decoded_audio_tensor.set_shape(self._shape)
return tf.cast(decoded_audio_tensor, self._dtype)
class _EagerDecoder(_AudioDecoder):
"""Read audio during encoding."""
def encode_audio(
self, fobj: BinaryIO, file_format: Optional[str]
) -> np.ndarray:
audio = _pydub_load_audio(fobj, file_format, self._channels)
return audio.astype(self._np_dtype)
def decode_audio(self, audio_tensor: tf.Tensor) -> tf.Tensor:
return audio_tensor
class Audio(tensor_feature.Tensor):
"""`tfds.features.FeatureConnector` for audio.
In `_generate_examples`, Audio accept:
* A `np.ndarray` of shape `(length,)` or `(length, channels)`
* A path to a `.mp3`, `.wav`,... file.
* A file-object (e.g. `with path.open('rb') as fobj:`)
By default, Audio features are decoded as the raw integer wave form
`tf.Tensor(shape=(None,), dtype=tf.int64)`.
When encoding an audio with a different number of channels than expected by
the feature, TFDS automatically tries to correct the number of channels.
"""
def __init__(
self,
*,
file_format: Optional[str] = None,
shape: utils.Shape = (None,),
dtype: type_utils.TfdsDType = np.int64,
sample_rate: Optional[int] = None,
encoding: Union[str, Encoding] = Encoding.NONE,
doc: feature_lib.DocArg = None,
lazy_decode: bool = False,
):
"""Constructs the connector.
Args:
file_format: `str`, the audio file format. Can be any format ffmpeg
understands. If `None`, will attempt to infer from the file extension.
shape: `tuple`, shape of the data.
dtype: The dtype of the data.
sample_rate: `int`, additional metadata exposed to the user through
`info.features['audio'].sample_rate`. This value isn't used neither in
encoding nor decoding.
encoding: Internal encoding. See `tfds.features.Encoding` for available
values.
doc: Documentation of this feature (e.g. description).
lazy_decode: `bool`, if set `True` then stores audio as is and decodes it
to numpy array when loaded. Otherwise saves decoded audio.
""" # fmt:skip
self._file_format = file_format
self._sample_rate = sample_rate
self._lazy_decode = lazy_decode
if len(shape) > 2:
raise ValueError(
'Audio shape should be either (length,) or '
f'(length, num_channels), got {shape}.'
)
if self._lazy_decode:
serialized_dtype = tf.string
serialized_shape = ()
_Decoder = _LazyDecoder
else:
serialized_dtype = None
serialized_shape = None
_Decoder = _EagerDecoder
super().__init__(
shape=shape,
dtype=dtype,
encoding=encoding,
doc=doc,
serialized_dtype=serialized_dtype,
serialized_shape=serialized_shape,
)
self._audio_decoder = _Decoder(
file_format=self._file_format, np_dtype=self._dtype, shape=self._shape
)
def encode_example(self, audio_or_path_or_fobj):
"""Convert the given audio into a dict convertible to tf example."""
if isinstance(audio_or_path_or_fobj, (np.ndarray, list)):
return audio_or_path_or_fobj
elif isinstance(audio_or_path_or_fobj, epath.PathLikeCls):
filename = os.fspath(audio_or_path_or_fobj)
file_format = _infer_file_format(self._file_format, filename)
with tf.io.gfile.GFile(filename, 'rb') as audio_f:
try:
audio = self._audio_decoder.encode_audio(audio_f, file_format)
except Exception as e: # pylint: disable=broad-except
utils.reraise(e, prefix=f'Error for {filename}: ')
else:
audio = self._audio_decoder.encode_audio(
audio_or_path_or_fobj, self._file_format
)
return super().encode_example(audio)
def decode_example(self, tfexample_data):
"""See base class for details."""
audio = super().decode_example(tfexample_data)
return self._audio_decoder.decode_audio(audio)
@property
def sample_rate(self):
"""Returns the `sample_rate` metadata associated with the dataset."""
return self._sample_rate
def repr_html(self, ex: np.ndarray) -> str:
"""Audio are displayed in the player."""
if self.sample_rate:
rate = self.sample_rate
else:
# We should display an error message once to warn the user the sample
# rate was auto-infered. Requirements:
# * Should appear only once (even though repr_html is called once per
# examples)
# * Ideally should appear on Colab (while `logging.warning` is hidden
# by default)
rate = 16000
audio_str = utils.get_base64(lambda buff: _save_wav(buff, ex, rate))
return (
f'<audio controls src="data:audio/ogg;base64,{audio_str}" '
' controlsList="nodownload" />'
)
@classmethod
def from_json_content(
cls, value: Union[Json, feature_pb2.AudioFeature]
) -> 'Audio':
if isinstance(value, dict):
# For backwards compatibility
return cls(
file_format=value['file_format'],
shape=tuple(value['shape']),
dtype=feature_lib.dtype_from_str(value['dtype']),
sample_rate=value['sample_rate'],
lazy_decode=value.get('lazy_decode', False),
)
return cls(
shape=feature_lib.from_shape_proto(value.shape),
dtype=feature_lib.dtype_from_str(value.dtype),
file_format=value.file_format or None,
sample_rate=value.sample_rate,
encoding=value.encoding,
lazy_decode=value.lazy_decode or False,
)
def to_json_content(self) -> feature_pb2.AudioFeature: # pytype: disable=signature-mismatch # overriding-return-type-checks
return feature_pb2.AudioFeature(
shape=feature_lib.to_shape_proto(self.shape),
dtype=feature_lib.dtype_to_str(self.dtype),
file_format=self._file_format,
sample_rate=self._sample_rate,
encoding=self._encoding.name,
lazy_decode=self._lazy_decode,
)
def _save_wav(buff, data, rate) -> None:
"""Transform a numpy array to a PCM bytestring."""
# Code inspired from `IPython.display.Audio`
data = np.array(data, dtype=float)
bit_depth = 16
max_sample_value = int(2 ** (bit_depth - 1)) - 1
num_channels = data.shape[1] if len(data.shape) > 1 else 1
scaled = np.int16(data / np.max(np.abs(data)) * max_sample_value)
# The WAVE spec expects little-endian integers of "sampwidth" bytes each.
# Numpy's `astype` accepts array-protocol type strings, so we specify:
# - '<' to indicate little endian
# - 'i' to specify signed integer
# - the number of bytes used to represent each integer
# See: https://numpy.org/doc/stable/reference/arrays.dtypes.html
encoded_wav = scaled.astype(f'<i{bit_depth // 8}', copy=False).tobytes()
with wave.open(buff, mode='wb') as waveobj:
waveobj.setnchannels(num_channels)
waveobj.setframerate(rate)
waveobj.setsampwidth(bit_depth // 8)
waveobj.setcomptype('NONE', 'NONE')
waveobj.writeframes(encoded_wav)
def _infer_file_format(
file_format: Optional[str], filename: str
) -> Optional[str]:
"""Simple heuristics to infer file format. Pydub will use FFMPEG otherwise."""
if file_format is not None:
return file_format
suffix = epath.Path(filename).suffix
if suffix.startswith('.'):
return suffix[1:]
return None