In [None]:
# voice design使用, 使用参考语音
import torch
import soundfile as sf
from qwen_tts import Qwen3TTSModel
import string

model_dir = "/Volumes/sw/pretrained_models/Qwen3-TTS-12Hz-0.6B-Base"
output_dir = "/Users/larry/github.com/colab-script/output"
ref_audio="/Volumes/sw/MyDrive/data_src/youyi-15s.wav"
ref_text = "不过事实上也不完全如此,从政治体制的角度看,他们俩都在为君主制寻找新的基础.或者说都希望出现一种不同于以往的新的君主制."
batch_size=2
max_single_size=3000

model = Qwen3TTSModel.from_pretrained(
    model_dir,
    device_map="mps",
    dtype=torch.bfloat16,
    attn_implementation="sdpa",
)

r = model.create_voice_clone_prompt(
    ref_audio=ref_audio,
    ref_text=ref_text,
    x_vector_only_mode=False,
)

def batch_process(i_file, batch_size, process_size):
    def _read_file():
        _lines = []
        with open(i_file, 'r', encoding='utf-8') as f: # 按行分割, 保证在一个段落的内容都放在一个输入中, 避免从中间拆分
            _lines = f.read().splitlines()
        if not _lines:
            raise Exception(f'not content in {i_file}')

        results = []  # 存放一批(batch_size)的结果
        current_lines = []  # 当前正在积累的块
        current_length = 0  # 当前块的字符总长度

        for _line in _lines:
            if _line:
                current_lines.append(_line)
                current_length += len(_line)
                if current_length >= process_size:
                    results.append(current_lines)
                    current_lines = []
                    current_length = 0
                    if len(results) == batch_size:
                        yield results
                        results = []
        if current_lines:
            results.append(current_lines)
        if results:
            yield results

    batch_index = 0
    for batch in _read_file():
        processed_batch = []
        for sub_list in batch:
            processed_sub_list = []

            for item in sub_list:
                stripped_item = item.strip()
                if stripped_item:
                    stripped_item.split()
                    processed_sub_list.append(stripped_item)

            processed_batch.append(processed_sub_list)
        yield processed_batch, batch_index
        batch_index += 1

def ensure_punctuation_end(text: str, default='。') -> str:
    if not text:
        return default
    return text if text[-1] in string.punctuation + '。！？；：，、' else text + default

text_file = "/Volumes/sw/tts_result/wangquanyudikangquan/wangquanyudikangquan-0_0.txt"
for _b, _i in batch_process(i_file=text_file, batch_size=1, process_size=max_single_size):

    print(f"process batch {_i}")
    _b = [''.join(ensure_punctuation_end(x)) for x in _b]
    wavs, sr = model.generate_voice_clone(
        text = _b,
        language = ["Chinese"] * len(_b),
        voice_clone_prompt = r,
        non_streaming_mode=True,
    )
    for i, wav in enumerate(wavs):
        sf.write(f"{output_dir}/output-{_i}-{i}.wav", wavs[i], sr)