Skip to content

Commit

Permalink
Merge pull request #18 from sensein/dev
Browse files Browse the repository at this point in the history
adding data augmentation with torch-audiomentation
  • Loading branch information
fabiocat93 committed May 17, 2024
2 parents ebbfa6d + ed12817 commit 4f93d83
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 1 deletion.
87 changes: 86 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ huggingface-hub = "^0.23.0"
praat-parselmouth = "^0.4.3"
iso-639 = {git = "https://github.com/noumar/iso639.git", tag = "0.4.5"}
opensmile = "^2.5.0"
audiomentations = "^0.35.0"
torch-audiomentations = "^0.11.1"

[tool.poetry.group.dev]
optional = true
Expand Down
42 changes: 42 additions & 0 deletions scripts/experiment5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""This script is used to test the audio tasks."""

from typing import Any, Dict

from torch_audiomentations import Compose, Gain, PolarityInversion

from senselab.audio.tasks.data_augmentation import augment_hf_dataset
from senselab.utils.decorators import get_response_time
from senselab.utils.tasks.input_output import read_files_from_disk


@get_response_time
def workflow(data: Dict[str, Any], augmentation: Compose) -> None:
"""This function reads audio files from disk, and transcribes them using Whisper."""
print("Starting to read files from disk...")
dataset = read_files_from_disk(data["files"])
print(f"Dataset loaded with {len(dataset)} records.")

print("Augmenting dataset...")
dataset = augment_hf_dataset(dataset, augmentation)
print("Augmented dataset.")



# Initialize augmentation callable
apply_augmentation = Compose(
transforms=[
Gain(
min_gain_in_db=-15.0,
max_gain_in_db=5.0,
p=0.5,
),
PolarityInversion(p=0.5)
]
)

data = {"files":
["/Users/fabiocat/Documents/git/sensein/senselab/src/tests/data_for_testing/audio_48khz_mono_16bits.wav",
"/Users/fabiocat/Documents/git/sensein/senselab/src/tests/data_for_testing/audio_48khz_mono_16bits.wav"]
}

workflow(data, apply_augmentation)
37 changes: 37 additions & 0 deletions src/senselab/audio/tasks/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""This module implements some utilities for the audio data augmentation task."""
from typing import Any, Dict

import torch
from datasets import Dataset
from torch_audiomentations import Compose

from senselab.utils.tasks.input_output import _from_dict_to_hf_dataset, _from_hf_dataset_to_dict


def augment_hf_dataset(dataset: Dict[str, Any], augmentation: Compose, audio_column: str = 'audio') -> Dict[str, Any]:
"""Resamples a Hugging Face `Dataset` object."""
hf_dataset = _from_dict_to_hf_dataset(dataset)

def _augment_hf_row(row: Dataset, augmentation: Compose, audio_column: str) -> Dict[str, Any]:
waveform = row[audio_column]['array']
sampling_rate = row[audio_column]['sampling_rate']

# Ensure waveform is a PyTorch tensor
if not isinstance(waveform, torch.Tensor):
waveform = torch.tensor(waveform)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0).unsqueeze(0) # [num_samples] -> [1, 1, num_samples]
elif waveform.dim() == 2:
waveform = waveform.unsqueeze(1) # [batch_size, num_samples] -> [batch_size, 1, num_samples]

augmented_hf_row = augmentation(waveform, sample_rate=sampling_rate).squeeze()

return { "augmented_audio": {
"array": augmented_hf_row,
"sampling_rate": sampling_rate
}
}

augmented_hf_dataset = hf_dataset.map(lambda x: _augment_hf_row(x, augmentation, audio_column))
augmented_hf_dataset = augmented_hf_dataset.remove_columns([audio_column])
return _from_hf_dataset_to_dict(augmented_hf_dataset)

0 comments on commit 4f93d83

Please sign in to comment.