Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wsj0mix dataset #895

Merged
merged 1 commit into from Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/source_separation/utils/__init__.py
@@ -1,3 +1,4 @@
from . import (
dataset,
metrics,
)
1 change: 1 addition & 0 deletions examples/source_separation/utils/dataset/__init__.py
@@ -0,0 +1 @@
from . import utils, wsj0mix
83 changes: 83 additions & 0 deletions examples/source_separation/utils/dataset/utils.py
@@ -0,0 +1,83 @@
from typing import List
from functools import partial
from collections import namedtuple

import torch

from . import wsj0mix

Batch = namedtuple("Batch", ["mix", "src", "mask"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name may be misleading?



def get_dataset(dataset_type, root_dir, num_speakers, sample_rate):
if dataset_type == "wsj0mix":
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
else:
raise ValueError(f"Unexpected dataset: {dataset_type}")
return train, validation, evaluation


def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, random_start=False):
"""Ensure waveform has exact number of frames by slicing or padding"""
mix = sample[1] # [1, num_frames]
src = torch.cat(sample[2], 0) # [num_sources, num_frames]

num_channels, num_frames = src.shape
if num_frames >= target_num_frames:
if random_start and num_frames > target_num_frames:
start_frame = torch.randint(num_frames - target_num_frames, [1])
mix = mix[:, start_frame:]
src = src[:, start_frame:]
mix = mix[:, :target_num_frames]
src = src[:, :target_num_frames]
mask = torch.ones_like(mix)
else:
num_padding = target_num_frames - num_frames
pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
mix = torch.cat([mix, pad], 1)
src = torch.cat([src, pad.expand(num_channels, -1)], 1)
mask = torch.ones_like(mix)
mask[..., num_frames:] = 0
return mix, src, mask



def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
target_num_frames = int(duration * sample_rate)

mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, target_num_frames, random_start=True)

mixes.append(mix)
srcs.append(src)
masks.append(mask)

return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))


def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType]):
max_num_frames = max(s[1].shape[-1] for s in samples)

mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, max_num_frames, random_start=False)

mixes.append(mix)
srcs.append(src)
masks.append(mask)

return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))


def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"]
if dataset_type == "wsj0mix":
if mode == 'train':
if sample_rate is None:
raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
return collate_fn_wsj0mix_test
raise ValueError(f"Unexpected dataset: {dataset_type}")
70 changes: 70 additions & 0 deletions examples/source_separation/utils/dataset/wsj0mix.py
@@ -0,0 +1,70 @@
from pathlib import Path
from typing import Union, Tuple, List

import torch
from torch.utils.data import Dataset

import torchaudio

SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]


class WSJ0Mix(Dataset):
"""Create a Dataset for wsj0-mix.

Args:
root (str or Path): Path to the directory where the dataset is found.
num_speakers (int): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios.
sample_rate (int): Expected sample rate of audio files. If any of the audio has a
different sample rate, raises ``ValueError``.
audio_ext (str): The extension of audio files to find. (default: ".wav")
"""
def __init__(
self,
root: Union[str, Path],
num_speakers: int,
sample_rate: int,
audio_ext: str = ".wav",
):
self.root = Path(root)
self.sample_rate = sample_rate
self.mix_dir = (self.root / "mix").resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could use os path join


self.files = [p.name for p in self.mix_dir.glob(f"*{audio_ext}")]
self.files.sort()

def _load_audio(self, path) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
if sample_rate != self.sample_rate:
raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}. "
"Where the requested sample rate is {self.sample_rate}."
)
return waveform
Comment on lines +39 to +46
Copy link
Contributor

@vincentqb vincentqb Oct 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this function serve beyond wrapping load? Ensures sample rate is the same?


def _load_sample(self, filename) -> SampleType:
mixed = self._load_audio(str(self.mix_dir / filename))
srcs = []
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape:
raise ValueError(
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
srcs.append(src)
return self.sample_rate, mixed, srcs

def __len__(self) -> int:
return len(self.files)

def __getitem__(self, key: int) -> SampleType:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
"""
return self._load_sample(self.files[key])
111 changes: 111 additions & 0 deletions test/torchaudio_unittest/example/souce_sepration/wsj0mix_test.py
@@ -0,0 +1,111 @@
import os

from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
normalize_wav,
)

from source_separation.utils.dataset import wsj0mix


_FILENAMES = [
"012c0207_1.9952_01cc0202_-1.9952.wav",
"01co0302_1.63_014c020q_-1.63.wav",
"01do0316_0.24011_205a0104_-0.24011.wav",
"01lc020x_1.1301_027o030r_-1.1301.wav",
"01mc0202_0.34056_205o0106_-0.34056.wav",
"01nc020t_0.53821_018o030w_-0.53821.wav",
"01po030f_2.2136_40ko031a_-2.2136.wav",
"01ra010o_2.4098_403a010f_-2.4098.wav",
"01xo030b_0.22377_016o031a_-0.22377.wav",
"02ac020x_0.68566_01ec020b_-0.68566.wav",
"20co010m_0.82801_019c0212_-0.82801.wav",
"20da010u_1.2483_017c0211_-1.2483.wav",
"20oo010d_1.0631_01ic020s_-1.0631.wav",
"20sc0107_2.0222_20fo010h_-2.0222.wav",
"20tc010f_0.051456_404a0110_-0.051456.wav",
"407c0214_1.1712_02ca0113_-1.1712.wav",
"40ao030w_2.4697_20vc010a_-2.4697.wav",
"40pa0101_1.1087_40ea0107_-1.1087.wav",
]


def _mock_dataset(root_dir, num_speaker):
dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)]
for dirname in dirnames:
os.makedirs(os.path.join(root_dir, dirname), exist_ok=True)

seed = 0
sample_rate = 8000
expected = []
for filename in _FILENAMES:
mix = None
src = []
for dirname in dirnames:
waveform = get_whitenoise(
sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed
)
seed += 1

path = os.path.join(root_dir, dirname, filename)
save_wav(path, waveform, sample_rate)
waveform = normalize_wav(waveform)

if dirname == "mix":
mix = waveform
else:
src.append(waveform)
expected.append((sample_rate, mix, src))
return expected


class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
expected = None

@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.expected = _mock_dataset(cls.root_dir, 2)

def test_wsj0mix(self):
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000)

n_ite = 0
for i, sample in enumerate(dataset):
(_, sample_mix, sample_src) = sample
(_, expected_mix, expected_src) = self.expected[i]
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
n_ite += 1
assert n_ite == len(self.expected)


class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
expected = None

@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.expected = _mock_dataset(cls.root_dir, 3)

def test_wsj0mix(self):
dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000)

n_ite = 0
for i, sample in enumerate(dataset):
(_, sample_mix, sample_src) = sample
(_, expected_mix, expected_src) = self.expected[i]
self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
self.assertEqual(sample_src[2], expected_src[2], atol=5e-5, rtol=1e-8)
n_ite += 1
assert n_ite == len(self.expected)