Skip to content

Commit

Permalink
Add test collate function
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Sep 17, 2020
1 parent a7f0e2a commit ace5b6b
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions examples/source_separation/utils/dataset_utils.py
Expand Up @@ -35,18 +35,42 @@ def _fix_num_frames(waveform: torch.Tensor, target_num_frames: int):
return torch.cat([waveform, pad], 1)


def collate_fn_wsj0mix(samples: List[wsj0mix.Sample], sample_rate, duration):
def collate_fn_wsj0mix_train(samples: List[wsj0mix.Sample], sample_rate, duration):
target_num_frames = int(duration * sample_rate)

mixed = [_fix_num_frames(s.mix, target_num_frames) for s in samples]
mixed = torch.stack(mixed, 0)
mixes, srcs = [], []
for sample in samples:
mix = sample.mix
src = torch.cat(sample.src, 0)

src = [_fix_num_frames(torch.cat(s.src, 0), target_num_frames) for s in samples]
src = torch.stack(src, 0)
return Batch(mixed, src)
num_frames = mix.shape[-1]
if num_frames > target_num_frames:
start_frame = torch.randint(num_frames - target_num_frames, [1])
mix = mix[..., start_frame:]
src = src[..., start_frame:]

mix = _fix_num_frames(mix, target_num_frames)
src = _fix_num_frames(src, target_num_frames)

def get_collate_fn(dataset_type, sample_rate, duration=4):
mixes.append(mix)
srcs.append(src)

return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0))


def collate_fn_wsj0mix_test(samples: List[wsj0mix.Sample]):
return [Batch(
sample.mix.unsqueeze(0),
torch.cat(sample.src, 0).unsqueeze(0),
) for sample in samples]


def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"]
if dataset_type == "wsj0mix":
return partial(collate_fn_wsj0mix, sample_rate=sample_rate, duration=duration)
if mode == 'train':
if sample_rate is None:
raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
return collate_fn_wsj0mix_test
raise ValueError(f"Unexpected dataset: {dataset_type}")

0 comments on commit ace5b6b

Please sign in to comment.