In [None]:
!git clone https://github.com/vibevoice-community/VibeVoice

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!uv pip install --system -e /content/VibeVoice --force-reinstall

In [None]:
!uv pip --quiet install --system -e /content/VibeVoice

In [None]:
!uv pip install wetext

In [None]:
!pip install flash-attn triton accelerate

In [None]:
import types
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from transformers.utils import logging
from wetext import Normalizer
from typing import Dict
from pathlib import Path
import torch
import re
import os

logging.set_verbosity_info()
logger = logging.get_logger(__name__)

sentence_splitter = ["！", "；", "？", "～", "?", "!", "：", "～", "…", "……", "。"]
char_rep_map = {
    "——":".", "：": ",","；": ",",";": ",","，": ",","。": ".","！": "!","？": "?","·": "-",
    "、": ",","...": "…",",,,": "…","，，，": "…","……": "…","“": "'","”": "'",
    '"': "'","‘": "'","’": "'","（": "'","）": "'","(": "'",")": "'",
    "《": "'","》": "'","【": "'","】": "'","[": "'","]": "'","—": "-",
    "～": "-","~": "-","「": "'","」": "'",":": ",",
    "〇": "零","○": "零","卐":"万"
}

def replace_chars(full_script, char_rep_map):
    result = ''
    for char in full_script:
        result += char_rep_map.get(char, char)
    return result

class BookAudioGenerator:
    def __init__(self, tts_model, device) -> None:
        self.processor = VibeVoiceProcessor.from_pretrained(
                tts_model,
                device=device
            )
        model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                tts_model,
                torch_dtype=torch.bfloat16,
                attn_implementation='flash_attention_2',
                device_map=device,)
        model.eval()
        model.set_ddpm_inference_steps(num_steps=40)
        self.model=model

        self.default_prefix="Speaker 1:"
        self.default_speaker = "旁白"
        self.max_length_times = 2.7
        self.normalizer = Normalizer(lang="zh", operator="tn", remove_erhua=True, traditional_to_simple=False)

    def batch_process(self, i_file, batch_size, process_size):
        def _read_file():
            _lines = []
            with open(i_file, 'r', encoding='utf-8') as f: # 按行分割, 保证在一个段落的内容都放在一个输入中, 避免从中间拆分
                _lines = f.read().splitlines()
            if not _lines:
                raise Exception(f'not content in {i_file}')

            results = []  # 存放一批(batch_size)的结果
            current_lines = []  # 当前正在积累的块
            current_length = 0  # 当前块的字符总长度

            for _line in _lines:
                if _line:
                    current_lines.append(_line)
                    current_length += len(_line)
                    if current_length >= process_size:
                        results.append(current_lines)
                        current_lines = []
                        current_length = 0
                        if len(results) == batch_size:
                            yield results
                            results = []
            if current_lines:
                results.append(current_lines)
            if results:
                yield results

        batch_index = 0
        for batch in _read_file():
            processed_batch = []
            for sub_list in batch:
                processed_sub_list = []

                for item in sub_list:
                    stripped_item = item.strip()
                    if stripped_item:
                        stripped_item.split() # 将一个段落拆分,避免一句话中内容太多,导致输出语音语速变快
                        processed_sub_list.extend(self.split_sentence(stripped_item))

                processed_batch.append(processed_sub_list)
            yield processed_batch, batch_index
            batch_index += 1

    def split_sentence(self, sentence):
        splitters = "".join(sentence_splitter)
        escaped_splitters = re.escape(splitters)
        pattern = r'([' + escaped_splitters + r'])\s*'
        parts = re.split(pattern, sentence)
        sentences = []
        current_sentence = ""
        for part in parts:
            if part is None or not part.strip():
                continue
            current_sentence += part
            if part in sentence_splitter:
                sentences.append(current_sentence.strip())
                current_sentence = ""
        if current_sentence.strip():
            sentences.append(current_sentence.strip())
        return sentences

    def _tts_generate(self, to_tts_batch, voice_sample):
        inputs = self.processor(
            text=to_tts_batch,
            voice_samples=voice_sample,
            padding=True,
            return_tensors="pt",
            return_attention_mask=True,
        )

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=None,
            cfg_scale=1.3,
            tokenizer=self.processor.tokenizer,
            # generation_config={'do_sample': True, 'temperature': 0.99, 'top_p': 0.99, 'top_k': 3},
            generation_config={'do_sample': False},
            verbose=True,
            max_length_times=self.max_length_times, #default 2
        )
        return outputs

    def txt_normlize(self, txt):
        return self.normalizer.normalize(txt)

    def tts_txt_preprocess(self, txt):

        chinese_pattern = r"（.*?）"
        english_pattern = r"\([^)]*?\)"

        combined_pattern = f"{chinese_pattern}|{english_pattern}"
        _txt = re.sub(combined_pattern, "", txt)
        _txt = self.default_prefix + replace_chars(_txt, char_rep_map) #规范化, 替换中文符号, 根据vibevoice文档, 建议使用英语标点符号
        return _txt

    def gererator_speech_with_default_voice(
            self,
            chunk,
            batch_index,
            single_speaker,
            output_dir
            ):
        # txt_normlize 通过opencc将繁体转换成了简体.txt可以直接保存简体部分
        to_tts_batch = [
            [
                self.txt_normlize(item) for item in s_batch
            ]
            for s_batch in chunk
            ]

        _tts_text = [
            "\n".join([
                self.tts_txt_preprocess(item) for item in s_batch
            ])
            for s_batch in to_tts_batch
            ]

        output_stem = output_path_wav = f"{output_dir}/{project_name}-{batch_index}"
        output_path_wav = f"{output_stem}_0.wav"
        if os.path.exists(output_path_wav):
            logger.warning(f'⚠️ file {output_path_wav} exists, so batch will not process.')
            return

        outputs = self._tts_generate(_tts_text, [[single_speaker]] * len(chunk))
        for check in outputs.reach_max_step_sample.tolist():
            if check:
                logger.warning(f'⚠️  reach max length, audio may cut up, you may increase [max_length_times] and current is [{self.max_length_times}]')

        for _index, (output_speech, txt) in enumerate(zip(outputs.speech_outputs, chunk)): #

            output_path_wav = f"{output_stem}_{_index}.wav"
            output_path_txt = f"{output_stem}_{_index}.txt"

            output_path = Path(output_path_txt)
            output_path.parent.mkdir(parents=True, exist_ok=True)

            self.processor.save_audio(
                output_speech,
                output_path=output_path_wav,
            )
            output_path.write_text("\n".join(txt), encoding='utf-8')
            logger.info(f'finish process ouput file : {output_path_wav} \n {output_path_txt}')

    def generate(self, to_tts_file, output_dir, single_speaker, batch_size = 4, process_size = 9000):
        for _b, _i in self.batch_process(to_tts_file, batch_size, process_size):
            self.gererator_speech_with_default_voice(_b, _i, single_speaker, output_dir)

    def generate_single_dialog(self, to_tts_file, txt_speeker, speeker_voice):

        with open(to_tts_file, 'r', encoding='utf-8') as f:
            _lines = f.read().splitlines()
        output_path_wav = Path(to_tts_file).with_suffix(".wav")

        speeker_voice_x = [f"Speaker {i+1}" for i, speaker in enumerate(txt_speeker)]
        speaker_map: Dict[str, str] = dict(zip(txt_speeker, speeker_voice_x))

        SPEAKER_PATTERN = re.compile(r'^([^:]+):')

        to_tts_batch = []
        pre_speaker = self.default_speaker
        for item in _lines:
            if item:
                match = SPEAKER_PATTERN.match(item)
                if match:
                    speaker_name = match.group(1).strip()
                    selected_prefix = speaker_map.get(speaker_name, self.default_speaker[0])
                    item_content = item[match.end():].strip() # 提取冒号后的内容
                    new_line = selected_prefix + ": " + item_content
                    pre_speaker = speaker_name
                else:
                    speaker_name = pre_speaker
                    new_line = speaker_map.get(speaker_name, self.default_speaker[0]) + ": " + item
                to_tts_batch.append(new_line)

        to_tts_batch = ["\n".join(to_tts_batch)]

        outputs = self._tts_generate(to_tts_batch, speeker_voice)
        self.processor.save_audio(
            outputs.speech_outputs[0],
            output_path=output_path_wav,
        )
    def generate_single_batch(self, to_tts_files, voice_samples):
        to_tts_texts = []
        for to_tts_file in to_tts_files:
            with open(to_tts_file, 'r', encoding='utf-8') as f:
              _lines = f.read().splitlines()

            output_path_wav = Path(to_tts_file).with_suffix(".wav")
            to_tts_txt = [self.tts_txt_preprocess(self.txt_normlize(item)) for item in _lines]
            to_tts_txt = "\n".join(to_tts_txt)
            to_tts_texts.append(to_tts_txt)
        outputs = self._tts_generate(to_tts_texts, [[voice_samples]] * len(to_tts_files))
        for _index, (output_speech, tts_file) in enumerate(zip(outputs.speech_outputs, to_tts_files)):
            output_path_wav = Path(tts_file).with_suffix(".wav")
            self.processor.save_audio(
                output_speech,
                output_path=output_path_wav,
            )
            logger.info(f'saved file {output_path_wav}')

    def generate_single(self, to_tts_file, voice_samples):
        with open(to_tts_file, 'r', encoding='utf-8') as f:
            _lines = f.read().splitlines()

        output_path_wav = Path(to_tts_file).with_suffix(".wav")
        to_tts_txt = [self.tts_txt_preprocess(self.txt_normlize(item)) for item in _lines]
        to_tts_txt = "\n".join(to_tts_txt)
        # to_tts_txt = self.txt_normlize(to_tts_txt)
        # to_tts_txt = self.tts_txt_preprocess(to_tts_txt)
        outputs = self._tts_generate(to_tts_txt, [voice_samples])
        self.processor.save_audio(
            outputs.speech_outputs[0],
            output_path=output_path_wav,
        )

env_type = "colab" # colab modelscope local

env_config = {
    "local":{
        "drive_dir" : "/Volumes/sw/MyDrive",
        # "model_name": "/Volumes/sw/pretrained_models/VibeVoice-1.5B",
        "model_name": "/Volumes/sw/hf_models/VibeVoice-1.5B-ft",
        "device": "mps"
    },
    "modelscope":{
        "drive_dir": "/mnt/workspace",
        # model_name = "/mnt/workspace/pretrained_models/VibeVoice-1.5B"
        "model_name": "/mnt/workspace/pretrained_models/VibeVoice-1.5B-ft",
        "device": "cuda"
    },
    "colab":{
        "drive_dir": "/content/drive/MyDrive",
        # "model_name": "tardigrade-doc/VibeVoice-1.5B-ft",
        "model_name": "microsoft/VibeVoice-1.5B",
        # "model_name": "vibevoice/VibeVoice-7B",
        "device": "cuda"
    }
}
if env_type not in env_config:
    raise Exception(f"not supported env {env_type}")
config_dict = env_config[env_type]

config = types.SimpleNamespace(**config_dict)

drive_dir = config.drive_dir
model_name = config.model_name
device = config.device

input_file = f"{drive_dir}/data_src/zhipeiyudikang.txt"
speaker_phi0 = f"{drive_dir}/data_src/qinsheng.wav"

input_file_path = Path(input_file)
project_name = input_file_path.stem

output_dir = f"{drive_dir}/{project_name}"
bookAudioGen = BookAudioGenerator(
    model_name,
    device)

# bookAudioGen.generate(input_file, output_dir, speaker_phi0, 6, 8000)
# bookAudioGen.generate_single("/content/drive/MyDrive/tianchaoyaoyuan1/tianchaoyaoyuan1-6_1.txt", speaker_phi0)
# 上一个方法的批量执行
bookAudioGen.generate_single_batch(
    [
        "/content/drive/MyDrive/zhipeiyudikang/zhipeiyudikang-2_3.txt",
        "/content/drive/MyDrive/zhipeiyudikang/zhipeiyudikang-2_4.txt",
        "/content/drive/MyDrive/zhipeiyudikang/zhipeiyudikang-2_5.txt",
        "/content/drive/MyDrive/zhipeiyudikang/zhipeiyudikang-3_0.txt",
        "/content/drive/MyDrive/zhipeiyudikang/zhipeiyudikang-3_1.txt",
        ],
                                   speaker_phi0)

# 针对某个已经经过上述批量处理后,某个txt对应的wav存在问题的重新生成.
# bookAudioGen.generate_single("/Volumes/sw/MyDrive/zhengzhi1/output/zhengzhi1-4_2.txt", [speaker_phi0])

# bookAudioGen.generate_single("/Volumes/sw/tmp/zhengzhi1-5_4.txt", [speaker_phi0])

# bookAudioGen.generate_single_dialog(
#     "/Users/larry/github.com/tardigrade-dot/colab-script/data_src/sugeladizhisi_part1.txt",
#     ["旁白", "欧", "苏"],
#     [f"{drive_dir}/data_src/youyi.wav", f"{drive_dir}/data_src/sample_zhongdong.wav", f"{drive_dir}/data_src/gdg_voice_06.wav"])