From ed1281767a86c5474670955fcad46716bf7230f1 Mon Sep 17 00:00:00 2001 From: fabiocat93 Date: Fri, 17 May 2024 10:00:11 -0400 Subject: [PATCH] adding data augmentation with torch-audiomentation --- poetry.lock | 87 ++++++++++++++++++- pyproject.toml | 2 + scripts/experiment5.py | 42 +++++++++ src/senselab/audio/tasks/data_augmentation.py | 37 ++++++++ 4 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 scripts/experiment5.py create mode 100644 src/senselab/audio/tasks/data_augmentation.py diff --git a/poetry.lock b/poetry.lock index aa60c79..cc97f89 100644 --- a/poetry.lock +++ b/poetry.lock @@ -279,6 +279,26 @@ audmath = ">=1.3.0" numpy = "*" soundfile = ">=0.12.1" +[[package]] +name = "audiomentations" +version = "0.35.0" +description = "A Python library for audio data augmentation. Inspired by albumentations. Useful for machine learning." +optional = false +python-versions = ">=3.8" +files = [ + {file = "audiomentations-0.35.0-py3-none-any.whl", hash = "sha256:363904f30b0a3f30162e43a303b3d63b1ffd0737386739f9193c827667dab42a"}, + {file = "audiomentations-0.35.0.tar.gz", hash = "sha256:4c6c8c1b8ca5d3eb6621b5c1cda3f86dbf9fdc070022fc0cd645334c66153944"}, +] + +[package.dependencies] +librosa = ">=0.8.0,<0.10.0 || >0.10.0,<0.11.0" +numpy = ">=1.21.0" +scipy = ">=1.4.0,<2" +soxr = ">=0.3.2,<1.0.0" + +[package.extras] +extras = ["cylimiter (==0.3.0)", "lameenc (>=1.2.0,<2)", "pydub (>=0.22.0,<1)", "pyloudnorm (>=0.1.0)", "pyroomacoustics (>=0.6.0)"] + [[package]] name = "audioread" version = "3.0.1" @@ -1294,6 +1314,22 @@ files = [ [package.dependencies] referencing = ">=0.31.0" +[[package]] +name = "julius" +version = "0.2.7" +description = "Nice DSP sweets: resampling, FFT Convolutions. All with PyTorch, differentiable and with CUDA support." +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "julius-0.2.7.tar.gz", hash = "sha256:3c0f5f5306d7d6016fcc95196b274cae6f07e2c9596eed314e4e7641554fbb08"}, +] + +[package.dependencies] +torch = ">=1.7.0" + +[package.extras] +dev = ["coverage", "flake8", "mypy", "onnxruntime", "pdoc3", "resampy (==0.2.2)"] + [[package]] name = "jupyter-client" version = "8.6.1" @@ -2440,6 +2476,17 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "primepy" +version = "1.3" +description = "This module contains several useful functions to work with prime numbers. from primePy import primes" +optional = false +python-versions = "*" +files = [ + {file = "primePy-1.3-py3-none-any.whl", hash = "sha256:5ed443718765be9bf7e2ff4c56cdff71b42140a15b39d054f9d99f0009e2317a"}, + {file = "primePy-1.3.tar.gz", hash = "sha256:25fd7e25344b0789a5984c75d89f054fcf1f180bef20c998e4befbac92de4669"}, +] + [[package]] name = "prompt-toolkit" version = "3.0.43" @@ -3786,6 +3833,44 @@ typing-extensions = ">=4.8.0" opt-einsum = ["opt-einsum (>=3.3)"] optree = ["optree (>=0.9.1)"] +[[package]] +name = "torch-audiomentations" +version = "0.11.1" +description = "A Pytorch library for audio data augmentation. Inspired by audiomentations. Useful for deep learning." +optional = false +python-versions = ">=3.6" +files = [ + {file = "torch-audiomentations-0.11.1.tar.gz", hash = "sha256:ee4d0c0f937552b4d63dbccdd3c509f5f7747f76c2e13e528ea8ef30e13e1000"}, + {file = "torch_audiomentations-0.11.1-py3-none-any.whl", hash = "sha256:530b419b61c4ffd7b137828ec4ad4c2ce937db205230b5679947941f6b05242e"}, +] + +[package.dependencies] +julius = ">=0.2.3,<0.3" +librosa = ">=0.6.0" +torch = ">=1.7.0" +torch-pitch-shift = ">=1.2.2" +torchaudio = ">=0.9.0" + +[package.extras] +extras = ["PyYAML"] + +[[package]] +name = "torch-pitch-shift" +version = "1.2.4" +description = "" +optional = false +python-versions = ">=3.4" +files = [ + {file = "torch_pitch_shift-1.2.4-py3-none-any.whl", hash = "sha256:6cde2bddd7388e6da05e354fb84cf8e95a33e79d196e4c8ebd655faa4296cc42"}, + {file = "torch_pitch_shift-1.2.4.tar.gz", hash = "sha256:c173fc808184a684c1ecd99d5744573e55a667f238e2268adaf15c9467d99db9"}, +] + +[package.dependencies] +packaging = ">=21.3" +primePy = ">=1.3" +torch = ">=1.7.0" +torchaudio = ">=0.7.0" + [[package]] name = "torchaudio" version = "2.2.2" @@ -4319,4 +4404,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "9a6a902f473533ac897708896e700b3d18526f16d7e9a6dd17061e529cf989a9" +content-hash = "c185ef549ed5a9b1bd5a3d296688aec4aaf196e7b957fce06dab9b33c38b4922" diff --git a/pyproject.toml b/pyproject.toml index cab5ac4..c65b5ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ huggingface-hub = "^0.23.0" praat-parselmouth = "^0.4.3" iso-639 = {git = "https://github.com/noumar/iso639.git", tag = "0.4.5"} opensmile = "^2.5.0" +audiomentations = "^0.35.0" +torch-audiomentations = "^0.11.1" [tool.poetry.group.dev] optional = true diff --git a/scripts/experiment5.py b/scripts/experiment5.py new file mode 100644 index 0000000..aeb8070 --- /dev/null +++ b/scripts/experiment5.py @@ -0,0 +1,42 @@ +"""This script is used to test the audio tasks.""" + +from typing import Any, Dict + +from torch_audiomentations import Compose, Gain, PolarityInversion + +from senselab.audio.tasks.data_augmentation import augment_hf_dataset +from senselab.utils.decorators import get_response_time +from senselab.utils.tasks.input_output import read_files_from_disk + + +@get_response_time +def workflow(data: Dict[str, Any], augmentation: Compose) -> None: + """This function reads audio files from disk, and transcribes them using Whisper.""" + print("Starting to read files from disk...") + dataset = read_files_from_disk(data["files"]) + print(f"Dataset loaded with {len(dataset)} records.") + + print("Augmenting dataset...") + dataset = augment_hf_dataset(dataset, augmentation) + print("Augmented dataset.") + + + +# Initialize augmentation callable +apply_augmentation = Compose( + transforms=[ + Gain( + min_gain_in_db=-15.0, + max_gain_in_db=5.0, + p=0.5, + ), + PolarityInversion(p=0.5) + ] +) + +data = {"files": + ["/Users/fabiocat/Documents/git/sensein/senselab/src/tests/data_for_testing/audio_48khz_mono_16bits.wav", + "/Users/fabiocat/Documents/git/sensein/senselab/src/tests/data_for_testing/audio_48khz_mono_16bits.wav"] + } + +workflow(data, apply_augmentation) \ No newline at end of file diff --git a/src/senselab/audio/tasks/data_augmentation.py b/src/senselab/audio/tasks/data_augmentation.py new file mode 100644 index 0000000..f46d46e --- /dev/null +++ b/src/senselab/audio/tasks/data_augmentation.py @@ -0,0 +1,37 @@ +"""This module implements some utilities for the audio data augmentation task.""" +from typing import Any, Dict + +import torch +from datasets import Dataset +from torch_audiomentations import Compose + +from senselab.utils.tasks.input_output import _from_dict_to_hf_dataset, _from_hf_dataset_to_dict + + +def augment_hf_dataset(dataset: Dict[str, Any], augmentation: Compose, audio_column: str = 'audio') -> Dict[str, Any]: + """Resamples a Hugging Face `Dataset` object.""" + hf_dataset = _from_dict_to_hf_dataset(dataset) + + def _augment_hf_row(row: Dataset, augmentation: Compose, audio_column: str) -> Dict[str, Any]: + waveform = row[audio_column]['array'] + sampling_rate = row[audio_column]['sampling_rate'] + + # Ensure waveform is a PyTorch tensor + if not isinstance(waveform, torch.Tensor): + waveform = torch.tensor(waveform) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0).unsqueeze(0) # [num_samples] -> [1, 1, num_samples] + elif waveform.dim() == 2: + waveform = waveform.unsqueeze(1) # [batch_size, num_samples] -> [batch_size, 1, num_samples] + + augmented_hf_row = augmentation(waveform, sample_rate=sampling_rate).squeeze() + + return { "augmented_audio": { + "array": augmented_hf_row, + "sampling_rate": sampling_rate + } + } + + augmented_hf_dataset = hf_dataset.map(lambda x: _augment_hf_row(x, augmentation, audio_column)) + augmented_hf_dataset = augmented_hf_dataset.remove_columns([audio_column]) + return _from_hf_dataset_to_dict(augmented_hf_dataset)