Skip to content

Commit f9c3e4c

Browse files
authored
Merge pull request #1821 from Zhalslar/gsv-tts-selfhost
Feature: 新增 GPT_SoVIS 的 TTS 服务商
2 parents 0441b51 + dc62c1f commit f9c3e4c

File tree

3 files changed

+307
-0
lines changed

3 files changed

+307
-0
lines changed

astrbot/core/config/default.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,37 @@
850850
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
851851
"timeout": 20,
852852
},
853+
"GSV TTS(本地加载)": {
854+
"id": "gsv_tts",
855+
"enable": False,
856+
"type": "gsv_tts_selfhost",
857+
"provider_type": "text_to_speech",
858+
"api_base": "http://127.0.0.1:9880",
859+
"gpt_weights_path": "",
860+
"sovits_weights_path": "",
861+
"timeout": 60,
862+
"gsv_default_parms": {
863+
"gsv_ref_audio_path": "",
864+
"gsv_prompt_text": "",
865+
"gsv_prompt_lang": "zh",
866+
"gsv_aux_ref_audio_paths": "",
867+
"gsv_text_lang": "zh",
868+
"gsv_top_k": 5,
869+
"gsv_top_p": 1.0,
870+
"gsv_temperature": 1.0,
871+
"gsv_text_split_method": "cut3",
872+
"gsv_batch_size": 1,
873+
"gsv_batch_threshold": 0.75,
874+
"gsv_split_bucket": True,
875+
"gsv_speed_factor": 1,
876+
"gsv_fragment_interval": 0.3,
877+
"gsv_streaming_mode": False,
878+
"gsv_seed": -1,
879+
"gsv_parallel_infer": True,
880+
"gsv_repetition_penalty": 1.35,
881+
"gsv_media_type": "wav",
882+
},
883+
},
853884
"GSVI TTS(API)": {
854885
"id": "gsvi_tts",
855886
"type": "gsvi_tts_api",
@@ -951,6 +982,130 @@
951982
},
952983
},
953984
"items": {
985+
"gpt_weights_path": {
986+
"description": "GPT模型文件路径",
987+
"type": "string",
988+
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
989+
"obvious_hint": True,
990+
},
991+
"sovits_weights_path": {
992+
"description": "SoVITS模型文件路径",
993+
"type": "string",
994+
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
995+
"obvious_hint": True,
996+
},
997+
"gsv_default_parms": {
998+
"description": "GPT_SoVITS默认参数",
999+
"hint": "参考音频文件路径、参考音频文本必填,其他参数根据个人爱好自行填写",
1000+
"type": "object",
1001+
"items": {
1002+
"gsv_ref_audio_path": {
1003+
"description": "参考音频文件路径",
1004+
"type": "string",
1005+
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
1006+
"obvious_hint": True,
1007+
},
1008+
"gsv_prompt_text": {
1009+
"description": "参考音频文本",
1010+
"type": "string",
1011+
"hint": "必填!请填写参考音频讲述的文本",
1012+
"obvious_hint": True,
1013+
},
1014+
"gsv_prompt_lang": {
1015+
"description": "参考音频文本语言",
1016+
"type": "string",
1017+
"hint": "请填写参考音频讲述的文本的语言,默认为中文",
1018+
},
1019+
"gsv_aux_ref_audio_paths": {
1020+
"description": "辅助参考音频文件路径",
1021+
"type": "string",
1022+
"hint": "辅助参考音频文件,可不填",
1023+
},
1024+
"gsv_text_lang": {
1025+
"description": "文本语言",
1026+
"type": "string",
1027+
"hint": "默认为中文",
1028+
},
1029+
"gsv_top_k": {
1030+
"description": "生成语音的多样性",
1031+
"type": "int",
1032+
"hint": "",
1033+
},
1034+
"gsv_top_p": {
1035+
"description": "核采样的阈值",
1036+
"type": "float",
1037+
"hint": "",
1038+
},
1039+
"gsv_temperature": {
1040+
"description": "生成语音的随机性",
1041+
"type": "float",
1042+
"hint": "",
1043+
},
1044+
"gsv_text_split_method": {
1045+
"description": "切分文本的方法",
1046+
"type": "string",
1047+
"hint": "可选值: `cut0`:不切分 `cut1`:四句一切 `cut2`:50字一切 `cut3`:按中文句号切 `cut4`:按英文句号切 `cut5`:按标点符号切",
1048+
"options": [
1049+
"cut0",
1050+
"cut1",
1051+
"cut2",
1052+
"cut3",
1053+
"cut4",
1054+
"cut5",
1055+
],
1056+
},
1057+
"gsv_batch_size": {
1058+
"description": "批处理大小",
1059+
"type": "int",
1060+
"hint": "",
1061+
},
1062+
"gsv_batch_threshold": {
1063+
"description": "批处理阈值",
1064+
"type": "float",
1065+
"hint": "",
1066+
},
1067+
"gsv_split_bucket": {
1068+
"description": "将文本分割成桶以便并行处理",
1069+
"type": "bool",
1070+
"hint": "",
1071+
},
1072+
"gsv_speed_factor": {
1073+
"description": "语音播放速度",
1074+
"type": "float",
1075+
"hint": "1为原始语速",
1076+
},
1077+
"gsv_fragment_interval": {
1078+
"description": "语音片段之间的间隔时间",
1079+
"type": "float",
1080+
"hint": "",
1081+
},
1082+
"gsv_streaming_mode": {
1083+
"description": "启用流模式",
1084+
"type": "bool",
1085+
"hint": "",
1086+
},
1087+
"gsv_seed": {
1088+
"description": "随机种子",
1089+
"type": "int",
1090+
"hint": "用于结果的可重复性",
1091+
},
1092+
"gsv_parallel_infer": {
1093+
"description": "并行执行推理",
1094+
"type": "bool",
1095+
"hint": "",
1096+
},
1097+
"gsv_repetition_penalty": {
1098+
"description": "重复惩罚因子",
1099+
"type": "float",
1100+
"hint": "",
1101+
},
1102+
"gsv_media_type": {
1103+
"description": "输出媒体的类型",
1104+
"type": "string",
1105+
"hint": "建议用wav",
1106+
},
1107+
},
1108+
},
9541109
"embedding_dimensions": {
9551110
"description": "嵌入维度",
9561111
"type": "int",

astrbot/core/provider/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ async def load_provider(self, provider_config: dict):
225225
from .sources.edge_tts_source import (
226226
ProviderEdgeTTS as ProviderEdgeTTS,
227227
)
228+
case "gsv_tts_selfhost":
229+
from .sources.gsv_selfhosted_source import (
230+
ProviderGSVTTS as ProviderGSVTTS,
231+
)
228232
case "gsvi_tts_api":
229233
from .sources.gsvi_tts_source import (
230234
ProviderGSVITTS as ProviderGSVITTS,
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import asyncio
2+
import os
3+
import uuid
4+
5+
import aiohttp
6+
from ..provider import TTSProvider
7+
from ..entities import ProviderType
8+
from ..register import register_provider_adapter
9+
from astrbot import logger
10+
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
11+
12+
13+
@register_provider_adapter(
14+
provider_type_name="gsv_tts_selfhost",
15+
desc="GPT-SoVITS TTS(本地加载)",
16+
provider_type=ProviderType.TEXT_TO_SPEECH,
17+
)
18+
class ProviderGSVTTS(TTSProvider):
19+
def __init__(
20+
self,
21+
provider_config: dict,
22+
provider_settings: dict,
23+
) -> None:
24+
super().__init__(provider_config, provider_settings)
25+
26+
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
27+
"/"
28+
)
29+
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
30+
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")
31+
32+
# TTS 请求的默认参数,移除前缀gsv_
33+
self.default_params: dict = {
34+
key.removeprefix("gsv_"): str(value).lower()
35+
for key, value in provider_config.get("gsv_default_parms", {}).items()
36+
}
37+
self.timeout = provider_config.get("timeout", 60)
38+
self._session: aiohttp.ClientSession | None = None
39+
40+
async def initialize(self):
41+
"""异步初始化:在 ProviderManager 中被调用"""
42+
self._session = aiohttp.ClientSession(
43+
timeout=aiohttp.ClientTimeout(total=self.timeout)
44+
)
45+
try:
46+
await self._set_model_weights()
47+
logger.info("[GSV TTS] 初始化完成")
48+
except Exception as e:
49+
logger.error(f"[GSV TTS] 初始化失败:{e}")
50+
raise
51+
52+
def get_session(self) -> aiohttp.ClientSession:
53+
if not self._session or self._session.closed:
54+
raise RuntimeError(
55+
"[GSV TTS] Provider HTTP session is not ready or closed."
56+
)
57+
return self._session
58+
59+
async def _make_request(
60+
self, endpoint: str, params=None, retries: int = 3
61+
) -> bytes | None:
62+
"""发起请求"""
63+
for attempt in range(retries):
64+
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
65+
try:
66+
async with self.get_session().get(endpoint, params=params) as response:
67+
if response.status != 200:
68+
error_text = await response.text()
69+
raise Exception(
70+
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
71+
)
72+
return await response.read()
73+
except Exception as e:
74+
if attempt < retries - 1:
75+
logger.warning(
76+
f"[GSV TTS] 请求 {endpoint}{attempt + 1} 次失败:{e},重试中..."
77+
)
78+
await asyncio.sleep(1)
79+
else:
80+
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
81+
raise
82+
83+
async def _set_model_weights(self):
84+
"""设置模型路径"""
85+
try:
86+
if self.gpt_weights_path:
87+
await self._make_request(
88+
f"{self.api_base}/set_gpt_weights",
89+
{"weights_path": self.gpt_weights_path},
90+
)
91+
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
92+
else:
93+
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")
94+
95+
if self.sovits_weights_path:
96+
await self._make_request(
97+
f"{self.api_base}/set_sovits_weights",
98+
{"weights_path": self.sovits_weights_path},
99+
)
100+
logger.info(
101+
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
102+
)
103+
else:
104+
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
105+
except aiohttp.ClientError as e:
106+
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
107+
except Exception as e:
108+
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")
109+
110+
async def get_audio(self, text: str) -> str:
111+
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
112+
if not text.strip():
113+
raise ValueError("[GSV TTS] TTS 文本不能为空")
114+
115+
endpoint = f"{self.api_base}/tts"
116+
117+
params = self.build_synthesis_params(text)
118+
119+
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
120+
os.makedirs(temp_dir, exist_ok=True)
121+
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
122+
123+
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")
124+
125+
result = await self._make_request(endpoint, params)
126+
if isinstance(result, bytes):
127+
with open(path, "wb") as f:
128+
f.write(result)
129+
return path
130+
else:
131+
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
132+
133+
def build_synthesis_params(self, text: str) -> dict:
134+
"""
135+
构建语音合成所需的参数字典。
136+
137+
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
138+
"""
139+
params = self.default_params.copy()
140+
params["text"] = text
141+
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
142+
return params
143+
144+
async def terminate(self):
145+
"""终止释放资源:在 ProviderManager 中被调用"""
146+
if self._session and not self._session.closed:
147+
await self._session.close()
148+
logger.info("[GSV TTS] Session 已关闭")

0 commit comments

Comments
 (0)