In [None]:
%matplotlib inline
import sys
import matplotlib.pyplot as plt
import IPython.display as ipd
from scipy.io.wavfile import write

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from sem_data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from sem_models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence


def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def load_data_from_json(file_path):
    with open(file_path, 'r') as file:
        test_text_sem_dic = json.load(file)
    for key in test_text_sem_dic:
        test_text_sem_dic[key] = torch.tensor(test_text_sem_dic[key])
    return test_text_sem_dic

def ensure_directory_exists(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

## 单一话者

In [None]:
# choose pretrained model and the trained step
model = 'ljs_sem_mat_phone'
model = 'ljs_sem_mat_text'
model = 'ljs_sem_mat_bert_phone'
model = 'ljs_sem_mat_bert_text'
model = 'onehour_ljs_sem_mat_phone'
model = 'onehour_ljs_sem_mat_text'
model = 'onehour_ljs_sem_mat_bert_phone'
model = 'onehour_ljs_sem_mat_bert_text'
model = 'emovdb_sem_mat_phone_pretrained16'
model = 'emovdb_sem_mat_text_pretrained16'
model = 'emovdb_sem_mat_bert_phone_pretrained16'
model = 'emovdb_sem_mat_bert_text_pretrained16'

# step = 'G_50000'
# step = "G_100000"
step = "G_150000"


common_dir = 'vits/'
log_dir = f'{common_dir}sem_vits/logs/'
save_dir = f'{log_dir}{model}/{step}/source_model_test_wav'
ensure_directory_exists(save_dir)
hps = utils.get_hparams_from_file(f"{log_dir}{model}/config.json")
sem_embedding = hps.data.sem_embedding
print(f"configs: {sem_embedding}")

# Dictionary to map the model to its corresponding test_text_sem_dic_file
model_to_test_text_sem_dic = {
    'ljs_sem_mat_phone': 'ljs_text_sem_mat_phone_t5120.json',
    'ljs_sem_mat_text': 'ljs_text_sem_mat_text_t5120.json',
    'ljs_sem_mat_bert_text': 'ljs_text_bert_text_768.json',
    'ljs_sem_mat_bert_phone': 'ljs_text_bert_phone_768.json',
    'onehour_ljs_sem_mat_phone': 'ljs_text_sem_mat_phone_t5120.json',
    'onehour_ljs_sem_mat_text': 'ljs_text_sem_mat_text_t5120.json',
    'onehour_ljs_sem_mat_bert_text': 'ljs_text_bert_text_768.json',
    'onehour_ljs_sem_mat_bert_phone': 'ljs_text_bert_phone_768.json',
    'librif_sem_mat_phone': 'librif_text_sem_mat_phone_t5120.json',
    'librif_sem_mat_text': 'librif_text_sem_mat_text_t5120.json',
    'librif_sem_mat_bert_text': 'librif_text_bert_text_768.json',
    'librif_sem_mat_bert_phone': 'librif_text_bert_phone_768.json',
    'emovdb_sem_mat_phone_pretrained': 'emovdb_text_sem_mat_phone_t5120.json',
    'emovdb_sem_mat_phone_pretrained16': 'emovdb_text_sem_mat_phone_t5120.json',
    'emovdb_sem_mat_text_pretrained': 'emovdb_text_sem_mat_text_t5120.json',
    'emovdb_sem_mat_text_pretrained16': 'emovdb_text_sem_mat_text_t5120.json',
    'emovdb_sem_mat_bert_text_pretrained': 'emovdb_text_bert_text_768.json',
    'emovdb_sem_mat_bert_text_pretrained16': 'emovdb_text_bert_text_768.json',
    'emovdb_sem_mat_bert_phone_pretrained': 'emovdb_text_bert_phone_768.json',
    'emovdb_sem_mat_bert_phone_pretrained16': 'emovdb_text_bert_phone_768.json',
}
# Get the corresponding test_text_sem_dic_file for the chosen model
test_text_sem_dic_file = model_to_test_text_sem_dic[model]
test_text_sem_dic = load_data_from_json(f"{common_dir}filelists/{test_text_sem_dic_file}")

In [None]:
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model).cuda()
_ = net_g.eval()
        
_ = utils.load_checkpoint(f"{log_dir}{model}/{step}.pth", net_g, None) 

In [None]:
keys = list(test_text_sem_dic.keys())
range_limit = len(keys)
range_limit = min(range_limit, len(keys))
print(f"range_limit: {range_limit}")
m = 5

for i in range(range_limit):
    key = keys[i]  # key is string
    s = get_text(key, hps)  # assign key to s
    print(key)
    with torch.no_grad():
        x_tst = s.cuda().unsqueeze(0)
        x_tst_lengths = torch.LongTensor([s.size(0)]).cuda()     
        emb_sem = test_text_sem_dic[key].cuda().unsqueeze(0)  # get corresponding tensor from dictionary
        emb_sem_lengths = torch.LongTensor([emb_sem.size(1)]).cuda()    
        print(f"x_tst: {x_tst.shape}")
        print(f"x_tst_lengths: {x_tst_lengths}")
        print(f"emb_sem: {emb_sem.shape}")
        print(f"emb_sem_lengths: {emb_sem_lengths}")
        audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1, emb_sem=emb_sem, emb_sem_lengths=emb_sem_lengths)[0][0,0].data.cpu().float().numpy()
    if i < m:
        ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))
    write(f"{save_dir}/output_sem_{i}.wav", hps.data.sampling_rate, audio)
