In [None]:
%matplotlib inline
import sys
import matplotlib.pyplot as plt
import IPython.display as ipd
from scipy.io.wavfile import write

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence


def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def load_data_from_json(file_path):
    with open(file_path, 'r') as file:
        test_text_sem_dic = json.load(file)
    for key in test_text_sem_dic:
        test_text_sem_dic[key] = torch.tensor(test_text_sem_dic[key])
    return test_text_sem_dic

def ensure_directory_exists(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

## LJ Speech

In [None]:
# choose pretrained model and the trained step
model = 'ljs_base'
# model = 'onehour_ljs_base'
# model = 'tenmin_ljs_base'
# step = 'G_50000'
# step = "G_100000"
step = "G_150000"
# step = "G_200000"
# step = "G_250000"
# step = "G_300000"
test_text_sem_dic_file = 'ljs_text_sem_ave_5120.json' # only for extracting test text, so any text_sem_dic_file is fine.


common_dir = '/data/vitsGPT/vits/'
log_dir = f'{common_dir}ori_vits/logs/'
save_dir = f'{log_dir}{model}/{step}/source_model_test_wav'
ensure_directory_exists(save_dir)
hps = utils.get_hparams_from_file(f"{log_dir}{model}/config.json")
sem_embedding = hps.data.sem_embedding
print(f"configs: {sem_embedding}")

test_text_sem_dic = load_data_from_json(f"{common_dir}filelists/{test_text_sem_dic_file}")

In [None]:
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model).cuda()
_ = net_g.eval()

_ = utils.load_checkpoint(f"{log_dir}{model}/{step}.pth", net_g, None)

In [None]:
keys = list(test_text_sem_dic.keys())
range_limit = len(keys)
range_limit = min(range_limit, len(keys))
print(f"range_limit: {range_limit}")
m = 5

for i in range(range_limit):
    key = keys[i]  # key 是字符串
    s = get_text(key, hps)  # 将 key 赋值给 s
    print(key)
    with torch.no_grad():
        x_tst = s.cuda().unsqueeze(0)
        x_tst_lengths = torch.LongTensor([s.size(0)]).cuda()     
        print(f"x_tst: {x_tst.shape}")
        print(f"x_tst_lengths: {x_tst_lengths}")  
        audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
    if i < m:
        ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
    write(f"{save_dir}/output_ori_{i}.wav", hps.data.sampling_rate, audio)