Skip to content

torchaudio's First Official Release (v0.2.0)

Pre-release
Pre-release
Compare
Choose a tag to compare
@jamarshon jamarshon released this 08 Aug 15:56
e3c7784

Background

The goal of this release is to fix the current API as there will be future changes that breaking backward compatibility in order to improve the library as more thought is given to design, capabilities, and usability.

While this release is compatible with all currently known PyTorch versions (<=1.2.0), the available binaries will only require Pytorch 1.1.0. Installation commands:

# Wheels for Python 2 are NOT supported
# Python 3.5
$ pip3 install http://download.pytorch.org/whl/torchaudio-0.2-cp35-cp35m-linux_x86_64.whl
# Python 3.6
$ pip3 install http://download.pytorch.org/whl/torchaudio-0.2-cp36-cp36m-linux_x86_64.whl
# Python 3.7
$ pip3 install http://download.pytorch.org/whl/torchaudio-0.2-cp37-cp37m-linux_x86_64.whl

What's new?

  • Fixed broken tests and setup automatic testing environment
  • Read in Kaldi files (“.ark”, “.scp”)
  • Separation of state and computation into transforms.py and functional.py
  • Loading and saving to file
  • Datasets VCTK and YESNO
  • SoxEffects and SoxEffectsChain in torchaudio.sox_effects

CI and Testing

A continuous integration (Travis CI) has been setup in #117. This means all the tests have been fixed and their status can be checked in https://travis-ci.org/pytorch/audio. The test files have to be run separately via build_tools/travis/test_script.sh because closing sox after a test file is completed prevents it from being reopened. The testing framework is pytest.

# Run the whole test suite
$ build_tools/travis/test_script.sh
# Run an individual test
$ python -m pytest test/test_transforms.py

Kaldi IO

Kaldi IO has been added as an optional dependency in #111. torchaudio provides a simple wrapper around this by converting the np.ndarray into torch.Tensor. Functions include: read_vec_int_ark, read_vec_flt_scp, read_vec_flt_ark, read_mat_scp, and read_mat_ark.

>>> # read ark to a 'dictionary'
>>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_int_ark(file) }

Separation of State and Computation

In #105, the computations have been moved into functional.py. The reasoning behind this is that tracking state is a separate problem by itself and should be separate from computing a function. It also allows us to annotate the functional as weak scriptable, which in turn allows us to utilize the JIT and create efficient code. The functional itself might then also be used by other functionals, which is much easier and more efficient than having another Module create an instance of the class. This also makes it easier to implement performance improvements and create a generic API. If someone implements a function that adheres to the contract of your functional, it can be an immediate drop-in. This is important if we want to support different backends (e.g. move a functional entirely into C++).

>>> torchaudio.transforms.Spectrogram(n_fft=...)(waveform)
>>> torchaudio.functional.spectrogram(waveform, …)

Loading and saving to file

Tensors can be read and written to various file formats (e.g. “mp3”, “wav”, etc.) through torchaudio.

sound, sample_rate = torchaudio.load(‘input.wav’)
torchaudio.save(‘output.wav’, sound)

Transforms and functionals

Transforms

class Compose(object):
    def __init__(self, transforms):
    def __call__(self, audio):
        
class Scale(object):
    def __init__(self, factor=2**31):
    def __call__(self, tensor):
        
class PadTrim(object):
    def __init__(self, max_len, fill_value=0, channels_first=True):
    def __call__(self, tensor):
       
class DownmixMono(object):
    def __init__(self, channels_first=None):
    def __call__(self, tensor):

class LC2CL(object):
    def __call__(self, tensor):

def SPECTROGRAM(*args, **kwargs):

class Spectrogram(object):
    def __init__(self, n_fft=400, ws=None, hop=None,
                 pad=0, window=torch.hann_window,
                 power=2, normalize=False, wkwargs=None):
    def __call__(self, sig):
        
def F2M(*args, **kwargs):

class MelScale(object):
    def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None):
    def __call__(self, spec_f):

class SpectrogramToDB(object):
    def __init__(self, stype="power", top_db=None):
    def __call__(self, spec):
       
class MFCC(object):
    def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
                 melkwargs=None):
    def __call__(self, sig):

class MelSpectrogram(object):
    def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window=torch.hann_window, wkwargs=None):
    def __call__(self, sig):

def MEL(*args, **kwargs):

class BLC2CBL(object):
    def __call__(self, tensor):

class MuLawEncoding(object):
    def __init__(self, quantization_channels=256):
    def __call__(self, x):

class MuLawExpanding(object):
    def __init__(self, quantization_channels=256):
    def __call__(self, x_mu):

Functional

def scale(tensor, factor):
    # type: (Tensor, int) -> Tensor

def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
    # type: (Tensor, int, int, int, float) -> Tensor

def downmix_mono(tensor, ch_dim):
    # type: (Tensor, int) -> Tensor

def LC2CL(tensor):
    # type: (Tensor) -> Tensor

def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
    # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor

def create_fb_matrix(n_stft, f_min, f_max, n_mels):
    # type: (int, float, float, int) -> Tensor

def mel_scale(spec_f, f_min, f_max, n_mels, fb=None):
    # type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]

def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
    # type: (Tensor, float, float, float, Optional[float]) -> Tensor

def create_dct(n_mfcc, n_mels, norm):
    # type: (int, int, string) -> Tensor

def MFCC(sig, mel_spect, log_mels, s2db, dct_mat):
    # type: (Tensor, MelSpectrogram, bool, SpectrogramToDB, Tensor) -> Tensor

def BLC2CBL(tensor):
    # type: (Tensor) -> Tensor

def mu_law_encoding(x, qc):
    # type: (Tensor, int) -> Tensor

def mu_law_expanding(x_mu, qc):
    # type: (Tensor, int) -> Tensor

Datasets VCTK and YESNO

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. For example:

yesno_data = torchaudio.datasets.YESNO('.', download=True)
data_loader = torch.utils.data.DataLoader(yesno_data,
                                          batch_size=1,
                                          shuffle=True,
                                          num_workers=args.nThreads)

The two datasets available are VCTK and YESNO. They download the datasets and preprocess them so that the loaded data is in convenient format.

SoxEffects and SoxEffectsChain

SoxEffects and SoxEffectsChain in torchaudio.sox_effects expose sox operations through a Python interface. Various useful effects like downmixing a multichannel signal or resampling a signal can be done here.

torchaudio.initialize_sox()
E = torchaudio.sox_effects.SoxEffectsChain()
E.append_effect_to_chain("rate", [16000])  # resample to 16000hz
E.append_effect_to_chain("channels", ["1"])  # mono signal
E.set_input_file(fn)
waveform, sample_rate = E.sox_build_flow_effects()
torchaudio.shutdown_sox()