Skip to content

Feature: 新增 GPT_SoVIS 的 TTS 服务商 #1821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,37 @@
"edge-tts-voice": "zh-CN-XiaoxiaoNeural",
"timeout": 20,
},
"GSV TTS(本地加载)": {
"id": "gsv_tts",
"enable": False,
"type": "gsv_tts_selfhost",
"provider_type": "text_to_speech",
"api_base": "http://127.0.0.1:9880",
"gpt_weights_path": "",
"sovits_weights_path": "",
"timeout": 60,
"gsv_default_parms": {
"gsv_ref_audio_path": "",
"gsv_prompt_text": "",
"gsv_prompt_lang": "zh",
"gsv_aux_ref_audio_paths": "",
"gsv_text_lang": "zh",
"gsv_top_k": 5,
"gsv_top_p": 1.0,
"gsv_temperature": 1.0,
"gsv_text_split_method": "cut3",
"gsv_batch_size": 1,
"gsv_batch_threshold": 0.75,
"gsv_split_bucket": True,
"gsv_speed_factor": 1,
"gsv_fragment_interval": 0.3,
"gsv_streaming_mode": False,
"gsv_seed": -1,
"gsv_parallel_infer": True,
"gsv_repetition_penalty": 1.35,
"gsv_media_type": "wav",
},
},
"GSVI TTS(API)": {
"id": "gsvi_tts",
"type": "gsvi_tts_api",
Expand Down Expand Up @@ -901,6 +932,130 @@
},
},
"items": {
"gpt_weights_path": {
"description": "GPT模型文件路径",
"type": "string",
"hint": "即“.ckpt”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
"obvious_hint": True,
},
"sovits_weights_path": {
"description": "SoVITS模型文件路径",
"type": "string",
"hint": "即“.pth”后缀的文件,请使用绝对路径,路径两端不要带双引号,不填则默认用GPT_SoVITS内置的SoVITS模型(建议直接在GPT_SoVITS中改默认模型)",
"obvious_hint": True,
},
"gsv_default_parms": {
"description": "GPT_SoVITS默认参数",
"hint": "参考音频文件路径、参考音频文本必填,其他参数根据个人爱好自行填写",
"type": "object",
"items": {
"gsv_ref_audio_path": {
"description": "参考音频文件路径",
"type": "string",
"hint": "必填!请使用绝对路径!路径两端不要带双引号!",
"obvious_hint": True,
},
"gsv_prompt_text": {
"description": "参考音频文本",
"type": "string",
"hint": "必填!请填写参考音频讲述的文本",
"obvious_hint": True,
},
"gsv_prompt_lang": {
"description": "参考音频文本语言",
"type": "string",
"hint": "请填写参考音频讲述的文本的语言,默认为中文",
},
"gsv_aux_ref_audio_paths": {
"description": "辅助参考音频文件路径",
"type": "string",
"hint": "辅助参考音频文件,可不填",
},
"gsv_text_lang": {
"description": "文本语言",
"type": "string",
"hint": "默认为中文",
},
"gsv_top_k": {
"description": "生成语音的多样性",
"type": "int",
"hint": "",
},
"gsv_top_p": {
"description": "核采样的阈值",
"type": "float",
"hint": "",
},
"gsv_temperature": {
"description": "生成语音的随机性",
"type": "float",
"hint": "",
},
"gsv_text_split_method": {
"description": "切分文本的方法",
"type": "string",
"hint": "可选值: `cut0`:不切分 `cut1`:四句一切 `cut2`:50字一切 `cut3`:按中文句号切 `cut4`:按英文句号切 `cut5`:按标点符号切",
"options": [
"cut0",
"cut1",
"cut2",
"cut3",
"cut4",
"cut5",
],
},
"gsv_batch_size": {
"description": "批处理大小",
"type": "int",
"hint": "",
},
"gsv_batch_threshold": {
"description": "批处理阈值",
"type": "float",
"hint": "",
},
"gsv_split_bucket": {
"description": "将文本分割成桶以便并行处理",
"type": "bool",
"hint": "",
},
"gsv_speed_factor": {
"description": "语音播放速度",
"type": "float",
"hint": "1为原始语速",
},
"gsv_fragment_interval": {
"description": "语音片段之间的间隔时间",
"type": "float",
"hint": "",
},
"gsv_streaming_mode": {
"description": "启用流模式",
"type": "bool",
"hint": "",
},
"gsv_seed": {
"description": "随机种子",
"type": "int",
"hint": "用于结果的可重复性",
},
"gsv_parallel_infer": {
"description": "并行执行推理",
"type": "bool",
"hint": "",
},
"gsv_repetition_penalty": {
"description": "重复惩罚因子",
"type": "float",
"hint": "",
},
"gsv_media_type": {
"description": "输出媒体的类型",
"type": "string",
"hint": "建议用wav",
},
},
},
"embedding_dimensions": {
"description": "嵌入维度",
"type": "int",
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ async def load_provider(self, provider_config: dict):
from .sources.edge_tts_source import (
ProviderEdgeTTS as ProviderEdgeTTS,
)
case "gsv_tts_selfhost":
from .sources.gsv_selfhosted_source import (
ProviderGSVTTS as ProviderGSVTTS,
)
case "gsvi_tts_api":
from .sources.gsvi_tts_source import (
ProviderGSVITTS as ProviderGSVITTS,
Expand Down
148 changes: 148 additions & 0 deletions astrbot/core/provider/sources/gsv_selfhosted_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import asyncio
import os
import uuid

import aiohttp
from ..provider import TTSProvider
from ..entities import ProviderType
from ..register import register_provider_adapter
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path


@register_provider_adapter(
provider_type_name="gsv_tts_selfhost",
desc="GPT-SoVITS TTS(本地加载)",
provider_type=ProviderType.TEXT_TO_SPEECH,
)
class ProviderGSVTTS(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)

self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip(
"/"
)
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "")
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "")

# TTS 请求的默认参数,移除前缀gsv_
self.default_params: dict = {
key.removeprefix("gsv_"): str(value).lower()
Comment on lines +33 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): 将所有默认参数值转换为小写可能会导致类型问题。

仅将字符串值转换为小写,以避免与期望非字符串类型的下游 API 发生类型不匹配。

Original comment in English

issue (bug_risk): Lowercasing all default parameter values may cause type issues.

Only lowercase string values to avoid type mismatches with downstream APIs that expect non-string types.

Comment on lines +33 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): 将所有默认参数值转换为小写可能会导致类型问题。

将所有值转换为小写可能会导致非字符串参数的类型不匹配。仅将字符串值转换为小写,以避免意外行为。

Original comment in English

issue (bug_risk): Lowercasing all default parameter values may cause type issues.

Lowercasing all values may cause type mismatches for non-string parameters. Only lowercase string values to avoid unexpected behavior.

for key, value in provider_config.get("gsv_default_parms", {}).items()
}
self.timeout = provider_config.get("timeout", 60)
self._session: aiohttp.ClientSession | None = None

async def initialize(self):
"""异步初始化:在 ProviderManager 中被调用"""
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.timeout)
Comment on lines +42 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk):init 中创建 aiohttp.ClientSession 可能会导致事件循环问题。

考虑在异步方法中初始化 ClientSession,或者使用异步工厂,以避免在异步上下文之外创建 provider 时出现事件循环问题。

Original comment in English

issue (bug_risk): Creating aiohttp.ClientSession in init may cause event loop issues.

Consider initializing ClientSession in an async method or using an async factory to avoid event loop issues when the provider is created outside an async context.

)
try:
await self._set_model_weights()
logger.info("[GSV TTS] 初始化完成")
except Exception as e:
logger.error(f"[GSV TTS] 初始化失败:{e}")
raise

def get_session(self) -> aiohttp.ClientSession:
if not self._session or self._session.closed:
raise RuntimeError(
"[GSV TTS] Provider HTTP session is not ready or closed."
)
return self._session

async def _make_request(
self, endpoint: str, params=None, retries: int = 3
) -> bytes | None:
"""发起请求"""
for attempt in range(retries):
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}")
try:
async with self.get_session().get(endpoint, params=params) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}"
)
Comment on lines +69 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): 提出一个特定的错误,而不是一般的 ExceptionBaseException (raise-specific-error)

解释如果一段代码引发一个特定的异常类型, 而不是通用的 [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) 或 [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), 调用代码可以:
  • 获取更多关于错误类型的信息
  • 为其定义特定的异常处理

这样,代码的调用者可以适当地处理错误。

您如何解决这个问题?

因此,与其让代码引发 ExceptionBaseException,例如

if incorrect_input(value):
    raise Exception("输入不正确")

您可以让代码引发一个特定的错误,例如

if incorrect_input(value):
    raise ValueError("输入不正确")

或者

class IncorrectInputError(Exception):
    pass


if incorrect_input(value):
    raise IncorrectInputError("输入不正确")
Original comment in English

issue (code-quality): Raise a specific error instead of the general Exception or BaseException (raise-specific-error)

ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
  • get more information about what type of error it is
  • define specific exception handling for it

This way, callers of the code can handle the error appropriately.

How can you solve this?

So instead of having code raising Exception or BaseException like

if incorrect_input(value):
    raise Exception("The input is incorrect")

you can have code raising a specific error like

if incorrect_input(value):
    raise ValueError("The input is incorrect")

or

class IncorrectInputError(Exception):
    pass


if incorrect_input(value):
    raise IncorrectInputError("The input is incorrect")

return await response.read()
except Exception as e:
if attempt < retries - 1:
logger.warning(
f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..."
)
await asyncio.sleep(1)
else:
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}")
raise

async def _set_model_weights(self):
"""设置模型路径"""
try:
if self.gpt_weights_path:
await self._make_request(
f"{self.api_base}/set_gpt_weights",
{"weights_path": self.gpt_weights_path},
)
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}")
else:
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型")

if self.sovits_weights_path:
await self._make_request(
f"{self.api_base}/set_sovits_weights",
{"weights_path": self.sovits_weights_path},
)
logger.info(
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}"
)
else:
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型")
except aiohttp.ClientError as e:
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}")
except Exception as e:
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}")

async def get_audio(self, text: str) -> str:
"""实现 TTS 核心方法,根据文本内容自动切换情绪"""
if not text.strip():
raise ValueError("[GSV TTS] TTS 文本不能为空")

endpoint = f"{self.api_base}/tts"

params = self.build_synthesis_params(text)

temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")

logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}")

result = await self._make_request(endpoint, params)
if isinstance(result, bytes):
with open(path, "wb") as f:
f.write(result)
return path
else:
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}")
Comment on lines +126 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): 我们发现了这些问题:

Original comment in English

issue (code-quality): We've found these issues:


def build_synthesis_params(self, text: str) -> dict:
"""
构建语音合成所需的参数字典。

当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。
"""
params = self.default_params.copy()
params["text"] = text
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text)
return params

async def terminate(self):
"""终止释放资源:在 ProviderManager 中被调用"""
if self._session and not self._session.closed:
await self._session.close()
logger.info("[GSV TTS] Session 已关闭")