From 8ee7a481510969b3de999efa53703d3a1be55dff Mon Sep 17 00:00:00 2001 From: JS00000 Date: Sun, 9 Apr 2023 18:00:34 +0800 Subject: [PATCH 1/2] fix: wechatmp's deadloop when reply is None --- channel/chat_channel.py | 5 +++++ channel/wechatmp/ServiceAccount.py | 9 +++++---- channel/wechatmp/SubscribeAccount.py | 29 +++++++++++++++++----------- channel/wechatmp/wechatmp_channel.py | 12 ++++++------ 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 5000ae3de..9ef3d4182 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -233,6 +233,9 @@ def _send(self, reply: Reply, context: Context, retry_cnt = 0): time.sleep(3+3*retry_cnt) self._send(reply, context, retry_cnt+1) + def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 + pass + def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 logger.exception("Worker return exception: {}".format(exception)) @@ -242,6 +245,8 @@ def func(worker:Future): worker_exception = worker.exception() if worker_exception: self._fail_callback(session_id, exception = worker_exception, **kwargs) + else: + self._success_callback(session_id, **kwargs) except CancelledError as e: logger.info("Worker cancelled, session_id = {}".format(session_id)) except Exception as e: diff --git a/channel/wechatmp/ServiceAccount.py b/channel/wechatmp/ServiceAccount.py index db9dff3e0..60d40db6c 100644 --- a/channel/wechatmp/ServiceAccount.py +++ b/channel/wechatmp/ServiceAccount.py @@ -16,7 +16,7 @@ def GET(self): def POST(self): # Make sure to return the instance that first created, @singleton will do that. - channel_instance = WechatMPChannel() + channel = WechatMPChannel() try: webData = web.data() # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) @@ -27,14 +27,15 @@ def POST(self): message_id = wechatmp_msg.msg_id logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) - context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) + context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) if context: # set private openai_api_key # if from_user is not changed in itchat, this can be placed at chat_channel user_data = conf().get_user_data(from_user) context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key - channel_instance.produce(context) - # The reply will be sent by channel_instance.send() in another thread + channel.produce(context) + channel.running.add(from_user) + # The reply will be sent by channel.send() in another thread return "success" elif wechatmp_msg.msg_type == 'event': diff --git a/channel/wechatmp/SubscribeAccount.py b/channel/wechatmp/SubscribeAccount.py index 745ef0e46..b1047c394 100644 --- a/channel/wechatmp/SubscribeAccount.py +++ b/channel/wechatmp/SubscribeAccount.py @@ -41,7 +41,8 @@ def POST(self): context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg)) if message_id in channel.received_msgs: # received and finished - return + # no return because of bandwords or other reasons + return "success" if supported and context: # set private openai_api_key # if from_user is not changed in itchat, this can be placed at chat_channel @@ -71,11 +72,12 @@ def POST(self): channel.query1[cache_key] = False channel.query2[cache_key] = False channel.query3[cache_key] = False - # Request again + # User request again, and the answer is not ready elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. channel.query3[cache_key] = False + # User request again, and the answer is ready elif cache_key in channel.cache_dict: # Skip the waiting phase channel.query1[cache_key] = True @@ -89,7 +91,7 @@ def POST(self): logger.debug("[wechatmp] query1 {}".format(cache_key)) channel.query1[cache_key] = True cnt = 0 - while cache_key not in channel.cache_dict and cnt < 45: + while cache_key in channel.running and cnt < 45: cnt = cnt + 1 time.sleep(0.1) if cnt == 45: @@ -104,7 +106,7 @@ def POST(self): logger.debug("[wechatmp] query2 {}".format(cache_key)) channel.query2[cache_key] = True cnt = 0 - while cache_key not in channel.cache_dict and cnt < 45: + while cache_key in channel.running and cnt < 45: cnt = cnt + 1 time.sleep(0.1) if cnt == 45: @@ -119,7 +121,7 @@ def POST(self): logger.debug("[wechatmp] query3 {}".format(cache_key)) channel.query3[cache_key] = True cnt = 0 - while cache_key not in channel.cache_dict and cnt < 40: + while cache_key in channel.running and cnt < 40: cnt = cnt + 1 time.sleep(0.1) if cnt == 40: @@ -132,12 +134,17 @@ def POST(self): else: pass - if float(time.time()) - float(query_time) > 4.8: - reply_text = "【正在思考中,回复任意文字尝试获取回复】" - logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id)) - replyPost = reply.TextMsg(from_user, to_user, reply_text).send() - return replyPost - + + if cache_key not in channel.cache_dict and cache_key not in channel.running: + # no return because of bandwords or other reasons + return "success" + + # if float(time.time()) - float(query_time) > 4.8: + # reply_text = "【正在思考中,回复任意文字尝试获取回复】" + # logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id)) + # replyPost = reply.TextMsg(from_user, to_user, reply_text).send() + # return replyPost + if cache_key in channel.cache_dict: content = channel.cache_dict[cache_key] if len(content.encode('utf8'))<=MAX_UTF8_LEN: diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 49f45e013..940f9e37b 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -97,8 +97,7 @@ def send(self, reply: Reply, context: Context): if self.passive_reply: receiver = context["receiver"] self.cache_dict[receiver] = reply.content - self.running.remove(receiver) - logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply)) + logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) else: receiver = context["receiver"] reply_text = reply.content @@ -115,11 +114,12 @@ def send(self, reply: Reply, context: Context): logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) return + def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 + self.running.remove(session_id) - def _fail_callback(self, session_id, exception, context, **kwargs): + def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) - assert session_id not in self.cache_dict + if self.passive_reply: + assert session_id not in self.cache_dict self.running.remove(session_id) - - From f687b2b6f4580b00c9d58e2cf48ff43e59811cd8 Mon Sep 17 00:00:00 2001 From: JS00000 Date: Sun, 9 Apr 2023 18:32:09 +0800 Subject: [PATCH 2/2] remove _success_callback --- channel/chat_channel.py | 5 ----- channel/wechatmp/ServiceAccount.py | 1 - channel/wechatmp/wechatmp_channel.py | 5 ++--- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 9ef3d4182..5000ae3de 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -233,9 +233,6 @@ def _send(self, reply: Reply, context: Context, retry_cnt = 0): time.sleep(3+3*retry_cnt) self._send(reply, context, retry_cnt+1) - def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 - pass - def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 logger.exception("Worker return exception: {}".format(exception)) @@ -245,8 +242,6 @@ def func(worker:Future): worker_exception = worker.exception() if worker_exception: self._fail_callback(session_id, exception = worker_exception, **kwargs) - else: - self._success_callback(session_id, **kwargs) except CancelledError as e: logger.info("Worker cancelled, session_id = {}".format(session_id)) except Exception as e: diff --git a/channel/wechatmp/ServiceAccount.py b/channel/wechatmp/ServiceAccount.py index 60d40db6c..ae535ea0b 100644 --- a/channel/wechatmp/ServiceAccount.py +++ b/channel/wechatmp/ServiceAccount.py @@ -34,7 +34,6 @@ def POST(self): user_data = conf().get_user_data(from_user) context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key channel.produce(context) - channel.running.add(from_user) # The reply will be sent by channel.send() in another thread return "success" diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 940f9e37b..04576060d 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -97,6 +97,7 @@ def send(self, reply: Reply, context: Context): if self.passive_reply: receiver = context["receiver"] self.cache_dict[receiver] = reply.content + self.running.remove(receiver) logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) else: receiver = context["receiver"] @@ -114,12 +115,10 @@ def send(self, reply: Reply, context: Context): logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) return - def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 - self.running.remove(session_id) def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) if self.passive_reply: assert session_id not in self.cache_dict - self.running.remove(session_id) + self.running.remove(session_id)