Skip to content

feat: 😽增加重试机制, 解决模型不稳定的报错 #1344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
"dequeue_context_length": 1,
"streaming_response": False,
"streaming_segmented": False,
"max_retries": 3,
"retry_delay": 1.0,
},
"provider_stt_settings": {
"enable": False,
Expand Down Expand Up @@ -1035,6 +1037,15 @@
"type": "bool",
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
},
"max_retries": {
"description": "最大重试次数",
"type": "int",
"hint": "当 LLM 请求失败时, 最大重试次数。",
},
"retry_delay": {
"description": "重试延迟时间",
"type": "float",
"hint": "当 LLM 请求失败时, 重试延迟时间。单位为秒。",
},
},
"persona": {
Expand Down
217 changes: 133 additions & 84 deletions astrbot/core/pipeline/process_stage/method/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.streaming_response = ctx.astrbot_config["provider_settings"][
"streaming_response"
] # bool
self.max_retries = ctx.astrbot_config["provider_settings"].get("max_retries", 3)
self.retry_delay = ctx.astrbot_config["provider_settings"].get(
"retry_delay", 1.0
)

for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
Expand Down Expand Up @@ -149,7 +153,14 @@ async def process(
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next((i for i, item in enumerate(req.contexts) if item.get("role") == "user"), None)
index = next(
(
i
for i, item in enumerate(req.contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
req.contexts = req.contexts[index:]

Expand All @@ -158,94 +169,132 @@ async def process(
req.session_id = event.unified_msg_origin

async def requesting(req: ProviderRequest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

问题 (复杂度): 考虑重构 requesting 函数,通过将请求执行和重试逻辑提取到单独的辅助函数中,以提高可读性并降低嵌套复杂度,而无需更改功能。

考虑将重试逻辑和内部处理提取到单独的辅助函数中。这将扁平化嵌套循环和 try/except 块,以提高可读性,而无需更改行为。例如,将执行单个请求的逻辑移动到其自己的函数中,然后用重试循环包装它:

async def _execute_request(self, req: ProviderRequest, event, provider) -> Optional[LLMResponse]:
    logger.debug(f"提供商请求 Payload: {req}")
    final_llm_response = None
    if self.streaming_response:
        stream = provider.text_chat_stream(**req.__dict__)
        async for llm_response in stream:
            if llm_response.is_chunk:
                if llm_response.result_chain:
                    yield llm_response.result_chain  # MessageChain
                else:
                    yield MessageChain().message(llm_response.completion_text)
            else:
                final_llm_response = llm_response
    else:
        final_llm_response = await provider.text_chat(**req.__dict__)
    if not final_llm_response:
        raise Exception("LLM response is None.")
    # Execute post-response event hooks
    await self._handle_event_hooks(event, final_llm_response)
    # Handle functions/streaming responses
    if self.streaming_response:
        async for result in self._handle_llm_stream_response(event, req, final_llm_response):
            yield result
    else:
        async for result in self._handle_llm_response(event, req, final_llm_response):
            yield result

然后用重试逻辑包装此执行:

async def requesting(self, req: ProviderRequest, event, provider):
    retry_count = 0
    while True:
        try:
            async for result in self._execute_request(req, event, provider):
                if isinstance(result, ProviderRequest):
                    req = result  # new LLM request
                    break  # re-enter execution with modified req
                else:
                    yield result
            else:
                # Only break out if no inner loop reset happened.
                break
            retry_count = 0  # Reset retry_count if a new req was processed successfully.
        except Exception as e:
            retry_count += 1
            logger.error(f"LLM请求失败 (尝试 {retry_count}/{self.max_retries}): {type(e).__name__} : {str(e)}")
            logger.error(traceback.format_exc())
            if retry_count < self.max_retries and any(err in str(e).lower() for err in ["timeout", "connection", "rate limit", "server error", "500", "503"]):
                logger.info(f"将在 {self.retry_delay} 秒后重试 LLM 请求 >﹏<")
                await asyncio.sleep(self.retry_delay)
            else:
                logger.error(f"LLM 请求失败, 重试次数({retry_count - 1})用尽: {type(e).__name__} : {str(e)}")
                break

最后,更新您的调用站点以使用此重构的 requesting 函数,保持所有功能完整,同时减少嵌套。

Original comment in English

issue (complexity): Consider refactoring the requesting function by extracting the request execution and retry logic into separate helper functions to improve readability and reduce nesting complexity without altering the functionality.

Consider extracting the retry logic and inner handling into separate helper functions. This would flatten the nested loops and try/except blocks to improve readability without changing behavior. For example, move the logic that executes a single request into its own function and then wrap that with the retry loop:

async def _execute_request(self, req: ProviderRequest, event, provider) -> Optional[LLMResponse]:
    logger.debug(f"提供商请求 Payload: {req}")
    final_llm_response = None
    if self.streaming_response:
        stream = provider.text_chat_stream(**req.__dict__)
        async for llm_response in stream:
            if llm_response.is_chunk:
                if llm_response.result_chain:
                    yield llm_response.result_chain  # MessageChain
                else:
                    yield MessageChain().message(llm_response.completion_text)
            else:
                final_llm_response = llm_response
    else:
        final_llm_response = await provider.text_chat(**req.__dict__)
    if not final_llm_response:
        raise Exception("LLM response is None.")
    # Execute post-response event hooks
    await self._handle_event_hooks(event, final_llm_response)
    # Handle functions/streaming responses
    if self.streaming_response:
        async for result in self._handle_llm_stream_response(event, req, final_llm_response):
            yield result
    else:
        async for result in self._handle_llm_response(event, req, final_llm_response):
            yield result

Then wrap this execution with retry logic:

async def requesting(self, req: ProviderRequest, event, provider):
    retry_count = 0
    while True:
        try:
            async for result in self._execute_request(req, event, provider):
                if isinstance(result, ProviderRequest):
                    req = result  # new LLM request
                    break  # re-enter execution with modified req
                else:
                    yield result
            else:
                # Only break out if no inner loop reset happened.
                break
            retry_count = 0  # Reset retry_count if a new req was processed successfully.
        except Exception as e:
            retry_count += 1
            logger.error(f"LLM请求失败 (尝试 {retry_count}/{self.max_retries}): {type(e).__name__} : {str(e)}")
            logger.error(traceback.format_exc())
            if retry_count < self.max_retries and any(err in str(e).lower() for err in ["timeout", "connection", "rate limit", "server error", "500", "503"]):
                logger.info(f"将在 {self.retry_delay} 秒后重试 LLM 请求 >﹏<")
                await asyncio.sleep(self.retry_delay)
            else:
                logger.error(f"LLM 请求失败, 重试次数({retry_count - 1})用尽: {type(e).__name__} : {str(e)}")
                break

Finally, update your call sites to use this refactored requesting function, keeping all functionality intact while reducing nesting.

try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")

final_llm_response = None

if self.streaming_response:
stream = provider.text_chat_stream(**req.__dict__)
async for llm_response in stream:
if llm_response.is_chunk:
if llm_response.result_chain:
yield llm_response.result_chain # MessageChain
retry_count = 0
while True:
try:
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")

final_llm_response = None

if self.streaming_response:
stream = provider.text_chat_stream(**req.__dict__)
async for llm_response in stream:
if llm_response.is_chunk:
if llm_response.result_chain:
yield llm_response.result_chain # MessageChain
else:
yield MessageChain().message(
llm_response.completion_text
)
else:
final_llm_response = llm_response
else:
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM

if not final_llm_response:
raise Exception("LLM response is None.")

# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())

if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return

# ==========================================
# 执行函数调用
# ==========================================
if self.streaming_response:
# 流式输出的处理
async for result in self._handle_llm_stream_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
else:
# 非流式输出的处理
async for result in self._handle_llm_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield MessageChain().message(
llm_response.completion_text
)
else:
final_llm_response = llm_response
else:
final_llm_response = await provider.text_chat(
**req.__dict__
) # 请求 LLM

if not final_llm_response:
raise Exception("LLM response is None.")

# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
yield

# ==========================================
# 请求后上传相关非敏感指标
# ==========================================
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, final_llm_response)
except BaseException:
logger.error(traceback.format_exc())

if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return

if self.streaming_response:
# 流式输出的处理
async for result in self._handle_llm_stream_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
else:
# 非流式输出的处理
async for result in self._handle_llm_response(
event, req, final_llm_response
):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield

asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
# 保存到历史记录
await self._save_to_history(event, req, final_llm_response)
break

# ==========================================
# 请求错误时进行重试
# ==========================================
except Exception as e:
retry_count += 1
logger.error(
f"LLM请求失败 (尝试 {retry_count}/{self.max_retries}): {type(e).__name__} : {str(e)}"
)
)
logger.error(traceback.format_exc())

# 是否继续重试
if retry_count < self.max_retries:
should_retry = any(
err in str(e).lower()
for err in [
"timeout",
"connection",
"rate limit",
"server error",
"500",
"503",
]
)

# 保存到历史记录
await self._save_to_history(event, req, final_llm_response)
if should_retry:
logger.info(
f"将在 {self.retry_delay} 秒后重试 LLM 请求 >﹏<"
)
await asyncio.sleep(self.retry_delay)
continue

except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
# 是重试也解决不了的错误或次数用尽时
logger.error(
f"LLM 请求失败, 重试次数({retry_count - 1})用尽: {type(e).__name__} : {str(e)}"
)
)
break

if not self.streaming_response:
event.set_extra("tool_call_result", None)
Expand Down Expand Up @@ -388,8 +437,8 @@ async def _handle_function_tools(
platform_id = event.get_platform_id()
star_md = star_map.get(func_tool.handler_module_path)
if (
star_md and
platform_id in star_md.supported_platforms
star_md
and platform_id in star_md.supported_platforms
and not star_md.supported_platforms[platform_id]
):
logger.debug(
Expand Down