New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add wsj0mix dataset #895
Merged
Merged
Add wsj0mix dataset #895
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from . import ( | ||
dataset, | ||
metrics, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import utils, wsj0mix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from typing import List | ||
from functools import partial | ||
from collections import namedtuple | ||
|
||
import torch | ||
|
||
from . import wsj0mix | ||
|
||
Batch = namedtuple("Batch", ["mix", "src", "mask"]) | ||
|
||
|
||
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate): | ||
if dataset_type == "wsj0mix": | ||
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate) | ||
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate) | ||
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate) | ||
else: | ||
raise ValueError(f"Unexpected dataset: {dataset_type}") | ||
return train, validation, evaluation | ||
|
||
|
||
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_start=False): | ||
"""Ensure waveform has exact number of frames by slicing or padding""" | ||
mix = sample[1] # [1, num_frames] | ||
src = torch.cat(sample[2], 0) # [num_sources, num_frames] | ||
|
||
num_channels, num_frames = src.shape | ||
if num_frames >= target_num_frames: | ||
if random_start and num_frames > target_num_frames: | ||
start_frame = torch.randint(num_frames - target_num_frames, [1]) | ||
mix = mix[:, start_frame:] | ||
src = src[:, start_frame:] | ||
mix = mix[:, :target_num_frames] | ||
src = src[:, :target_num_frames] | ||
mask = torch.ones_like(mix) | ||
else: | ||
num_padding = target_num_frames - num_frames | ||
pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device) | ||
mix = torch.cat([mix, pad], 1) | ||
src = torch.cat([src, pad.expand(num_channels, -1)], 1) | ||
mask = torch.ones_like(mix) | ||
mask[..., num_frames:] = 0 | ||
return mix, src, mask | ||
|
||
|
||
|
||
def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration): | ||
target_num_frames = int(duration * sample_rate) | ||
|
||
mixes, srcs, masks = [], [], [] | ||
for sample in samples: | ||
mix, src, mask = _fix_num_frames(sample, target_num_frames, random_start=True) | ||
|
||
mixes.append(mix) | ||
srcs.append(src) | ||
masks.append(mask) | ||
|
||
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) | ||
|
||
|
||
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]): | ||
max_num_frames = max(s[1].shape[-1] for s in samples) | ||
|
||
mixes, srcs, masks = [], [], [] | ||
for sample in samples: | ||
mix, src, mask = _fix_num_frames(sample, max_num_frames, random_start=False) | ||
|
||
mixes.append(mix) | ||
srcs.append(src) | ||
masks.append(mask) | ||
|
||
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0)) | ||
|
||
|
||
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4): | ||
assert mode in ["train", "test"] | ||
if dataset_type == "wsj0mix": | ||
if mode == 'train': | ||
if sample_rate is None: | ||
raise ValueError("sample_rate is not given.") | ||
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration) | ||
return collate_fn_wsj0mix_test | ||
raise ValueError(f"Unexpected dataset: {dataset_type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from pathlib import Path | ||
from typing import Union, Tuple, List | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
import torchaudio | ||
|
||
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] | ||
|
||
|
||
class WSJ0Mix(Dataset): | ||
"""Create a Dataset for wsj0-mix. | ||
|
||
Args: | ||
root (str or Path): Path to the directory where the dataset is found. | ||
num_speakers (int): The number of speakers, which determines the directories | ||
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect | ||
N source audios. | ||
sample_rate (int): Expected sample rate of audio files. If any of the audio has a | ||
different sample rate, raises ``ValueError``. | ||
audio_ext (str): The extension of audio files to find. (default: ".wav") | ||
""" | ||
def __init__( | ||
self, | ||
root: Union[str, Path], | ||
num_speakers: int, | ||
sample_rate: int, | ||
audio_ext: str = ".wav", | ||
): | ||
self.root = Path(root) | ||
self.sample_rate = sample_rate | ||
self.mix_dir = (self.root / "mix").resolve() | ||
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could use os path join |
||
|
||
self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")] | ||
self.files.sort() | ||
|
||
def _load_audio(self, path) -> torch.Tensor: | ||
waveform, sample_rate = torchaudio.load(path) | ||
if sample_rate != self.sample_rate: | ||
raise ValueError( | ||
f"The dataset contains audio file of sample rate {sample_rate}. " | ||
"Where the requested sample rate is {self.sample_rate}." | ||
) | ||
return waveform | ||
Comment on lines
+39
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this function serve beyond wrapping load? Ensures sample rate is the same? |
||
|
||
def _load_sample(self, filename) -> SampleType: | ||
mixed = self._load_audio(str(self.mix_dir / filename)) | ||
srcs = [] | ||
for i, dir_ in enumerate(self.src_dirs): | ||
src = self._load_audio(str(dir_ / filename)) | ||
if mixed.shape != src.shape: | ||
raise ValueError( | ||
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}" | ||
) | ||
srcs.append(src) | ||
return self.sample_rate, mixed, srcs | ||
|
||
def __len__(self) -> int: | ||
return len(self.files) | ||
|
||
def __getitem__(self, key: int) -> SampleType: | ||
"""Load the n-th sample from the dataset. | ||
Args: | ||
n (int): The index of the sample to be loaded | ||
Returns: | ||
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)`` | ||
""" | ||
return self._load_sample(self.files[key]) |
111 changes: 111 additions & 0 deletions
111
test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import os | ||
|
||
from torchaudio_unittest.common_utils import ( | ||
TempDirMixin, | ||
TorchaudioTestCase, | ||
get_whitenoise, | ||
save_wav, | ||
normalize_wav, | ||
) | ||
|
||
from source_separation.utils.dataset import wsj0mix | ||
|
||
|
||
_FILENAMES = [ | ||
"012c0207_1.9952_01cc0202_-1.9952.wav", | ||
"01co0302_1.63_014c020q_-1.63.wav", | ||
"01do0316_0.24011_205a0104_-0.24011.wav", | ||
"01lc020x_1.1301_027o030r_-1.1301.wav", | ||
"01mc0202_0.34056_205o0106_-0.34056.wav", | ||
"01nc020t_0.53821_018o030w_-0.53821.wav", | ||
"01po030f_2.2136_40ko031a_-2.2136.wav", | ||
"01ra010o_2.4098_403a010f_-2.4098.wav", | ||
"01xo030b_0.22377_016o031a_-0.22377.wav", | ||
"02ac020x_0.68566_01ec020b_-0.68566.wav", | ||
"20co010m_0.82801_019c0212_-0.82801.wav", | ||
"20da010u_1.2483_017c0211_-1.2483.wav", | ||
"20oo010d_1.0631_01ic020s_-1.0631.wav", | ||
"20sc0107_2.0222_20fo010h_-2.0222.wav", | ||
"20tc010f_0.051456_404a0110_-0.051456.wav", | ||
"407c0214_1.1712_02ca0113_-1.1712.wav", | ||
"40ao030w_2.4697_20vc010a_-2.4697.wav", | ||
"40pa0101_1.1087_40ea0107_-1.1087.wav", | ||
] | ||
|
||
|
||
def _mock_dataset(root_dir, num_speaker): | ||
dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)] | ||
for dirname in dirnames: | ||
os.makedirs(os.path.join(root_dir, dirname), exist_ok=True) | ||
|
||
seed = 0 | ||
sample_rate = 8000 | ||
expected = [] | ||
for filename in _FILENAMES: | ||
mix = None | ||
src = [] | ||
for dirname in dirnames: | ||
waveform = get_whitenoise( | ||
sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed | ||
) | ||
seed += 1 | ||
|
||
path = os.path.join(root_dir, dirname, filename) | ||
save_wav(path, waveform, sample_rate) | ||
waveform = normalize_wav(waveform) | ||
|
||
if dirname == "mix": | ||
mix = waveform | ||
else: | ||
src.append(waveform) | ||
expected.append((sample_rate, mix, src)) | ||
return expected | ||
|
||
|
||
class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase): | ||
backend = "default" | ||
root_dir = None | ||
expected = None | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.root_dir = cls.get_base_temp_dir() | ||
cls.expected = _mock_dataset(cls.root_dir, 2) | ||
|
||
def test_wsj0mix(self): | ||
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000) | ||
|
||
n_ite = 0 | ||
for i, sample in enumerate(dataset): | ||
(_, sample_mix, sample_src) = sample | ||
(_, expected_mix, expected_src) = self.expected[i] | ||
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8) | ||
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8) | ||
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8) | ||
n_ite += 1 | ||
assert n_ite == len(self.expected) | ||
|
||
|
||
class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase): | ||
backend = "default" | ||
root_dir = None | ||
expected = None | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.root_dir = cls.get_base_temp_dir() | ||
cls.expected = _mock_dataset(cls.root_dir, 3) | ||
|
||
def test_wsj0mix(self): | ||
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000) | ||
|
||
n_ite = 0 | ||
for i, sample in enumerate(dataset): | ||
(_, sample_mix, sample_src) = sample | ||
(_, expected_mix, expected_src) = self.expected[i] | ||
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8) | ||
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8) | ||
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8) | ||
self.assertEqual(sample_src[2], expected_src[2], atol=5e-5, rtol=1e-8) | ||
n_ite += 1 | ||
assert n_ite == len(self.expected) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name may be misleading?