-
-
Notifications
You must be signed in to change notification settings - Fork 737
Feature: 新增 GPT_SoVIS 的 TTS 服务商 #1821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9405ba7
825e3db
14c29f0
b251ee9
1789393
bee5d35
dc62c1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import asyncio | ||
import os | ||
import uuid | ||
|
||
import aiohttp | ||
from ..provider import TTSProvider | ||
from ..entities import ProviderType | ||
from ..register import register_provider_adapter | ||
from astrbot import logger | ||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path | ||
|
||
|
||
@register_provider_adapter( | ||
provider_type_name="gsv_tts_selfhost", | ||
desc="GPT-SoVITS TTS(本地加载)", | ||
provider_type=ProviderType.TEXT_TO_SPEECH, | ||
) | ||
class ProviderGSVTTS(TTSProvider): | ||
def __init__( | ||
self, | ||
provider_config: dict, | ||
provider_settings: dict, | ||
) -> None: | ||
super().__init__(provider_config, provider_settings) | ||
|
||
self.api_base = provider_config.get("api_base", "http://127.0.0.1:9880").rstrip( | ||
"/" | ||
) | ||
self.gpt_weights_path: str = provider_config.get("gpt_weights_path", "") | ||
self.sovits_weights_path: str = provider_config.get("sovits_weights_path", "") | ||
|
||
# TTS 请求的默认参数,移除前缀gsv_ | ||
self.default_params: dict = { | ||
key.removeprefix("gsv_"): str(value).lower() | ||
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+33
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): 将所有默认参数值转换为小写可能会导致类型问题。 将所有值转换为小写可能会导致非字符串参数的类型不匹配。仅将字符串值转换为小写,以避免意外行为。 Original comment in Englishissue (bug_risk): Lowercasing all default parameter values may cause type issues. Lowercasing all values may cause type mismatches for non-string parameters. Only lowercase string values to avoid unexpected behavior. |
||
for key, value in provider_config.get("gsv_default_parms", {}).items() | ||
} | ||
self.timeout = provider_config.get("timeout", 60) | ||
self._session: aiohttp.ClientSession | None = None | ||
|
||
async def initialize(self): | ||
"""异步初始化:在 ProviderManager 中被调用""" | ||
self._session = aiohttp.ClientSession( | ||
timeout=aiohttp.ClientTimeout(total=self.timeout) | ||
Comment on lines
+42
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): 在 init 中创建 aiohttp.ClientSession 可能会导致事件循环问题。 考虑在异步方法中初始化 ClientSession,或者使用异步工厂,以避免在异步上下文之外创建 provider 时出现事件循环问题。 Original comment in Englishissue (bug_risk): Creating aiohttp.ClientSession in init may cause event loop issues. Consider initializing ClientSession in an async method or using an async factory to avoid event loop issues when the provider is created outside an async context. |
||
) | ||
try: | ||
await self._set_model_weights() | ||
logger.info("[GSV TTS] 初始化完成") | ||
except Exception as e: | ||
logger.error(f"[GSV TTS] 初始化失败:{e}") | ||
raise | ||
|
||
def get_session(self) -> aiohttp.ClientSession: | ||
if not self._session or self._session.closed: | ||
raise RuntimeError( | ||
"[GSV TTS] Provider HTTP session is not ready or closed." | ||
) | ||
return self._session | ||
|
||
async def _make_request( | ||
self, endpoint: str, params=None, retries: int = 3 | ||
) -> bytes | None: | ||
"""发起请求""" | ||
for attempt in range(retries): | ||
logger.debug(f"[GSV TTS] 请求地址:{endpoint},参数:{params}") | ||
try: | ||
async with self.get_session().get(endpoint, params=params) as response: | ||
if response.status != 200: | ||
error_text = await response.text() | ||
raise Exception( | ||
f"[GSV TTS] Request to {endpoint} failed with status {response.status}: {error_text}" | ||
) | ||
Comment on lines
+69
to
+71
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (code-quality): 提出一个特定的错误,而不是一般的 解释如果一段代码引发一个特定的异常类型, 而不是通用的 [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) 或 [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), 调用代码可以:
这样,代码的调用者可以适当地处理错误。 您如何解决这个问题? 因此,与其让代码引发 if incorrect_input(value):
raise Exception("输入不正确") 您可以让代码引发一个特定的错误,例如 if incorrect_input(value):
raise ValueError("输入不正确") 或者 class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("输入不正确") Original comment in Englishissue (code-quality): Raise a specific error instead of the general ExplanationIf a piece of code raises a specific exception type rather than the generic [`BaseException`](https://docs.python.org/3/library/exceptions.html#BaseException) or [`Exception`](https://docs.python.org/3/library/exceptions.html#Exception), the calling code can:
This way, callers of the code can handle the error appropriately. How can you solve this?
So instead of having code raising if incorrect_input(value):
raise Exception("The input is incorrect") you can have code raising a specific error like if incorrect_input(value):
raise ValueError("The input is incorrect") or class IncorrectInputError(Exception):
pass
if incorrect_input(value):
raise IncorrectInputError("The input is incorrect")
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return await response.read() | ||
except Exception as e: | ||
if attempt < retries - 1: | ||
logger.warning( | ||
f"[GSV TTS] 请求 {endpoint} 第 {attempt + 1} 次失败:{e},重试中..." | ||
) | ||
await asyncio.sleep(1) | ||
else: | ||
logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}") | ||
raise | ||
|
||
async def _set_model_weights(self): | ||
"""设置模型路径""" | ||
try: | ||
if self.gpt_weights_path: | ||
await self._make_request( | ||
f"{self.api_base}/set_gpt_weights", | ||
{"weights_path": self.gpt_weights_path}, | ||
) | ||
logger.info(f"[GSV TTS] 成功设置 GPT 模型路径:{self.gpt_weights_path}") | ||
else: | ||
logger.info("[GSV TTS] GPT 模型路径未配置,将使用内置 GPT 模型") | ||
|
||
if self.sovits_weights_path: | ||
await self._make_request( | ||
f"{self.api_base}/set_sovits_weights", | ||
{"weights_path": self.sovits_weights_path}, | ||
) | ||
logger.info( | ||
f"[GSV TTS] 成功设置 SoVITS 模型路径:{self.sovits_weights_path}" | ||
) | ||
else: | ||
logger.info("[GSV TTS] SoVITS 模型路径未配置,将使用内置 SoVITS 模型") | ||
except aiohttp.ClientError as e: | ||
logger.error(f"[GSV TTS] 设置模型路径时发生网络错误:{e}") | ||
except Exception as e: | ||
logger.error(f"[GSV TTS] 设置模型路径时发生未知错误:{e}") | ||
|
||
async def get_audio(self, text: str) -> str: | ||
"""实现 TTS 核心方法,根据文本内容自动切换情绪""" | ||
if not text.strip(): | ||
raise ValueError("[GSV TTS] TTS 文本不能为空") | ||
|
||
endpoint = f"{self.api_base}/tts" | ||
|
||
params = self.build_synthesis_params(text) | ||
|
||
temp_dir = os.path.join(get_astrbot_data_path(), "temp") | ||
os.makedirs(temp_dir, exist_ok=True) | ||
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav") | ||
|
||
logger.debug(f"[GSV TTS] 正在调用语音合成接口,参数:{params}") | ||
|
||
result = await self._make_request(endpoint, params) | ||
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if isinstance(result, bytes): | ||
with open(path, "wb") as f: | ||
f.write(result) | ||
return path | ||
else: | ||
raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") | ||
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+126
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (code-quality): 我们发现了这些问题:
Original comment in Englishissue (code-quality): We've found these issues:
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
Soulter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def build_synthesis_params(self, text: str) -> dict: | ||
""" | ||
构建语音合成所需的参数字典。 | ||
|
||
当前仅包含默认参数 + 文本,未来可在此基础上动态添加如情绪、角色等语义控制字段。 | ||
""" | ||
params = self.default_params.copy() | ||
params["text"] = text | ||
# TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text) | ||
return params | ||
|
||
async def terminate(self): | ||
"""终止释放资源:在 ProviderManager 中被调用""" | ||
if self._session and not self._session.closed: | ||
await self._session.close() | ||
logger.info("[GSV TTS] Session 已关闭") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): 将所有默认参数值转换为小写可能会导致类型问题。
仅将字符串值转换为小写,以避免与期望非字符串类型的下游 API 发生类型不匹配。
Original comment in English
issue (bug_risk): Lowercasing all default parameter values may cause type issues.
Only lowercase string values to avoid type mismatches with downstream APIs that expect non-string types.