-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
75 lines (61 loc) · 2.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import warnings
warnings.filterwarnings("ignore") # ignore warnings in this notebook
import argparse
import os, sys
import numpy as np
import torch
from tqdm import *
from hparams import HParams as hp
from audio import save_to_wav
from models import Text2Mel, SSRN
from datasets.lj_speech import vocab, idx2char, get_test_data
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--message', type=str, default='')
parser.add_argument('--filename', type=str, default=None)
args = parser.parse_args()
project_path = "/home/alotaima/Projects/side/onlysudo/src/ai/tts"
args.outdir = '/src/server/public/ai/tts'
# args.outdir = '/server/public/ai/tts'
sys.path.append(project_path)
cur_path = os.path.abspath(os.getcwd())
root_path = '/'.join(os.path.abspath(os.getcwd()).split('/')[:-1])
output_path = f"{root_path}{args.outdir}/{args.filename}"
if args.message == '':
print('Need message!')
exit()
print('start')
if os.path.exists(output_path):
os.remove(output_path)
torch.set_grad_enabled(False)
text2mel = Text2Mel(vocab)
text2mel.load_state_dict(torch.load(f"{project_path}/ljspeech-text2mel.pth").state_dict())
text2mel = text2mel.eval()
ssrn = SSRN()
ssrn.load_state_dict(torch.load(f"{project_path}/ljspeech-ssrn.pth").state_dict())
ssrn = ssrn.eval()
normalized_sentence = "".join([c if c.lower() in vocab else '' for c in args.message])
print(normalized_sentence)
sentences = [normalized_sentence]
max_N = len(normalized_sentence)
# print(max_N)
L = torch.from_numpy(get_test_data(sentences, max_N))
# print(L.shape)
zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32))
# print(zeros.shape)
Y = zeros
A = None
# for t in range(hp.max_T):
while True:
_, Y_t, A = text2mel(L, Y, monotonic_attention=True)
# print(Y_t.shape, A.shape)
Y = torch.cat((zeros, Y_t), -1)
_, attention = torch.max(A[0, :, -1], 0)
attention = attention.item()
if L[0, attention] == vocab.index('E'): # EOS
break
_, Z = ssrn(Y)
Z = Z.cpu().detach().numpy()
# print(Z[0, :, :].T.shape)
save_to_wav(Z[0, :, :].T, output_path)
print('Done!')