From 618c94edb83592a41b3cc8411cbbd9539e88051f Mon Sep 17 00:00:00 2001 From: lanvent Date: Sat, 22 Apr 2023 12:01:29 +0800 Subject: [PATCH] formatting: run precommit on all files --- app.py | 1 + bot/baidu/baidu_unit_bot.py | 12 +-- bot/chatgpt/chat_gpt_bot.py | 29 ++---- bot/chatgpt/chat_gpt_session.py | 22 +---- bot/openai/open_ai_bot.py | 28 ++---- bot/openai/open_ai_image.py | 10 +- bot/openai/open_ai_session.py | 16 +--- bot/session_manager.py | 20 +--- bridge/context.py | 4 +- channel/chat_channel.py | 131 ++++++--------------------- channel/terminal/terminal_channel.py | 4 +- channel/wechat/wechat_channel.py | 39 ++------ channel/wechat/wechat_message.py | 26 ++---- channel/wechat/wechaty_channel.py | 20 +--- channel/wechat/wechaty_message.py | 12 +-- channel/wechatmp/active_reply.py | 25 ++--- channel/wechatmp/common.py | 8 +- channel/wechatmp/passive_reply.py | 54 ++++------- channel/wechatmp/wechatmp_channel.py | 47 ++++------ channel/wechatmp/wechatmp_client.py | 21 ++--- channel/wechatmp/wechatmp_message.py | 19 +--- common/time_check.py | 14 +-- config.py | 4 +- plugins/banwords/banwords.py | 12 +-- plugins/bdunit/bdunit.py | 39 ++------ plugins/dungeon/dungeon.py | 10 +- plugins/godcmd/godcmd.py | 30 ++---- plugins/hello/hello.py | 8 +- plugins/keyword/config.json.template | 4 +- plugins/keyword/keyword.py | 4 +- plugins/plugin_manager.py | 60 +++--------- plugins/role/role.py | 31 ++----- plugins/tool/README.md | 2 +- plugins/tool/tool.py | 14 +-- voice/audio_convert.py | 20 +--- voice/azure/azure_voice.py | 42 ++------- voice/baidu/baidu_voice.py | 4 +- voice/google/google_voice.py | 10 +- voice/openai/openai_voice.py | 6 +- voice/pytts/pytts_voice.py | 12 ++- 40 files changed, 228 insertions(+), 646 deletions(-) diff --git a/app.py b/app.py index e11f46c2a..637b6e462 100644 --- a/app.py +++ b/app.py @@ -19,6 +19,7 @@ def func(_signo, _stack_frame): if callable(old_handler): # check old_handler return old_handler(_signo, _stack_frame) sys.exit(0) + signal.signal(_signo, func) diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py index d8a0aca11..f7714e4f4 100644 --- a/bot/baidu/baidu_unit_bot.py +++ b/bot/baidu/baidu_unit_bot.py @@ -10,10 +10,7 @@ class BaiduUnitBot(Bot): def reply(self, query, context=None): token = self.get_token() - url = ( - "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" - + token - ) + url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token post_data = ( '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' + query @@ -32,12 +29,7 @@ def reply(self, query, context=None): def get_token(self): access_key = "YOUR_ACCESS_KEY" secret_key = "YOUR_SECRET_KEY" - host = ( - "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" - + access_key - + "&client_secret=" - + secret_key - ) + host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key response = requests.get(host) if response: print(response.json()) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index d8e4b0e26..b045311a3 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -30,23 +30,15 @@ def __init__(self): if conf().get("rate_limit_chatgpt"): self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) - self.sessions = SessionManager( - ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo" - ) + self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") self.args = { "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 # "max_tokens":4096, # 回复最大的字符数 "top_p": 1, - "frequency_penalty": conf().get( - "frequency_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty": conf().get( - "presence_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get( - "request_timeout", None - ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 } @@ -87,15 +79,10 @@ def reply(self, query, context=None): reply_content["completion_tokens"], ) ) - if ( - reply_content["completion_tokens"] == 0 - and len(reply_content["content"]) > 0 - ): + if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: reply = Reply(ReplyType.ERROR, reply_content["content"]) elif reply_content["completion_tokens"] > 0: - self.sessions.session_reply( - reply_content["content"], session_id, reply_content["total_tokens"] - ) + self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) reply = Reply(ReplyType.TEXT, reply_content["content"]) else: reply = Reply(ReplyType.ERROR, reply_content["content"]) @@ -126,9 +113,7 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") # if api_key == None, the default openai.api_key will be used - response = openai.ChatCompletion.create( - api_key=api_key, messages=session.messages, **self.args - ) + response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) return { "total_tokens": response["usage"]["total_tokens"], diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index 525793ffa..e6c319b36 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -25,9 +25,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): precise = False if cur_tokens is None: raise e - logger.debug( - "Exception when counting tokens precisely for query: {}".format(e) - ) + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) while cur_tokens > max_tokens: if len(self.messages) > 2: self.messages.pop(1) @@ -39,16 +37,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): cur_tokens = cur_tokens - max_tokens break elif len(self.messages) == 2 and self.messages[1]["role"] == "user": - logger.warn( - "user message exceed max_tokens. total_tokens={}".format(cur_tokens) - ) + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) break else: - logger.debug( - "max_tokens={}, total_tokens={}, len(messages)={}".format( - max_tokens, cur_tokens, len(self.messages) - ) - ) + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) break if precise: cur_tokens = self.calc_tokens() @@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model): elif model == "gpt-4": return num_tokens_from_messages(messages, model="gpt-4-0314") elif model == "gpt-3.5-turbo-0301": - tokens_per_message = ( - 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - ) + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted elif model == "gpt-4-0314": tokens_per_message = 3 tokens_per_name = 1 else: - logger.warn( - f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301." - ) + logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.") return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") num_tokens = 0 for message in messages: diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index 1cfbf10d8..160562526 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -28,23 +28,15 @@ def __init__(self): if proxy: openai.proxy = proxy - self.sessions = SessionManager( - OpenAISession, model=conf().get("model") or "text-davinci-003" - ) + self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003") self.args = { "model": conf().get("model") or "text-davinci-003", # 对话模型的名称 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "max_tokens": 1200, # 回复最大的字符数 "top_p": 1, - "frequency_penalty": conf().get( - "frequency_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty": conf().get( - "presence_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get( - "request_timeout", None - ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 "stop": ["\n\n\n"], } @@ -71,17 +63,13 @@ def reply(self, query, context=None): result["content"], ) logger.debug( - "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( - str(session), session_id, reply_content, completion_tokens - ) + "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) ) if total_tokens == 0: reply = Reply(ReplyType.ERROR, reply_content) else: - self.sessions.session_reply( - reply_content, session_id, total_tokens - ) + self.sessions.session_reply(reply_content, session_id, total_tokens) reply = Reply(ReplyType.TEXT, reply_content) return reply elif context.type == ContextType.IMAGE_CREATE: @@ -96,9 +84,7 @@ def reply(self, query, context=None): def reply_text(self, session: OpenAISession, retry_count=0): try: response = openai.Completion.create(prompt=str(session), **self.args) - res_content = ( - response.choices[0]["text"].strip().replace("<|endoftext|>", "") - ) + res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "") total_tokens = response["usage"]["total_tokens"] completion_tokens = response["usage"]["completion_tokens"] logger.info("[OPEN_AI] reply={}".format(res_content)) diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py index 5dbbd23ed..b188557f3 100644 --- a/bot/openai/open_ai_image.py +++ b/bot/openai/open_ai_image.py @@ -23,9 +23,7 @@ def create_img(self, query, retry_count=0): response = openai.Image.create( prompt=query, # 图片描述 n=1, # 每次生成图片的数量 - size=conf().get( - "image_create_size", "256x256" - ), # 图片大小,可选有 256x256, 512x512, 1024x1024 + size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024 ) image_url = response["data"][0]["url"] logger.info("[OPEN_AI] image_url={}".format(image_url)) @@ -34,11 +32,7 @@ def create_img(self, query, retry_count=0): logger.warn(e) if retry_count < 1: time.sleep(5) - logger.warn( - "[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format( - retry_count + 1 - ) - ) + logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1)) return self.create_img(query, retry_count + 1) else: return False, "提问太快啦,请休息一下再问我吧" diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py index 78cf43900..8f6aa4f5b 100644 --- a/bot/openai/open_ai_session.py +++ b/bot/openai/open_ai_session.py @@ -36,9 +36,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): precise = False if cur_tokens is None: raise e - logger.debug( - "Exception when counting tokens precisely for query: {}".format(e) - ) + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) while cur_tokens > max_tokens: if len(self.messages) > 1: self.messages.pop(0) @@ -50,18 +48,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): cur_tokens = len(str(self)) break elif len(self.messages) == 1 and self.messages[0]["role"] == "user": - logger.warn( - "user question exceed max_tokens. total_tokens={}".format( - cur_tokens - ) - ) + logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) break else: - logger.debug( - "max_tokens={}, total_tokens={}, len(conversation)={}".format( - max_tokens, cur_tokens, len(self.messages) - ) - ) + logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) break if precise: cur_tokens = self.calc_tokens() diff --git a/bot/session_manager.py b/bot/session_manager.py index 1aff647b5..8d70886e0 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -55,9 +55,7 @@ def build_session(self, session_id, system_prompt=None): return self.sessioncls(session_id, system_prompt, **self.session_args) if session_id not in self.sessions: - self.sessions[session_id] = self.sessioncls( - session_id, system_prompt, **self.session_args - ) + self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session self.sessions[session_id].set_system_prompt(system_prompt) session = self.sessions[session_id] @@ -71,9 +69,7 @@ def session_query(self, query, session_id): total_tokens = session.discard_exceeding(max_tokens, None) logger.debug("prompt tokens used={}".format(total_tokens)) except Exception as e: - logger.debug( - "Exception when counting tokens precisely for prompt: {}".format(str(e)) - ) + logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) return session def session_reply(self, reply, session_id, total_tokens=None): @@ -82,17 +78,9 @@ def session_reply(self, reply, session_id, total_tokens=None): try: max_tokens = conf().get("conversation_max_tokens", 1000) tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) - logger.debug( - "raw total_tokens={}, savesession tokens={}".format( - total_tokens, tokens_cnt - ) - ) + logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) except Exception as e: - logger.debug( - "Exception when counting tokens precisely for session: {}".format( - str(e) - ) - ) + logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) return session def clear_session(self, session_id): diff --git a/bridge/context.py b/bridge/context.py index c1eb10c1f..ab004c0f6 100644 --- a/bridge/context.py +++ b/bridge/context.py @@ -60,6 +60,4 @@ def __delitem__(self, key): del self.kwargs[key] def __str__(self): - return "Context(type={}, content={}, kwargs={})".format( - self.type, self.content, self.kwargs - ) + return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 04a9496d2..b3df9063e 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -53,9 +53,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): group_id = cmsg.other_user_id group_name_white_list = config.get("group_name_white_list", []) - group_name_keyword_white_list = config.get( - "group_name_keyword_white_list", [] - ) + group_name_keyword_white_list = config.get("group_name_keyword_white_list", []) if any( [ group_name in group_name_white_list, @@ -63,9 +61,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): check_contain(group_name, group_name_keyword_white_list), ] ): - group_chat_in_one_session = conf().get( - "group_chat_in_one_session", [] - ) + group_chat_in_one_session = conf().get("group_chat_in_one_session", []) session_id = cmsg.actual_user_id if any( [ @@ -81,17 +77,11 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): else: context["session_id"] = cmsg.other_user_id context["receiver"] = cmsg.other_user_id - e_context = PluginManager().emit_event( - EventContext( - Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context} - ) - ) + e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context})) context = e_context["context"] if e_context.is_pass() or context is None: return context - if cmsg.from_user_id == self.user_id and not config.get( - "trigger_by_self", True - ): + if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True): logger.debug("[WX]self message skipped") return None @@ -119,19 +109,13 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): if not flag: if context["origin_ctype"] == ContextType.VOICE: - logger.info( - "[WX]receive group voice, but checkprefix didn't match" - ) + logger.info("[WX]receive group voice, but checkprefix didn't match") return None else: # 单聊 - match_prefix = check_prefix( - content, conf().get("single_chat_prefix", [""]) - ) + match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""])) if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 content = content.replace(match_prefix, "", 1).strip() - elif ( - context["origin_ctype"] == ContextType.VOICE - ): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 + elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 pass else: return None @@ -143,18 +127,10 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): else: context.type = ContextType.TEXT context.content = content.strip() - if ( - "desire_rtype" not in context - and conf().get("always_reply_voice") - and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE - ): + if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: context["desire_rtype"] = ReplyType.VOICE elif context.type == ContextType.VOICE: - if ( - "desire_rtype" not in context - and conf().get("voice_reply_voice") - and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE - ): + if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: context["desire_rtype"] = ReplyType.VOICE return context @@ -182,15 +158,8 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: ) reply = e_context["reply"] if not e_context.is_pass(): - logger.debug( - "[WX] ready to handle context: type={}, content={}".format( - context.type, context.content - ) - ) - if ( - context.type == ContextType.TEXT - or context.type == ContextType.IMAGE_CREATE - ): # 文字和图片消息 + logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content)) + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 reply = super().build_reply_content(context.content, context) elif context.type == ContextType.VOICE: # 语音消息 cmsg = context["msg"] @@ -214,9 +183,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: # logger.warning("[WX]delete temp file error: " + str(e)) if reply.type == ReplyType.TEXT: - new_context = self._compose_context( - ContextType.TEXT, reply.content, **context.kwargs - ) + new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs) if new_context: reply = self._generate_reply(new_context) else: @@ -246,48 +213,24 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply: if reply.type == ReplyType.TEXT: reply_text = reply.content - if ( - desire_rtype == ReplyType.VOICE - and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE - ): + if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: reply = super().build_text_to_voice(reply.content) return self._decorate_reply(context, reply) if context.get("isgroup", False): - reply_text = ( - "@" - + context["msg"].actual_user_nickname - + " " - + reply_text.strip() - ) - reply_text = ( - conf().get("group_chat_reply_prefix", "") + reply_text - ) + reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "") + reply_text else: - reply_text = ( - conf().get("single_chat_reply_prefix", "") + reply_text - ) + reply_text = conf().get("single_chat_reply_prefix", "") + reply_text reply.content = reply_text elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: reply.content = "[" + str(reply.type) + "]\n" + reply.content - elif ( - reply.type == ReplyType.IMAGE_URL - or reply.type == ReplyType.VOICE - or reply.type == ReplyType.IMAGE - ): + elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: pass else: logger.error("[WX] unknown reply type: {}".format(reply.type)) return - if ( - desire_rtype - and desire_rtype != reply.type - and reply.type not in [ReplyType.ERROR, ReplyType.INFO] - ): - logger.warning( - "[WX] desire_rtype: {}, but reply type: {}".format( - context.get("desire_rtype"), reply.type - ) - ) + if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: + logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type)) return reply def _send_reply(self, context: Context, reply: Reply): @@ -300,9 +243,7 @@ def _send_reply(self, context: Context, reply: Reply): ) reply = e_context["reply"] if not e_context.is_pass() and reply and reply.type: - logger.debug( - "[WX] ready to send reply: {}, context: {}".format(reply, context) - ) + logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context)) self._send(reply, context) def _send(self, reply: Reply, context: Context, retry_cnt=0): @@ -328,9 +269,7 @@ def func(worker: Future): try: worker_exception = worker.exception() if worker_exception: - self._fail_callback( - session_id, exception=worker_exception, **kwargs - ) + self._fail_callback(session_id, exception=worker_exception, **kwargs) else: self._success_callback(session_id, **kwargs) except CancelledError as e: @@ -366,24 +305,14 @@ def consume(self): if not context_queue.empty(): context = context_queue.get() logger.debug("[WX] consume context: {}".format(context)) - future: Future = self.handler_pool.submit( - self._handle, context - ) - future.add_done_callback( - self._thread_pool_callback(session_id, context=context) - ) + future: Future = self.handler_pool.submit(self._handle, context) + future.add_done_callback(self._thread_pool_callback(session_id, context=context)) if session_id not in self.futures: self.futures[session_id] = [] self.futures[session_id].append(future) - elif ( - semaphore._initial_value == semaphore._value + 1 - ): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 - self.futures[session_id] = [ - t for t in self.futures[session_id] if not t.done() - ] - assert ( - len(self.futures[session_id]) == 0 - ), "thread pool error" + elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 + self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] + assert len(self.futures[session_id]) == 0, "thread pool error" del self.sessions[session_id] else: semaphore.release() @@ -397,9 +326,7 @@ def cancel_session(self, session_id): future.cancel() cnt = self.sessions[session_id][0].qsize() if cnt > 0: - logger.info( - "Cancel {} messages in session {}".format(cnt, session_id) - ) + logger.info("Cancel {} messages in session {}".format(cnt, session_id)) self.sessions[session_id][0] = Dequeue() def cancel_all_session(self): @@ -409,9 +336,7 @@ def cancel_all_session(self): future.cancel() cnt = self.sessions[session_id][0].qsize() if cnt > 0: - logger.info( - "Cancel {} messages in session {}".format(cnt, session_id) - ) + logger.info("Cancel {} messages in session {}".format(cnt, session_id)) self.sessions[session_id][0] = Dequeue() diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py index e2060789c..9a413dcff 100644 --- a/channel/terminal/terminal_channel.py +++ b/channel/terminal/terminal_channel.py @@ -77,9 +77,7 @@ def startup(self): if check_prefix(prompt, trigger_prefixs) is None: prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 - context = self._compose_context( - ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt) - ) + context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)) if context: self.produce(context) else: diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index cf200b17e..16d788c44 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -56,10 +56,7 @@ def wrapper(self, cmsg: ChatMessage): return self.receivedMsgs[msgId] = cmsg create_time = cmsg.create_time # 消息时间戳 - if ( - conf().get("hot_reload") == True - and int(create_time) < int(time.time()) - 60 - ): # 跳过1分钟前的历史消息 + if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 logger.debug("[WX]history message {} skipped".format(msgId)) return return func(self, cmsg) @@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode): url = f"https://login.weixin.qq.com/l/{uuid}" qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) - qr_api2 = ( - "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format( - url - ) - ) + qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url) qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) - qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format( - url - ) + qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url) print("You can also scan QRCode in any website below:") print(qr_api3) print(qr_api4) @@ -134,18 +125,12 @@ def startup(self): logger.error("Hot reload failed, try to login without hot reload") itchat.logout() os.remove(status_path) - itchat.auto_login( - enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback - ) + itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) else: raise e self.user_id = itchat.instance.storageClass.userName self.name = itchat.instance.storageClass.nickName - logger.info( - "Wechat login success, user_id: {}, nickname: {}".format( - self.user_id, self.name - ) - ) + logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) # start message listener itchat.run() @@ -173,16 +158,10 @@ def handle_single(self, cmsg: ChatMessage): elif cmsg.ctype == ContextType.PATPAT: logger.debug("[WX]receive patpat msg: {}".format(cmsg.content)) elif cmsg.ctype == ContextType.TEXT: - logger.debug( - "[WX]receive text msg: {}, cmsg={}".format( - json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg - ) - ) + logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) else: logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg)) - context = self._compose_context( - cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg - ) + context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) if context: self.produce(context) @@ -202,9 +181,7 @@ def handle_group(self, cmsg: ChatMessage): pass else: logger.debug("[WX]receive group msg: {}".format(cmsg.content)) - context = self._compose_context( - cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg - ) + context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) if context: self.produce(context) diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py index 18884259c..63c225471 100644 --- a/channel/wechat/wechat_message.py +++ b/channel/wechat/wechat_message.py @@ -27,37 +27,23 @@ def __init__(self, itchat_msg, is_group=False): self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 self._prepare_fn = lambda: itchat_msg.download(self.content) elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: - if is_group and ( - "加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"] - ): + if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]): self.ctype = ContextType.JOIN_GROUP self.content = itchat_msg["Content"] # 这里只能得到nickname, actual_user_id还是机器人的id if "加入了群聊" in itchat_msg["Content"]: - self.actual_user_nickname = re.findall( - r"\"(.*?)\"", itchat_msg["Content"] - )[-1] + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1] elif "加入群聊" in itchat_msg["Content"]: - self.actual_user_nickname = re.findall( - r"\"(.*?)\"", itchat_msg["Content"] - )[0] + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] elif "拍了拍我" in itchat_msg["Content"]: self.ctype = ContextType.PATPAT self.content = itchat_msg["Content"] if is_group: - self.actual_user_nickname = re.findall( - r"\"(.*?)\"", itchat_msg["Content"] - )[0] + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] else: - raise NotImplementedError( - "Unsupported note message: " + itchat_msg["Content"] - ) + raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"]) else: - raise NotImplementedError( - "Unsupported message type: Type:{} MsgType:{}".format( - itchat_msg["Type"], itchat_msg["MsgType"] - ) - ) + raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"])) self.from_user_id = itchat_msg["FromUserName"] self.to_user_id = itchat_msg["ToUserName"] diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 7383a206c..051a9cf10 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -60,13 +60,9 @@ def send(self, reply: Reply, context: Context): receiver_id = context["receiver"] loop = asyncio.get_event_loop() if context["isgroup"]: - receiver = asyncio.run_coroutine_threadsafe( - self.bot.Room.find(receiver_id), loop - ).result() + receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result() else: - receiver = asyncio.run_coroutine_threadsafe( - self.bot.Contact.find(receiver_id), loop - ).result() + receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result() msg = None if reply.type == ReplyType.TEXT: msg = reply.content @@ -83,9 +79,7 @@ def send(self, reply: Reply, context: Context): voiceLength = int(any_to_sil(file_path, sil_file)) if voiceLength >= 60000: voiceLength = 60000 - logger.info( - "[WX] voice too long, length={}, set to 60s".format(voiceLength) - ) + logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength)) # 发送语音 t = int(time.time()) msg = FileBox.from_file(sil_file, name=str(t) + ".sil") @@ -98,9 +92,7 @@ def send(self, reply: Reply, context: Context): os.remove(sil_file) except Exception as e: pass - logger.info( - "[WX] sendVoice={}, receiver={}".format(reply.content, receiver) - ) + logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver)) elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 img_url = reply.content t = int(time.time()) @@ -111,9 +103,7 @@ def send(self, reply: Reply, context: Context): image_storage = reply.content image_storage.seek(0) t = int(time.time()) - msg = FileBox.from_base64( - base64.b64encode(image_storage.read()), str(t) + ".png" - ) + msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png") asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() logger.info("[WX] sendImage, receiver={}".format(receiver)) diff --git a/channel/wechat/wechaty_message.py b/channel/wechat/wechaty_message.py index 94896f17f..cdb41ddf2 100644 --- a/channel/wechat/wechaty_message.py +++ b/channel/wechat/wechaty_message.py @@ -45,16 +45,12 @@ async def __init__(self, wechaty_msg: Message): def func(): loop = asyncio.get_event_loop() - asyncio.run_coroutine_threadsafe( - voice_file.to_file(self.content), loop - ).result() + asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result() self._prepare_fn = func else: - raise NotImplementedError( - "Unsupported message type: {}".format(wechaty_msg.type()) - ) + raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) from_contact = wechaty_msg.talker() # 获取消息的发送者 self.from_user_id = from_contact.contact_id @@ -73,9 +69,7 @@ def func(): self.to_user_id = to_contact.contact_id self.to_user_nickname = to_contact.name - if ( - self.is_group or wechaty_msg.is_self() - ): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 + if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 self.other_user_id = self.to_user_id self.other_user_nickname = self.to_user_nickname else: diff --git a/channel/wechatmp/active_reply.py b/channel/wechatmp/active_reply.py index d33f06ea2..12975a561 100644 --- a/channel/wechatmp/active_reply.py +++ b/channel/wechatmp/active_reply.py @@ -1,16 +1,17 @@ import time import web +from wechatpy import parse_message +from wechatpy.replies import create_reply -from channel.wechatmp.wechatmp_message import WeChatMPMessage from bridge.context import * from bridge.reply import * from channel.wechatmp.common import * from channel.wechatmp.wechatmp_channel import WechatMPChannel -from wechatpy import parse_message +from channel.wechatmp.wechatmp_message import WeChatMPMessage from common.log import logger from config import conf -from wechatpy.replies import create_reply + # This class is instantiated once per query class Query: @@ -50,29 +51,19 @@ def POST(self): ) ) if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): - context = channel._compose_context( - wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg - ) + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg) else: - context = channel._compose_context( - wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg - ) + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg) if context: # set private openai_api_key # if from_user is not changed in itchat, this can be placed at chat_channel user_data = conf().get_user_data(from_user) - context["openai_api_key"] = user_data.get( - "openai_api_key" - ) # None or user openai_api_key + context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key channel.produce(context) # The reply will be sent by channel.send() in another thread return "success" elif msg.type == "event": - logger.info( - "[wechatmp] Event {} from {}".format( - msg.event, msg.source - ) - ) + logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source)) if msg.event in ["subscribe", "subscribe_scan"]: reply_text = subscribe_msg() replyPost = create_reply(reply_text, msg) diff --git a/channel/wechatmp/common.py b/channel/wechatmp/common.py index 696585f92..b6f206c5a 100644 --- a/channel/wechatmp/common.py +++ b/channel/wechatmp/common.py @@ -1,10 +1,12 @@ import textwrap -import web -from config import conf -from wechatpy.utils import check_signature +import web from wechatpy.crypto import WeChatCrypto from wechatpy.exceptions import InvalidSignatureException +from wechatpy.utils import check_signature + +from config import conf + MAX_UTF8_LEN = 2048 diff --git a/channel/wechatmp/passive_reply.py b/channel/wechatmp/passive_reply.py index 6c722efab..cd0f012b8 100644 --- a/channel/wechatmp/passive_reply.py +++ b/channel/wechatmp/passive_reply.py @@ -1,17 +1,18 @@ -import time import asyncio +import time import web +from wechatpy import parse_message +from wechatpy.replies import ImageReply, VoiceReply, create_reply -from channel.wechatmp.wechatmp_message import WeChatMPMessage from bridge.context import * from bridge.reply import * from channel.wechatmp.common import * from channel.wechatmp.wechatmp_channel import WechatMPChannel +from channel.wechatmp.wechatmp_message import WeChatMPMessage from common.log import logger from config import conf -from wechatpy import parse_message -from wechatpy.replies import create_reply, ImageReply, VoiceReply + # This class is instantiated once per query class Query: @@ -49,21 +50,15 @@ def POST(self): if ( from_user not in channel.cache_dict and from_user not in channel.running - or content.startswith("#") - and message_id not in channel.request_cnt # insert the godcmd + or content.startswith("#") + and message_id not in channel.request_cnt # insert the godcmd ): # The first query begin if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): - context = channel._compose_context( - wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg - ) + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg) else: - context = channel._compose_context( - wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg - ) - logger.debug( - "[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported) - ) + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg) + logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported)) if supported and context: # set private openai_api_key @@ -94,23 +89,17 @@ def POST(self): """\ 未知错误,请稍后再试""" ) - + replyPost = create_reply(reply_text, msg) return encrypt_func(replyPost.render()) - # Wechat official server will request 3 times (5 seconds each), with the same message_id. # Because the interval is 5 seconds, here assumed that do not have multithreading problems. request_cnt = channel.request_cnt.get(message_id, 0) + 1 channel.request_cnt[message_id] = request_cnt logger.info( "[wechatmp] Request {} from {} {} {}:{}\n{}".format( - request_cnt, - from_user, - message_id, - web.ctx.env.get("REMOTE_ADDR"), - web.ctx.env.get("REMOTE_PORT"), - content + request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content ) ) @@ -130,7 +119,7 @@ def POST(self): time.sleep(2) # and do nothing, waiting for the next request return "success" - else: # request_cnt == 3: + else: # request_cnt == 3: # return timeout message reply_text = "【正在思考中,回复任意文字尝试获取回复】" replyPost = create_reply(reply_text, msg) @@ -140,10 +129,7 @@ def POST(self): channel.request_cnt.pop(message_id) # no return because of bandwords or other reasons - if ( - from_user not in channel.cache_dict - and from_user not in channel.running - ): + if from_user not in channel.cache_dict and from_user not in channel.running: return "success" # Only one request can access to the cached data @@ -152,7 +138,7 @@ def POST(self): except KeyError: return "success" - if (reply_type == "text"): + if reply_type == "text": if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN: reply_text = reply_content else: @@ -177,7 +163,7 @@ def POST(self): replyPost = create_reply(reply_text, msg) return encrypt_func(replyPost.render()) - elif (reply_type == "voice"): + elif reply_type == "voice": media_id = reply_content asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) logger.info( @@ -193,7 +179,7 @@ def POST(self): replyPost.media_id = media_id return encrypt_func(replyPost.render()) - elif (reply_type == "image"): + elif reply_type == "image": media_id = reply_content asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) logger.info( @@ -210,11 +196,7 @@ def POST(self): return encrypt_func(replyPost.render()) elif msg.type == "event": - logger.info( - "[wechatmp] Event {} from {}".format( - msg.event, msg.source - ) - ) + logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source)) if msg.event in ["subscribe", "subscribe_scan"]: reply_text = subscribe_msg() replyPost = create_reply(reply_text, msg) diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 9d63b8411..aa1fc74d7 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -1,24 +1,26 @@ # -*- coding: utf-8 -*- +import asyncio +import imghdr import io import os +import threading import time -import imghdr + import requests -import asyncio -import threading -from config import conf +import web +from wechatpy.crypto import WeChatCrypto +from wechatpy.exceptions import WeChatClientException + from bridge.context import * from bridge.reply import * -from common.log import logger -from common.singleton import singleton -from voice.audio_convert import any_to_mp3 from channel.chat_channel import ChatChannel from channel.wechatmp.common import * from channel.wechatmp.wechatmp_client import WechatMPClient -from wechatpy.exceptions import WeChatClientException -from wechatpy.crypto import WeChatCrypto +from common.log import logger +from common.singleton import singleton +from config import conf +from voice.audio_convert import any_to_mp3 -import web # If using SSL, uncomment the following lines, and modify the certificate path. # from cheroot.server import HTTPServer # from cheroot.ssl.builtin import BuiltinSSLAdapter @@ -54,7 +56,6 @@ def __init__(self, passive_reply=True): t.setDaemon(True) t.start() - def startup(self): if self.passive_reply: urls = ("/wx", "channel.wechatmp.passive_reply.Query") @@ -84,7 +85,7 @@ def send(self, reply: Reply, context: Context): elif reply.type == ReplyType.VOICE: try: voice_file_path = reply.content - with open(voice_file_path, 'rb') as f: + with open(voice_file_path, "rb") as f: # support: <2M, <60s, mp3/wma/wav/amr response = self.client.material.add("voice", f) logger.debug("[wechatmp] upload voice response: {}".format(response)) @@ -107,7 +108,7 @@ def send(self, reply: Reply, context: Context): image_storage.write(block) image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type try: response = self.client.material.add("image", (filename, image_storage, content_type)) @@ -122,7 +123,7 @@ def send(self, reply: Reply, context: Context): image_storage = reply.content image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type try: response = self.client.material.add("image", (filename, image_storage, content_type)) @@ -137,7 +138,7 @@ def send(self, reply: Reply, context: Context): if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: reply_text = reply.content texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN) - if len(texts)>1: + if len(texts) > 1: logger.info("[wechatmp] text too long, split into {} parts".format(len(texts))) for text in texts: self.client.message.send_text(receiver, text) @@ -174,7 +175,7 @@ def send(self, reply: Reply, context: Context): image_storage.write(block) image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type try: response = self.client.media.upload("image", (filename, image_storage, content_type)) @@ -188,7 +189,7 @@ def send(self, reply: Reply, context: Context): image_storage = reply.content image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + str(context['msg'].msg_id) + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type try: response = self.client.media.upload("image", (filename, image_storage, content_type)) @@ -201,20 +202,12 @@ def send(self, reply: Reply, context: Context): return def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 - logger.debug( - "[wechatmp] Success to generate reply, msgId={}".format( - context["msg"].msg_id - ) - ) + logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id)) if self.passive_reply: self.running.remove(session_id) def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 - logger.exception( - "[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format( - context["msg"].msg_id, exception - ) - ) + logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception)) if self.passive_reply: assert session_id not in self.cache_dict self.running.remove(session_id) diff --git a/channel/wechatmp/wechatmp_client.py b/channel/wechatmp/wechatmp_client.py index ee0ec845f..a87578478 100644 --- a/channel/wechatmp/wechatmp_client.py +++ b/channel/wechatmp/wechatmp_client.py @@ -1,17 +1,16 @@ -import time import threading -from channel.wechatmp.common import * +import time + from wechatpy.client import WeChatClient -from common.log import logger from wechatpy.exceptions import APILimitedException +from channel.wechatmp.common import * +from common.log import logger + class WechatMPClient(WeChatClient): - def __init__(self, appid, secret, access_token=None, - session=None, timeout=None, auto_retry=True): - super(WechatMPClient, self).__init__( - appid, secret, access_token, session, timeout, auto_retry - ) + def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True): + super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry) self.fetch_access_token_lock = threading.Lock() def clear_quota(self): @@ -20,7 +19,7 @@ def clear_quota(self): def clear_quota_v2(self): return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret}) - def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token + def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token with self.fetch_access_token_lock: access_token = self.session.get(self.access_token_key) if access_token: @@ -31,11 +30,11 @@ def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复 return access_token return super().fetch_access_token() - def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试 + def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试 try: return super()._request(method, url_or_endpoint, **kwargs) except APILimitedException as e: logger.error("[wechatmp] API quata has been used up. {}".format(e)) response = self.clear_quota_v2() logger.debug("[wechatmp] API quata has been cleard, {}".format(response)) - return super()._request(method, url_or_endpoint, **kwargs) \ No newline at end of file + return super()._request(method, url_or_endpoint, **kwargs) diff --git a/channel/wechatmp/wechatmp_message.py b/channel/wechatmp/wechatmp_message.py index fd0724344..27c9fbb85 100644 --- a/channel/wechatmp/wechatmp_message.py +++ b/channel/wechatmp/wechatmp_message.py @@ -6,7 +6,6 @@ from common.tmp_dir import TmpDir - class WeChatMPMessage(ChatMessage): def __init__(self, msg, client=None): super().__init__(msg) @@ -18,12 +17,9 @@ def __init__(self, msg, client=None): self.ctype = ContextType.TEXT self.content = msg.content elif msg.type == "voice": - if msg.recognition == None: self.ctype = ContextType.VOICE - self.content = ( - TmpDir().path() + msg.media_id + "." + msg.format - ) # content直接存临时目录路径 + self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径 def download_voice(): # 如果响应状态码是200,则将响应内容写入本地文件 @@ -32,9 +28,7 @@ def download_voice(): with open(self.content, "wb") as f: f.write(response.content) else: - logger.info( - f"[wechatmp] Failed to download voice file, {response.content}" - ) + logger.info(f"[wechatmp] Failed to download voice file, {response.content}") self._prepare_fn = download_voice else: @@ -43,6 +37,7 @@ def download_voice(): elif msg.type == "image": self.ctype = ContextType.IMAGE self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径 + def download_image(): # 如果响应状态码是200,则将响应内容写入本地文件 response = client.media.download(msg.media_id) @@ -50,15 +45,11 @@ def download_image(): with open(self.content, "wb") as f: f.write(response.content) else: - logger.info( - f"[wechatmp] Failed to download image file, {response.content}" - ) + logger.info(f"[wechatmp] Failed to download image file, {response.content}") self._prepare_fn = download_image else: - raise NotImplementedError( - "Unsupported message type: Type:{} ".format(msg.type) - ) + raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type)) self.from_user_id = msg.source self.to_user_id = msg.target diff --git a/common/time_check.py b/common/time_check.py index 808f71ab3..5c2dacba6 100644 --- a/common/time_check.py +++ b/common/time_check.py @@ -13,23 +13,15 @@ def _time_checker(self, *args, **kwargs): if chat_time_module: chat_start_time = _config.get("chat_start_time", "00:00") chat_stopt_time = _config.get("chat_stop_time", "24:00") - time_regex = re.compile( - r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$" - ) # 时间匹配,包含24:00 + time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00 starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 # 时间格式检查 - if not ( - starttime_format_check and stoptime_format_check and chat_time_check - ): - logger.warn( - "时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format( - starttime_format_check, stoptime_format_check - ) - ) + if not (starttime_format_check and stoptime_format_check and chat_time_check): + logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check)) if chat_start_time > "23:59": logger.error("启动时间可能存在问题,请修改!") diff --git a/config.py b/config.py index 4e8b60ad6..b7de15ebf 100644 --- a/config.py +++ b/config.py @@ -158,9 +158,7 @@ def load_config(): for name, value in os.environ.items(): name = name.lower() if name in available_setting: - logger.info( - "[INIT] override config by environ args: {}={}".format(name, value) - ) + logger.info("[INIT] override config by environ args: {}={}".format(name, value)) try: config[name] = eval(value) except: diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py index 4f7f75cde..118b9631c 100644 --- a/plugins/banwords/banwords.py +++ b/plugins/banwords/banwords.py @@ -50,9 +50,7 @@ def __init__(self): self.reply_action = conf.get("reply_action", "ignore") logger.info("[Banwords] inited") except Exception as e: - logger.warn( - "[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." - ) + logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .") raise e def on_handle_context(self, e_context: EventContext): @@ -72,9 +70,7 @@ def on_handle_context(self, e_context: EventContext): return elif self.action == "replace": if self.searchr.ContainsAny(content): - reply = Reply( - ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content) - ) + reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)) e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return @@ -94,9 +90,7 @@ def on_decorate_reply(self, e_context: EventContext): return elif self.reply_action == "replace": if self.searchr.ContainsAny(content): - reply = Reply( - ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content) - ) + reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)) e_context["reply"] = reply e_context.action = EventAction.CONTINUE return diff --git a/plugins/bdunit/bdunit.py b/plugins/bdunit/bdunit.py index 4523b940a..e41e8d274 100644 --- a/plugins/bdunit/bdunit.py +++ b/plugins/bdunit/bdunit.py @@ -76,9 +76,7 @@ def get_token(self): Returns: string: access_token """ - url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( - self.api_key, self.secret_key - ) + url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key) payload = "" headers = {"Content-Type": "application/json", "Accept": "application/json"} @@ -94,10 +92,7 @@ def getUnit(self, query): :returns: UNIT 解析结果。如果解析失败,返回 None """ - url = ( - "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" - + self.access_token - ) + url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token request = { "query": query, "user_id": str(get_mac())[:32], @@ -124,10 +119,7 @@ def getUnit2(self, query): :param query: 用户的指令字符串 :returns: UNIT 解析结果。如果解析失败,返回 None """ - url = ( - "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" - + self.access_token - ) + url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token request = {"query": query, "user_id": str(get_mac())[:32]} body = { "log_id": str(uuid.uuid1()), @@ -170,11 +162,7 @@ def hasIntent(self, parsed, intent): if parsed and "result" in parsed and "response_list" in parsed["result"]: response_list = parsed["result"]["response_list"] for response in response_list: - if ( - "schema" in response - and "intent" in response["schema"] - and response["schema"]["intent"] == intent - ): + if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent: return True return False else: @@ -198,12 +186,7 @@ def getSlots(self, parsed, intent=""): logger.warning(e) return [] for response in response_list: - if ( - "schema" in response - and "intent" in response["schema"] - and "slots" in response["schema"] - and response["schema"]["intent"] == intent - ): + if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent: return response["schema"]["slots"] return [] else: @@ -239,11 +222,7 @@ def getSayByConfidence(self, parsed): if ( "schema" in response and "intent_confidence" in response["schema"] - and ( - not answer - or response["schema"]["intent_confidence"] - > answer["schema"]["intent_confidence"] - ) + and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"]) ): answer = response return answer["action_list"][0]["say"] @@ -267,11 +246,7 @@ def getSay(self, parsed, intent=""): logger.warning(e) return "" for response in response_list: - if ( - "schema" in response - and "intent" in response["schema"] - and response["schema"]["intent"] == intent - ): + if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent: try: return response["action_list"][0]["say"] except Exception as e: diff --git a/plugins/dungeon/dungeon.py b/plugins/dungeon/dungeon.py index 2e3cdf1ad..5b129d606 100644 --- a/plugins/dungeon/dungeon.py +++ b/plugins/dungeon/dungeon.py @@ -84,9 +84,7 @@ def on_handle_context(self, e_context: EventContext): if len(clist) > 1: story = clist[1] else: - story = ( - "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" - ) + story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" self.games[sessionid] = StoryTeller(bot, sessionid, story) reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) e_context["reply"] = reply @@ -102,11 +100,7 @@ def get_help_text(self, **kwargs): if kwargs.get("verbose") != True: return help_text trigger_prefix = conf().get("plugin_trigger_prefix", "$") - help_text = ( - f"{trigger_prefix}开始冒险 " - + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" - + f"{trigger_prefix}停止冒险: 结束游戏。\n" - ) + help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n" if kwargs.get("verbose") == True: help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" return help_text diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 1b89df1e5..99d563290 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup): if plugins[plugin].enabled and not plugins[plugin].hidden: namecn = plugins[plugin].namecn help_text += "\n%s:" % namecn - help_text += ( - PluginManager().instances[plugin].get_help_text(verbose=False).strip() - ) + help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() if ADMIN_COMMANDS and isadmin: help_text += "\n\n管理员指令:\n" @@ -191,9 +189,7 @@ def __init__(self): COMMANDS["reset"]["alias"].append(custom_command) self.password = gconf["password"] - self.admin_users = gconf[ - "admin_users" - ] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 + self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 self.isrunning = True # 机器人是否运行中 self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context @@ -215,7 +211,7 @@ def on_handle_context(self, e_context: EventContext): reply.content = f"空指令,输入#help查看指令列表\n" e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS - return + return # msg = e_context['context']['msg'] channel = e_context["channel"] user = e_context["context"]["receiver"] @@ -248,11 +244,7 @@ def on_handle_context(self, e_context: EventContext): if not plugincls.enabled: continue if query_name == name or query_name == plugincls.namecn: - ok, result = True, PluginManager().instances[ - name - ].get_help_text( - isgroup=isgroup, isadmin=isadmin, verbose=True - ) + ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True) break if not ok: result = "插件不存在或未启用" @@ -285,11 +277,7 @@ def on_handle_context(self, e_context: EventContext): if isgroup: ok, result = False, "群聊不可执行管理员指令" else: - cmd = next( - c - for c, info in ADMIN_COMMANDS.items() - if cmd in info["alias"] - ) + cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"]) if cmd == "stop": self.isrunning = False ok, result = True, "服务已暂停" @@ -325,18 +313,14 @@ def on_handle_context(self, e_context: EventContext): PluginManager().activate_plugins() if len(new_plugins) > 0: result += "\n发现新插件:\n" - result += "\n".join( - [f"{p.name}_v{p.version}" for p in new_plugins] - ) + result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) else: result += ", 未发现新插件" elif cmd == "setpri": if len(args) != 2: ok, result = False, "请提供插件名和优先级" else: - ok = PluginManager().set_plugin_priority( - args[0], int(args[1]) - ) + ok = PluginManager().set_plugin_priority(args[0], int(args[1])) if ok: result = "插件" + args[0] + "优先级已设置为" + args[1] else: diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 254b17254..fc8fe7051 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -33,9 +33,7 @@ def on_handle_context(self, e_context: EventContext): if e_context["context"].type == ContextType.JOIN_GROUP: e_context["context"].type = ContextType.TEXT msg: ChatMessage = e_context["context"]["msg"] - e_context[ - "context" - ].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' + e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 return @@ -53,9 +51,7 @@ def on_handle_context(self, e_context: EventContext): reply.type = ReplyType.TEXT msg: ChatMessage = e_context["context"]["msg"] if e_context["context"]["isgroup"]: - reply.content = ( - f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" - ) + reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" else: reply.content = f"Hello, {msg.from_user_nickname}" e_context["reply"] = reply diff --git a/plugins/keyword/config.json.template b/plugins/keyword/config.json.template index 9a8332f3e..dbd5efe34 100644 --- a/plugins/keyword/config.json.template +++ b/plugins/keyword/config.json.template @@ -1,5 +1,5 @@ { "keyword": { - "关键字匹配": "测试成功" + "关键字匹配": "测试成功" } -} \ No newline at end of file +} diff --git a/plugins/keyword/keyword.py b/plugins/keyword/keyword.py index 376f748e9..97ebe26ac 100644 --- a/plugins/keyword/keyword.py +++ b/plugins/keyword/keyword.py @@ -41,9 +41,7 @@ def __init__(self): self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[keyword] inited.") except Exception as e: - logger.warn( - "[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ." - ) + logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .") raise e def on_handle_context(self, e_context: EventContext): diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index b014e5f06..d2ee75e39 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -31,23 +31,14 @@ def wrapper(plugincls): plugincls.desc = kwargs.get("desc") plugincls.author = kwargs.get("author") plugincls.path = self.current_plugin_path - plugincls.version = ( - kwargs.get("version") if kwargs.get("version") != None else "1.0" - ) - plugincls.namecn = ( - kwargs.get("namecn") if kwargs.get("namecn") != None else name - ) - plugincls.hidden = ( - kwargs.get("hidden") if kwargs.get("hidden") != None else False - ) + plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0" + plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name + plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False plugincls.enabled = True if self.current_plugin_path == None: raise Exception("Plugin path not set") self.plugins[name.upper()] = plugincls - logger.info( - "Plugin %s_v%s registered, path=%s" - % (name, plugincls.version, plugincls.path) - ) + logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path)) return wrapper @@ -62,9 +53,7 @@ def load_config(self): if os.path.exists("./plugins/plugins.json"): with open("./plugins/plugins.json", "r", encoding="utf-8") as f: pconf = json.load(f) - pconf["plugins"] = SortedDict( - lambda k, v: v["priority"], pconf["plugins"], reverse=True - ) + pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True) else: modified = True pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} @@ -90,26 +79,16 @@ def scan_plugins(self): if plugin_path in self.loaded: if self.loaded[plugin_path] == None: logger.info("reload module %s" % plugin_name) - self.loaded[plugin_path] = importlib.reload( - sys.modules[import_path] - ) - dependent_module_names = [ - name - for name in sys.modules.keys() - if name.startswith(import_path + ".") - ] + self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) + dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")] for name in dependent_module_names: logger.info("reload module %s" % name) importlib.reload(sys.modules[name]) else: - self.loaded[plugin_path] = importlib.import_module( - import_path - ) + self.loaded[plugin_path] = importlib.import_module(import_path) self.current_plugin_path = None except Exception as e: - logger.exception( - "Failed to import plugin %s: %s" % (plugin_name, e) - ) + logger.exception("Failed to import plugin %s: %s" % (plugin_name, e)) continue pconf = self.pconf news = [self.plugins[name] for name in self.plugins] @@ -119,9 +98,7 @@ def scan_plugins(self): rawname = plugincls.name if rawname not in pconf["plugins"]: modified = True - logger.info( - "Plugin %s not found in pconfig, adding to pconfig..." % name - ) + logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) pconf["plugins"][rawname] = { "enabled": plugincls.enabled, "priority": plugincls.priority, @@ -136,9 +113,7 @@ def scan_plugins(self): def refresh_order(self): for event in self.listening_plugins.keys(): - self.listening_plugins[event].sort( - key=lambda name: self.plugins[name].priority, reverse=True - ) + self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) def activate_plugins(self): # 生成新开启的插件实例 failed_plugins = [] @@ -184,13 +159,8 @@ def load_plugins(self): def emit_event(self, e_context: EventContext, *args, **kwargs): if e_context.event in self.listening_plugins: for name in self.listening_plugins[e_context.event]: - if ( - self.plugins[name].enabled - and e_context.action == EventAction.CONTINUE - ): - logger.debug( - "Plugin %s triggered by event %s" % (name, e_context.event) - ) + if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: + logger.debug("Plugin %s triggered by event %s" % (name, e_context.event)) instance = self.instances[name] instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context @@ -262,9 +232,7 @@ def install_plugin(self, repo: str): source = json.load(f) if repo in source["repo"]: repo = source["repo"][repo]["url"] - match = re.match( - r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo - ) + match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) if not match: return False, "安装插件失败,source中的仓库地址不合法" else: diff --git a/plugins/role/role.py b/plugins/role/role.py index 9788cc1cd..69c523347 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -69,13 +69,9 @@ def __init__(self): logger.info("[Role] inited") except Exception as e: if isinstance(e, FileNotFoundError): - logger.warn( - f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." - ) + logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") else: - logger.warn( - "[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." - ) + logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") raise e def get_role(self, name, find_closest=True, min_sim=0.35): @@ -143,9 +139,7 @@ def on_handle_context(self, e_context: EventContext): else: help_text = f"未知角色类型。\n" help_text += "目前的角色类型有: \n" - help_text += ( - ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" - ) + help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" else: help_text = f"请输入角色类型。\n" help_text += "目前的角色类型有: \n" @@ -158,9 +152,7 @@ def on_handle_context(self, e_context: EventContext): return logger.debug("[Role] on_handle_context. content: %s" % content) if desckey is not None: - if len(clist) == 1 or ( - len(clist) > 1 and clist[1].lower() in ["help", "帮助"] - ): + if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS @@ -178,9 +170,7 @@ def on_handle_context(self, e_context: EventContext): self.roles[role][desckey], self.roles[role].get("wrapper", "%s"), ) - reply = Reply( - ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey] - ) + reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]) e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS elif customize == True: @@ -199,17 +189,10 @@ def get_help_text(self, verbose=False, **kwargs): if not verbose: return help_text trigger_prefix = conf().get("plugin_trigger_prefix", "$") - help_text = ( - f"使用方法:\n{trigger_prefix}角色" - + " 预设角色名: 设定角色为{预设角色名}。\n" - + f"{trigger_prefix}role" - + " 预设角色名: 同上,但使用英文设定。\n" - ) + help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n" help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" - help_text += ( - f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" - ) + help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" help_text += "\n目前的角色类型有: \n" help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" diff --git a/plugins/tool/README.md b/plugins/tool/README.md index 3ce92da47..9f91b288a 100644 --- a/plugins/tool/README.md +++ b/plugins/tool/README.md @@ -60,7 +60,7 @@ > 该tool每天返回内容相同 -#### 6.3. finance-news +#### 6.3. finance-news ###### 获取实时的金融财政新闻 > 该工具需要解决browser tool 的google-chrome依赖安装 diff --git a/plugins/tool/tool.py b/plugins/tool/tool.py index 5a9c81a6c..e99b5188f 100644 --- a/plugins/tool/tool.py +++ b/plugins/tool/tool.py @@ -82,9 +82,7 @@ def on_handle_context(self, e_context: EventContext): return elif content_list[1].startswith("reset"): logger.debug("[tool]: remind") - e_context[ - "context" - ].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" + e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" e_context.action = EventAction.BREAK return @@ -93,18 +91,14 @@ def on_handle_context(self, e_context: EventContext): # Don't modify bot name all_sessions = Bridge().get_bot("chat").sessions - user_session = all_sessions.session_query( - query, e_context["context"]["session_id"] - ).messages + user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages # chatgpt-tool-hub will reply you with many tools logger.debug("[tool]: just-go") try: _reply = self.app.ask(query, user_session) e_context.action = EventAction.BREAK_PASS - all_sessions.session_reply( - _reply, e_context["context"]["session_id"] - ) + all_sessions.session_reply(_reply, e_context["context"]["session_id"]) except Exception as e: logger.exception(e) logger.error(str(e)) @@ -178,4 +172,4 @@ def _reset_app(self) -> App: # filter not support tool tool_list = self._filter_tool_list(tool_config.get("tools", [])) - return app.create_app(tools_list=tool_list, **app_kwargs) \ No newline at end of file + return app.create_app(tools_list=tool_list, **app_kwargs) diff --git a/voice/audio_convert.py b/voice/audio_convert.py index 77de4ed17..610170038 100644 --- a/voice/audio_convert.py +++ b/voice/audio_convert.py @@ -33,6 +33,7 @@ def get_pcm_from_wav(wav_path): wav = wave.open(wav_path, "rb") return wav.readframes(wav.getnframes()) + def any_to_mp3(any_path, mp3_path): """ 把任意格式转成mp3文件 @@ -40,16 +41,13 @@ def any_to_mp3(any_path, mp3_path): if any_path.endswith(".mp3"): shutil.copy2(any_path, mp3_path) return - if ( - any_path.endswith(".sil") - or any_path.endswith(".silk") - or any_path.endswith(".slk") - ): + if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): sil_to_wav(any_path, any_path) any_path = mp3_path audio = AudioSegment.from_file(any_path) audio.export(mp3_path, format="mp3") + def any_to_wav(any_path, wav_path): """ 把任意格式转成wav文件 @@ -57,11 +55,7 @@ def any_to_wav(any_path, wav_path): if any_path.endswith(".wav"): shutil.copy2(any_path, wav_path) return - if ( - any_path.endswith(".sil") - or any_path.endswith(".silk") - or any_path.endswith(".slk") - ): + if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): return sil_to_wav(any_path, wav_path) audio = AudioSegment.from_file(any_path) audio.export(wav_path, format="wav") @@ -71,11 +65,7 @@ def any_to_sil(any_path, sil_path): """ 把任意格式转成sil文件 """ - if ( - any_path.endswith(".sil") - or any_path.endswith(".silk") - or any_path.endswith(".slk") - ): + if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): shutil.copy2(any_path, sil_path) return 10000 audio = AudioSegment.from_file(any_path) diff --git a/voice/azure/azure_voice.py b/voice/azure/azure_voice.py index 3ee95043e..f911246a3 100644 --- a/voice/azure/azure_voice.py +++ b/voice/azure/azure_voice.py @@ -40,57 +40,33 @@ def __init__(self): config = json.load(fr) self.api_key = conf().get("azure_voice_api_key") self.api_region = conf().get("azure_voice_region") - self.speech_config = speechsdk.SpeechConfig( - subscription=self.api_key, region=self.api_region - ) - self.speech_config.speech_synthesis_voice_name = config[ - "speech_synthesis_voice_name" - ] - self.speech_config.speech_recognition_language = config[ - "speech_recognition_language" - ] + self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region) + self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"] + self.speech_config.speech_recognition_language = config["speech_recognition_language"] except Exception as e: logger.warn("AzureVoice init failed: %s, ignore " % e) def voiceToText(self, voice_file): audio_config = speechsdk.AudioConfig(filename=voice_file) - speech_recognizer = speechsdk.SpeechRecognizer( - speech_config=self.speech_config, audio_config=audio_config - ) + speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config) result = speech_recognizer.recognize_once() if result.reason == speechsdk.ResultReason.RecognizedSpeech: - logger.info( - "[Azure] voiceToText voice file name={} text={}".format( - voice_file, result.text - ) - ) + logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text)) reply = Reply(ReplyType.TEXT, result.text) else: - logger.error( - "[Azure] voiceToText error, result={}, canceldetails={}".format( - result, result.cancellation_details - ) - ) + logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details)) reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") return reply def textToVoice(self, text): fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" audio_config = speechsdk.AudioConfig(filename=fileName) - speech_synthesizer = speechsdk.SpeechSynthesizer( - speech_config=self.speech_config, audio_config=audio_config - ) + speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config) result = speech_synthesizer.speak_text(text) if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: - logger.info( - "[Azure] textToVoice text={} voice file name={}".format(text, fileName) - ) + logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName)) reply = Reply(ReplyType.VOICE, fileName) else: - logger.error( - "[Azure] textToVoice error, result={}, canceldetails={}".format( - result, result.cancellation_details - ) - ) + logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details)) reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") return reply diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index ccde6c4d1..114e4a770 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -85,9 +85,7 @@ def textToVoice(self, text): fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" with open(fileName, "wb") as f: f.write(result) - logger.info( - "[Baidu] textToVoice text={} voice file name={}".format(text, fileName) - ) + logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName)) reply = Reply(ReplyType.VOICE, fileName) else: logger.error("[Baidu] textToVoice error={}".format(result)) diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 4f7b8ade3..2cc20aab7 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -24,11 +24,7 @@ def voiceToText(self, voice_file): audio = self.recognizer.record(source) try: text = self.recognizer.recognize_google(audio, language="zh-CN") - logger.info( - "[Google] voiceToText text={} voice file name={}".format( - text, voice_file - ) - ) + logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file)) reply = Reply(ReplyType.TEXT, text) except speech_recognition.UnknownValueError: reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") @@ -42,9 +38,7 @@ def textToVoice(self, text): mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" tts = gTTS(text=text, lang="zh") tts.save(mp3File) - logger.info( - "[Google] textToVoice text={} voice file name={}".format(text, mp3File) - ) + logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File)) reply = Reply(ReplyType.VOICE, mp3File) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 06c221b21..b02d92651 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -22,11 +22,7 @@ def voiceToText(self, voice_file): result = openai.Audio.transcribe("whisper-1", file) text = result["text"] reply = Reply(ReplyType.TEXT, text) - logger.info( - "[Openai] voiceToText text={} voice file name={}".format( - text, voice_file - ) - ) + logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file)) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) finally: diff --git a/voice/pytts/pytts_voice.py b/voice/pytts/pytts_voice.py index 17cd6ff16..eb5a84928 100644 --- a/voice/pytts/pytts_voice.py +++ b/voice/pytts/pytts_voice.py @@ -5,6 +5,7 @@ import os import sys import time + import pyttsx3 from bridge.reply import Reply, ReplyType @@ -12,6 +13,7 @@ from common.tmp_dir import TmpDir from voice.voice import Voice + class PyttsVoice(Voice): engine = pyttsx3.init() @@ -20,7 +22,7 @@ def __init__(self): self.engine.setProperty("rate", 125) # 音量 self.engine.setProperty("volume", 1.0) - if sys.platform == 'win32': + if sys.platform == "win32": for voice in self.engine.getProperty("voices"): if "Chinese" in voice.name: self.engine.setProperty("voice", voice.id) @@ -33,23 +35,23 @@ def __init__(self): def textToVoice(self, text): try: # avoid the same filename - wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7fffffff) + ".wav" + wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav" wavFile = TmpDir().path() + wavFileName logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)) self.engine.save_to_file(text, wavFile) - if sys.platform == 'win32': + if sys.platform == "win32": self.engine.runAndWait() else: - # In ubuntu, runAndWait do not really wait until the file created. + # In ubuntu, runAndWait do not really wait until the file created. # It will return once the task queue is empty, but the task is still running in coroutine. # And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this. # If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file. # self.engine.runAndWait() # Before espeak fix this problem, we iterate the generator and control the waiting by ourself. - # But this is not the canonical way to use it, for example if the file already exists it also cannot wait. + # But this is not the canonical way to use it, for example if the file already exists it also cannot wait. self.engine.iterate() while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()): time.sleep(0.1)