-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathdata_utils.py
executable file
·46 lines (34 loc) · 1.7 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
from random import sample
# split data into train (70%), test (15%) and valid(15%) return tuple((trainX, trainY), (testX,testY), (validX,validY))
def split_dataset(x, y, ratio=None):
# number of examples
if ratio is None:
ratio = [0.7, 0.15, 0.15]
data_len = len(x)
lens = [int(data_len*item) for item in ratio]
train_x, train_y = x[:lens[0]], y[:lens[0]]
test_x, test_y = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]]
valid_x, valid_y = x[-lens[-1]:], y[-lens[-1]:]
return (train_x, train_y), (test_x, test_y), (valid_x, valid_y)
# generate batches from dataset, yield (x_gen, y_gen)
def batch_gen(x, y, batch_size):
# infinite while
while True:
for i in range(0, len(x), batch_size):
if (i + 1) * batch_size < len(x):
yield x[i: (i + 1) * batch_size].T, y[i: (i + 1) * batch_size].T
# generate batches, by random sampling a bunch of items, yield (x_gen, y_gen)
def rand_batch_gen(x, y, batch_size):
while True:
sample_idx = sample(list(np.arange(len(x))), batch_size)
yield x[sample_idx].T, y[sample_idx].T
# convert indices of alphabets into a string (word), return str(word)
def decode_word(alpha_seq, idx2alpha):
return ''.join([idx2alpha[alpha] for alpha in alpha_seq if alpha])
# convert indices of phonemes into list of phonemes (as string), return str(phoneme_list)
def decode_phonemes(pho_seq, idx2pho):
return ' '.join([idx2pho[pho] for pho in pho_seq if pho])
# a generic decode function, inputs : sequence, lookup
def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored
return separator.join([lookup[element] for element in sequence if element])