-
Notifications
You must be signed in to change notification settings - Fork 60
/
dataloader.py
81 lines (68 loc) · 2.7 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch.utils.data
import torchaudio
import os
from utils import *
import random
from natsort import natsorted
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
from torch.utils.data.distributed import DistributedSampler
class DemandDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, cut_len=16000 * 2):
self.cut_len = cut_len
self.clean_dir = os.path.join(data_dir, "clean")
self.noisy_dir = os.path.join(data_dir, "noisy")
self.clean_wav_name = os.listdir(self.clean_dir)
self.clean_wav_name = natsorted(self.clean_wav_name)
def __len__(self):
return len(self.clean_wav_name)
def __getitem__(self, idx):
clean_file = os.path.join(self.clean_dir, self.clean_wav_name[idx])
noisy_file = os.path.join(self.noisy_dir, self.clean_wav_name[idx])
clean_ds, _ = torchaudio.load(clean_file)
noisy_ds, _ = torchaudio.load(noisy_file)
clean_ds = clean_ds.squeeze()
noisy_ds = noisy_ds.squeeze()
length = len(clean_ds)
assert length == len(noisy_ds)
if length < self.cut_len:
units = self.cut_len // length
clean_ds_final = []
noisy_ds_final = []
for i in range(units):
clean_ds_final.append(clean_ds)
noisy_ds_final.append(noisy_ds)
clean_ds_final.append(clean_ds[: self.cut_len % length])
noisy_ds_final.append(noisy_ds[: self.cut_len % length])
clean_ds = torch.cat(clean_ds_final, dim=-1)
noisy_ds = torch.cat(noisy_ds_final, dim=-1)
else:
# randomly cut 2 seconds segment
wav_start = random.randint(0, length - self.cut_len)
noisy_ds = noisy_ds[wav_start : wav_start + self.cut_len]
clean_ds = clean_ds[wav_start : wav_start + self.cut_len]
return clean_ds, noisy_ds, length
def load_data(ds_dir, batch_size, n_cpu, cut_len):
torchaudio.set_audio_backend("sox_io") # in linux
train_dir = os.path.join(ds_dir, "train")
test_dir = os.path.join(ds_dir, "test")
train_ds = DemandDataset(train_dir, cut_len)
test_ds = DemandDataset(test_dir, cut_len)
train_dataset = torch.utils.data.DataLoader(
dataset=train_ds,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(train_ds),
drop_last=True,
num_workers=n_cpu,
)
test_dataset = torch.utils.data.DataLoader(
dataset=test_ds,
batch_size=batch_size,
pin_memory=True,
shuffle=False,
sampler=DistributedSampler(test_ds),
drop_last=False,
num_workers=n_cpu,
)
return train_dataset, test_dataset