-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
79 lines (59 loc) · 1.97 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
import os
import numpy as np
from scipy.io.wavfile import read
class DataLoader(object):
def __init__(self, config):
self.data_dir = './data/timit/'
self.batch_size = config.batch_size
self.num_steps = config.num_steps
self.data = self.load_data()
self.num_batches = len(self.data) / self.batch_size
np.random.shuffle(self.data)
def load_data(self):
data = []
dirs = [f for f in os.listdir(self.data_dir) if f.startswith('dr')]
for d in dirs:
sets = np.array(self.load_wav_set(d))
for s in sets:
for x in s:
if 'dr2-faem0' in d:
y = np.ones(x.shape)
else:
y = np.zeros(x.shape)
data.append([x, y])
return np.array(data)
def load_wav_set(self, path):
wav_files = [f for f in os.listdir(os.path.join(self.data_dir, path)) if f.endswith('.wav')]
data = []
for f in wav_files:
d = self.load_wav(os.path.join(path, f))
c = self.chunk_data(d)
data.append(c)
return np.array(data).flatten()
def load_wav(self, path):
w = read(os.path.join(self.data_dir, path))
data = np.array(w[1], dtype=float)
d = []
for x in data:
d.append([x])
return d
def chunk_data(self, data):
chunked = []
i = 0
while i + self.num_steps < len(data):
d = data[i:i + self.num_steps]
i += self.num_steps
chunked.append(d)
return np.array(chunked)
def next_batch(self):
x_batch = []
y_batch = []
np.random.shuffle(self.data)
data = []
for i in xrange(self.batch_size):
data.append(self.data[i])
for d in data:
x_batch.append(d[0])
y_batch.append(d[1])
return x_batch, y_batch