forked from foamliu/Speech-Transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pre_process.py
77 lines (55 loc) · 1.96 KB
/
pre_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import pickle
from tqdm import tqdm
from config import wav_folder, tran_file, pickle_file
from utils import ensure_folder
def get_data(split):
print('getting {} data...'.format(split))
global VOCAB
with open(tran_file, 'r', encoding='utf-8') as file:
lines = file.readlines()
tran_dict = dict()
for line in lines:
tokens = line.split()
key = tokens[0]
trn = ''.join(tokens[1:])
tran_dict[key] = trn
samples = []
folder = os.path.join(wav_folder, split)
ensure_folder(folder)
dirs = [os.path.join(folder, d) for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))]
for dir in tqdm(dirs):
files = [f for f in os.listdir(dir) if f.endswith('.wav')]
for f in files:
wave = os.path.join(dir, f)
key = f.split('.')[0]
if key in tran_dict:
trn = tran_dict[key]
trn = list(trn.strip()) + ['<eos>']
for token in trn:
build_vocab(token)
trn = [VOCAB[token] for token in trn]
samples.append({'trn': trn, 'wave': wave})
print('split: {}, num_files: {}'.format(split, len(samples)))
return samples
def build_vocab(token):
global VOCAB, IVOCAB
if not token in VOCAB:
next_index = len(VOCAB)
VOCAB[token] = next_index
IVOCAB[next_index] = token
if __name__ == "__main__":
VOCAB = {'<sos>': 0, '<eos>': 1}
IVOCAB = {0: '<sos>', 1: '<eos>'}
data = dict()
data['VOCAB'] = VOCAB
data['IVOCAB'] = IVOCAB
data['train'] = get_data('train')
data['dev'] = get_data('dev')
data['test'] = get_data('test')
with open(pickle_file, 'wb') as file:
pickle.dump(data, file)
print('num_train: ' + str(len(data['train'])))
print('num_dev: ' + str(len(data['dev'])))
print('num_test: ' + str(len(data['test'])))
print('vocab_size: ' + str(len(data['VOCAB'])))