Skip to content

Commit

Permalink
fix: extract the language code wrapped in a markup (#494)
Browse files Browse the repository at this point in the history
Fixes #492

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming committed Apr 24, 2024
1 parent 05b868c commit 1383224
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 74 deletions.
66 changes: 33 additions & 33 deletions README.md

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"google-generativeai",
"numexpr>=2.8.6",
"dashscope>=1.10.0",
"tetos>=0.1.1",
"tetos>=0.2.1",
]
license = {text = "MIT"}
dynamic = ["version", "optional-dependencies"]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ socksio==1.0.0
soupsieve==2.5
sqlalchemy==2.0.25
tenacity==8.2.3
tetos==0.1.1
tetos==0.2.1
tqdm==4.66.1
typing-extensions==4.9.0
typing-inspect==0.9.0
Expand Down
4 changes: 3 additions & 1 deletion xiaogpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ class Config:
start_conversation: str = "开始持续对话"
end_conversation: str = "结束持续对话"
stream: bool = False
tts: Literal["mi", "edge", "azure", "openai", "baidu", "google", "volc"] = "mi"
tts: Literal[
"mi", "edge", "azure", "openai", "baidu", "google", "volc", "minimax"
] = "mi"
tts_options: dict[str, Any] = field(default_factory=dict)
gpt_options: dict[str, Any] = field(default_factory=dict)
bing_cookie_path: str = ""
Expand Down
41 changes: 8 additions & 33 deletions xiaogpt/tts/tetos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path

from miservice import MiNAService
from tetos.base import Speaker

from xiaogpt.config import Config
from xiaogpt.tts.base import AudioFileTTS
Expand All @@ -14,39 +13,15 @@ class TetosTTS(AudioFileTTS):
def __init__(
self, mina_service: MiNAService, device_id: str, config: Config
) -> None:
super().__init__(mina_service, device_id, config)
self.speaker = self._get_speaker()

def _get_speaker(self) -> Speaker:
from tetos.azure import AzureSpeaker
from tetos.baidu import BaiduSpeaker
from tetos.edge import EdgeSpeaker
from tetos.google import GoogleSpeaker
from tetos.openai import OpenAISpeaker
from tetos.volc import VolcSpeaker
from tetos import get_speaker

options = self.config.tts_options
allowed_speakers: list[str] = []
for speaker in (
OpenAISpeaker,
EdgeSpeaker,
AzureSpeaker,
VolcSpeaker,
GoogleSpeaker,
BaiduSpeaker,
):
if (name := speaker.__name__[:-7].lower()) == self.config.tts:
try:
return speaker(**options)
except TypeError as e:
raise ValueError(
f"{e}. Please add them via `tts_options` config"
) from e
else:
allowed_speakers.append(name)
raise ValueError(
f"Unsupported TTS: {self.config.tts}, allowed: {','.join(allowed_speakers)}"
)
super().__init__(mina_service, device_id, config)
assert config.tts and config.tts != "mi"
speaker_cls = get_speaker(config.tts)
try:
self.speaker = speaker_cls(**config.tts_options)
except TypeError as e:
raise ValueError(f"{e}. Please add them via `tts_options` config") from e

async def make_audio_file(self, lang: str, text: str) -> tuple[Path, float]:
output_file = tempfile.NamedTemporaryFile(
Expand Down
8 changes: 7 additions & 1 deletion xiaogpt/xiaogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,15 @@ async def speak(self, text_stream: AsyncIterator[str]) -> None:
# It is not a legal language code, discard it
lang, first_chunk = "", text

lang = (
matches[0]
if (matches := re.findall(r"([a-z]{2}-[A-Z]{2})", lang))
else "zh-CN"
)

async def gen(): # reconstruct the generator
yield first_chunk
async for text in text_stream:
yield text

await self.tts.synthesize(lang or "zh-CN", gen())
await self.tts.synthesize(lang, gen())

0 comments on commit 1383224

Please sign in to comment.