generated from sensein/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from sensein/dev
adding data augmentation with torch-audiomentation
- Loading branch information
Showing
4 changed files
with
167 additions
and
1 deletion.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""This script is used to test the audio tasks.""" | ||
|
||
from typing import Any, Dict | ||
|
||
from torch_audiomentations import Compose, Gain, PolarityInversion | ||
|
||
from senselab.audio.tasks.data_augmentation import augment_hf_dataset | ||
from senselab.utils.decorators import get_response_time | ||
from senselab.utils.tasks.input_output import read_files_from_disk | ||
|
||
|
||
@get_response_time | ||
def workflow(data: Dict[str, Any], augmentation: Compose) -> None: | ||
"""This function reads audio files from disk, and transcribes them using Whisper.""" | ||
print("Starting to read files from disk...") | ||
dataset = read_files_from_disk(data["files"]) | ||
print(f"Dataset loaded with {len(dataset)} records.") | ||
|
||
print("Augmenting dataset...") | ||
dataset = augment_hf_dataset(dataset, augmentation) | ||
print("Augmented dataset.") | ||
|
||
|
||
|
||
# Initialize augmentation callable | ||
apply_augmentation = Compose( | ||
transforms=[ | ||
Gain( | ||
min_gain_in_db=-15.0, | ||
max_gain_in_db=5.0, | ||
p=0.5, | ||
), | ||
PolarityInversion(p=0.5) | ||
] | ||
) | ||
|
||
data = {"files": | ||
["/Users/fabiocat/Documents/git/sensein/senselab/src/tests/data_for_testing/audio_48khz_mono_16bits.wav", | ||
"/Users/fabiocat/Documents/git/sensein/senselab/src/tests/data_for_testing/audio_48khz_mono_16bits.wav"] | ||
} | ||
|
||
workflow(data, apply_augmentation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""This module implements some utilities for the audio data augmentation task.""" | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from datasets import Dataset | ||
from torch_audiomentations import Compose | ||
|
||
from senselab.utils.tasks.input_output import _from_dict_to_hf_dataset, _from_hf_dataset_to_dict | ||
|
||
|
||
def augment_hf_dataset(dataset: Dict[str, Any], augmentation: Compose, audio_column: str = 'audio') -> Dict[str, Any]: | ||
"""Resamples a Hugging Face `Dataset` object.""" | ||
hf_dataset = _from_dict_to_hf_dataset(dataset) | ||
|
||
def _augment_hf_row(row: Dataset, augmentation: Compose, audio_column: str) -> Dict[str, Any]: | ||
waveform = row[audio_column]['array'] | ||
sampling_rate = row[audio_column]['sampling_rate'] | ||
|
||
# Ensure waveform is a PyTorch tensor | ||
if not isinstance(waveform, torch.Tensor): | ||
waveform = torch.tensor(waveform) | ||
if waveform.dim() == 1: | ||
waveform = waveform.unsqueeze(0).unsqueeze(0) # [num_samples] -> [1, 1, num_samples] | ||
elif waveform.dim() == 2: | ||
waveform = waveform.unsqueeze(1) # [batch_size, num_samples] -> [batch_size, 1, num_samples] | ||
|
||
augmented_hf_row = augmentation(waveform, sample_rate=sampling_rate).squeeze() | ||
|
||
return { "augmented_audio": { | ||
"array": augmented_hf_row, | ||
"sampling_rate": sampling_rate | ||
} | ||
} | ||
|
||
augmented_hf_dataset = hf_dataset.map(lambda x: _augment_hf_row(x, augmentation, audio_column)) | ||
augmented_hf_dataset = augmented_hf_dataset.remove_columns([audio_column]) | ||
return _from_hf_dataset_to_dict(augmented_hf_dataset) |