# 1.提取文本特征

In [None]:
from text import chinese
import os 
import numpy as np
finals= ['a','ai','an','ang','ao','e','ei','en','eng','er','i','i0','ia','ian','iang','iao','ie','in','ing','iong','ir','iu','o','ong','ou','u','ua','uai','uan','uang','ui','un','uo','v','van','ve','vn']

def generate_token(text):
    
    phone_set = "./text/chinese_dict"
    with open(phone_set) as ps:
        phone_map = { pho.strip():idx+1 for idx,pho in enumerate(ps.readlines())}

    norm_text = chinese.text_normalize(text)
    phones, tones, word2ph = chinese.g2p(norm_text)

    final_phone_lst = []
    for phone,tone in zip(phones,tones):
        if '_' == phone:
            continue
        if phone in finals:
            phone = phone+str(tone)
        if phone == '!' or phone == '…':
            phone = '.'
        if phone == '-':
            phone = ','
        final_phone_lst.append(phone.strip())
    if 'iong4' in final_phone_lst:                              # 默认音素集内不存在iong4
        assert 0
    token_lst = [phone_map[phone] for phone in final_phone_lst]

    return np.array(token_lst)
    

In [None]:
prompt_text = '然后在家看书的时候，就发现那些什么心灵鸡汤啊、心灵鸡血呀，还有什么毒鸡汤啊，还挺管用的。'    #   提示文本
target_text = '你好，我是郭钊。'                                                                #   待合成文本

In [None]:
token = generate_token(prompt_text)
np.save("prompt_text_token.npy",token)
token = generate_token(target_text)
np.save("target_text_token.npy",token)

# 2.以下代码需要在[Amphion](https://github.com/open-mmlab/Amphion)根目录下运行
建议将本文件复制到Amphion根目录，具体运行环境见[Amphion](https://github.com/open-mmlab/Amphion)

In [None]:
import argparse
import os
import torch
import torchaudio
import numpy as np
from utils.io import save_audio
from tqdm import tqdm
from utils.tokenizer import AudioTokenizer,extract_encodec_token
from utils.util import load_config
from models.tts.valle.valle import VALLE
from encodec import EncodecModel
from encodec.utils import convert_audio

os.environ['WORK_DIR'] = '.'
os.environ['PYTHONPATH'] = '.'
os.environ['PYTHONIOENCODING'] = 'UTF-8'

In [None]:
in_config = './egs/tts/VALLE/exp_config.json'
in_ckpt_path = 'pytorch_model.bin'                          #权重路径


cfg = load_config(in_config)
model = VALLE(cfg.model)
audio_tokenizer = AudioTokenizer()
ckpt = torch.load(in_ckpt_path)
model.load_state_dict(ckpt)

enc_model = EncodecModel.encodec_model_24khz()
enc_model.set_target_bandwidth(6.0)
enc_model = enc_model.cuda()
enc_model.eval()

In [None]:
prompt_audio_path = 'prompt.wav'                                                    #提示音频路径
prompt_text_token_path = 'prompt_text_token.npy'                                    #1中提取的提示文本特征路径
target_text_token_path = 'target_text_token.npy'                                    #1中提取的待合成文本特征路径

In [None]:
wav, sr = torchaudio.load(prompt_audio_path)
wav = convert_audio(wav, sr, enc_model.sample_rate, enc_model.channels)
wav = wav.unsqueeze(0)
wav = wav.cuda()

with torch.no_grad():
    encoded_frames = enc_model.encode(wav)
    codes_ = torch.cat(
        [encoded[0] for encoded in encoded_frames], dim=-1
    )  # [B, n_q, T]
    codes = codes_.cpu().numpy()[0, :, :].T  # [T, 8]

prompt_audio_token = codes                          # encodec 特征

In [None]:
prompt_text_token = np.load(prompt_text_token_path)
target_text_token = np.load(target_text_token_path)

In [None]:
all_phone =  np.concatenate((prompt_text_token,target_text_token))
semantic_token = torch.from_numpy(all_phone)

semantic_len = semantic_token.shape[0]
semantic_len = torch.IntTensor([semantic_len])

prompt_token = torch.from_numpy(prompt_audio_token)

device = 'cuda:0'
model = model.to(device)
semantic_token = semantic_token.to(device)
semantic_len = semantic_len.to(device)
prompt_token = prompt_token.to(device)
encoded_frames = model.inference(
                semantic_token.unsqueeze(0),
                semantic_len,
                prompt_token.unsqueeze(0),
                enroll_x_lens=len(prompt_text_token),
                top_k=100,
                temperature=1.0,
            )
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
audio = samples[0].squeeze(0).cpu().detach()

In [None]:
save_audio('target.wav', audio, 24000)