diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..c614a6302 --- /dev/null +++ b/.flake8 @@ -0,0 +1,13 @@ +[flake8] +max-line-length = 88 +select = E303,W293,W291,W292,E305,E231,E302 +exclude = + .tox, + __pycache__, + *.pyc, + .env + venv/* + .venv/* + reports/* + dist/* + lib/* \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 4abe42801..7d697ec53 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -27,5 +27,5 @@ ### 环境 - 操作系统类型 (Mac/Windows/Linux): - - Python版本 ( 执行 `python3 -V` ): + - Python版本 ( 执行 `python3 -V` ): - pip版本 ( 依赖问题此项必填,执行 `pip3 -V`): diff --git a/.github/workflows/deploy-image.yml b/.github/workflows/deploy-image.yml index 9dc0a6337..d9f829e98 100644 --- a/.github/workflows/deploy-image.yml +++ b/.github/workflows/deploy-image.yml @@ -49,9 +49,9 @@ jobs: file: ./docker/Dockerfile.latest tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} - + - uses: actions/delete-package-versions@v4 - with: + with: package-name: 'chatgpt-on-wechat' package-type: 'container' min-versions-to-keep: 10 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..b6c6e19d5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,29 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: fix-byte-order-marker + - id: check-case-conflict + - id: check-merge-conflict + - id: debug-statements + - id: pretty-format-json + types: [text] + files: \.json(.template)?$ + args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys] + - id: trailing-whitespace + exclude: '(\/|^)lib\/' + args: [ --markdown-linebreak-ext=md ] + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + exclude: '(\/|^)lib\/' + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + exclude: '(\/|^)lib\/' + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 diff --git a/README.md b/README.md index ffbe616c2..bb86b4130 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ pip3 install azure-cognitiveservices-speech ```bash # config.json文件内容示例 -{ +{ "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY "model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 "proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 @@ -128,7 +128,7 @@ pip3 install azure-cognitiveservices-speech "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 - "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 + "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 "speech_recognition": false, # 是否开启语音识别 @@ -160,7 +160,7 @@ pip3 install azure-cognitiveservices-speech **4.其他配置** + `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放) -+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) ++ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` + 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。 @@ -181,7 +181,7 @@ pip3 install azure-cognitiveservices-speech ```bash python3 app.py ``` -终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 +终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 ### 2.服务器部署 @@ -189,7 +189,7 @@ python3 app.py 使用nohup命令在后台运行程序: ```bash -touch nohup.out # 首次运行需要新建日志文件 +touch nohup.out # 首次运行需要新建日志文件 nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码 ``` 扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。 diff --git a/app.py b/app.py index 145b9c7a9..637b6e462 100644 --- a/app.py +++ b/app.py @@ -1,23 +1,28 @@ # encoding:utf-8 import os -from config import conf, load_config +import signal +import sys + from channel import channel_factory from common.log import logger +from config import conf, load_config from plugins import * -import signal -import sys + def sigterm_handler_wrap(_signo): old_handler = signal.getsignal(_signo) + def func(_signo, _stack_frame): logger.info("signal {} received, exiting...".format(_signo)) conf().save_user_datas() - if callable(old_handler): # check old_handler + if callable(old_handler): # check old_handler return old_handler(_signo, _stack_frame) sys.exit(0) + signal.signal(_signo, func) + def run(): try: # load config @@ -28,17 +33,17 @@ def run(): sigterm_handler_wrap(signal.SIGTERM) # create channel - channel_name=conf().get('channel_type', 'wx') + channel_name = conf().get("channel_type", "wx") if "--cmd" in sys.argv: - channel_name = 'terminal' + channel_name = "terminal" - if channel_name == 'wxy': - os.environ['WECHATY_LOG']="warn" + if channel_name == "wxy": + os.environ["WECHATY_LOG"] = "warn" # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' channel = channel_factory.create_channel(channel_name) - if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']: + if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]: PluginManager().load_plugins() # startup channel @@ -47,5 +52,6 @@ def run(): logger.error("App startup failed!") logger.exception(e) -if __name__ == '__main__': - run() \ No newline at end of file + +if __name__ == "__main__": + run() diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py index 2b7dd8d2e..d8a0aca11 100644 --- a/bot/baidu/baidu_unit_bot.py +++ b/bot/baidu/baidu_unit_bot.py @@ -1,6 +1,7 @@ # encoding:utf-8 import requests + from bot.bot import Bot from bridge.reply import Reply, ReplyType @@ -9,20 +10,35 @@ class BaiduUnitBot(Bot): def reply(self, query, context=None): token = self.get_token() - url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token - post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}" + url = ( + "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + + token + ) + post_data = ( + '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' + + query + + '", "hyper_params": {"chat_custom_bot_profile": 1}}}' + ) print(post_data) - headers = {'content-type': 'application/x-www-form-urlencoded'} + headers = {"content-type": "application/x-www-form-urlencoded"} response = requests.post(url, data=post_data.encode(), headers=headers) if response: - reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) + reply = Reply( + ReplyType.TEXT, + response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1], + ) return reply def get_token(self): - access_key = 'YOUR_ACCESS_KEY' - secret_key = 'YOUR_SECRET_KEY' - host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key + access_key = "YOUR_ACCESS_KEY" + secret_key = "YOUR_SECRET_KEY" + host = ( + "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + + access_key + + "&client_secret=" + + secret_key + ) response = requests.get(host) if response: print(response.json()) - return response.json()['access_token'] + return response.json()["access_token"] diff --git a/bot/bot.py b/bot/bot.py index fd56e5022..ca6e1aa12 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -8,7 +8,7 @@ class Bot(object): - def reply(self, query, context : Context =None) -> Reply: + def reply(self, query, context: Context = None) -> Reply: """ bot auto-reply content :param req: received message diff --git a/bot/bot_factory.py b/bot/bot_factory.py index cf9cfe7ae..77797e76a 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -13,20 +13,24 @@ def create_bot(bot_type): if bot_type == const.BAIDU: # Baidu Unit对话接口 from bot.baidu.baidu_unit_bot import BaiduUnitBot + return BaiduUnitBot() elif bot_type == const.CHATGPT: # ChatGPT 网页端web接口 from bot.chatgpt.chat_gpt_bot import ChatGPTBot + return ChatGPTBot() elif bot_type == const.OPEN_AI: # OpenAI 官方对话模型API from bot.openai.open_ai_bot import OpenAIBot + return OpenAIBot() elif bot_type == const.CHATGPTONAZURE: # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot + return AzureChatGPTBot() raise RuntimeError diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index feb428f36..59acaa6b4 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -1,42 +1,53 @@ # encoding:utf-8 +import time + +import openai +import openai.error + from bot.bot import Bot from bot.chatgpt.chat_gpt_session import ChatGPTSession from bot.openai.open_ai_image import OpenAIImage from bot.session_manager import SessionManager from bridge.context import ContextType from bridge.reply import Reply, ReplyType -from config import conf, load_config from common.log import logger from common.token_bucket import TokenBucket -import openai -import openai.error -import time +from config import conf, load_config + # OpenAI对话模型API (可用) -class ChatGPTBot(Bot,OpenAIImage): +class ChatGPTBot(Bot, OpenAIImage): def __init__(self): super().__init__() # set the default api_key - openai.api_key = conf().get('open_ai_api_key') - if conf().get('open_ai_api_base'): - openai.api_base = conf().get('open_ai_api_base') - proxy = conf().get('proxy') + openai.api_key = conf().get("open_ai_api_key") + if conf().get("open_ai_api_base"): + openai.api_base = conf().get("open_ai_api_base") + proxy = conf().get("proxy") if proxy: openai.proxy = proxy - if conf().get('rate_limit_chatgpt'): - self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) - - self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo") - self.args ={ + if conf().get("rate_limit_chatgpt"): + self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) + + self.sessions = SessionManager( + ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo" + ) + self.args = { "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 - "temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 + "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 # "max_tokens":4096, # 回复最大的字符数 - "top_p":1, - "frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 - "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 + "top_p": 1, + "frequency_penalty": conf().get( + "frequency_penalty", 0.0 + ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "presence_penalty": conf().get( + "presence_penalty", 0.0 + ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "request_timeout": conf().get( + "request_timeout", None + ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 } def reply(self, query, context=None): @@ -44,39 +55,50 @@ def reply(self, query, context=None): if context.type == ContextType.TEXT: logger.info("[CHATGPT] query={}".format(query)) - - session_id = context['session_id'] + session_id = context["session_id"] reply = None - clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) + clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) if query in clear_memory_commands: self.sessions.clear_session(session_id) - reply = Reply(ReplyType.INFO, '记忆已清除') - elif query == '#清除所有': + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": self.sessions.clear_all_session() - reply = Reply(ReplyType.INFO, '所有人记忆已清除') - elif query == '#更新配置': + reply = Reply(ReplyType.INFO, "所有人记忆已清除") + elif query == "#更新配置": load_config() - reply = Reply(ReplyType.INFO, '配置已更新') + reply = Reply(ReplyType.INFO, "配置已更新") if reply: return reply session = self.sessions.session_query(query, session_id) logger.debug("[CHATGPT] session query={}".format(session.messages)) - api_key = context.get('openai_api_key') + api_key = context.get("openai_api_key") # if context.get('stream'): # # reply in stream # return self.reply_text_stream(query, new_query, session_id) reply_content = self.reply_text(session, api_key) - logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) - if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: - reply = Reply(ReplyType.ERROR, reply_content['content']) + logger.debug( + "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + session.messages, + session_id, + reply_content["content"], + reply_content["completion_tokens"], + ) + ) + if ( + reply_content["completion_tokens"] == 0 + and len(reply_content["content"]) > 0 + ): + reply = Reply(ReplyType.ERROR, reply_content["content"]) elif reply_content["completion_tokens"] > 0: - self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) + self.sessions.session_reply( + reply_content["content"], session_id, reply_content["total_tokens"] + ) reply = Reply(ReplyType.TEXT, reply_content["content"]) else: - reply = Reply(ReplyType.ERROR, reply_content['content']) + reply = Reply(ReplyType.ERROR, reply_content["content"]) logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content)) return reply @@ -89,53 +111,55 @@ def reply(self, query, context=None): reply = Reply(ReplyType.ERROR, retstring) return reply else: - reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) + reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) return reply - def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict: - ''' + def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict: + """ call openai's ChatCompletion to get the answer :param session: a conversation session :param session_id: session id :param retry_count: retry count :return: {} - ''' + """ try: - if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): + if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") # if api_key == None, the default openai.api_key will be used response = openai.ChatCompletion.create( api_key=api_key, messages=session.messages, **self.args ) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) - return {"total_tokens": response["usage"]["total_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], - "content": response.choices[0]['message']['content']} + return { + "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response["usage"]["completion_tokens"], + "content": response.choices[0]["message"]["content"], + } except Exception as e: need_retry = retry_count < 2 result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} if isinstance(e, openai.error.RateLimitError): logger.warn("[CHATGPT] RateLimitError: {}".format(e)) - result['content'] = "提问太快啦,请休息一下再问我吧" + result["content"] = "提问太快啦,请休息一下再问我吧" if need_retry: time.sleep(5) elif isinstance(e, openai.error.Timeout): logger.warn("[CHATGPT] Timeout: {}".format(e)) - result['content'] = "我没有收到你的消息" + result["content"] = "我没有收到你的消息" if need_retry: time.sleep(5) elif isinstance(e, openai.error.APIConnectionError): logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) need_retry = False - result['content'] = "我连接不到你的网络" + result["content"] = "我连接不到你的网络" else: logger.warn("[CHATGPT] Exception: {}".format(e)) need_retry = False self.sessions.clear_session(session.session_id) if need_retry: - logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) - return self.reply_text(session, api_key, retry_count+1) + logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, api_key, retry_count + 1) else: return result @@ -145,4 +169,4 @@ def __init__(self): super().__init__() openai.api_type = "azure" openai.api_version = "2023-03-15-preview" - self.args["deployment_id"] = conf().get("azure_deployment_id") \ No newline at end of file + self.args["deployment_id"] = conf().get("azure_deployment_id") diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index ed986a73e..525793ffa 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -1,20 +1,23 @@ from bot.session_manager import Session from common.log import logger -''' + +""" e.g. [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who won the world series in 2020?"}, {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, {"role": "user", "content": "Where was it played?"} ] -''' +""" + + class ChatGPTSession(Session): - def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"): + def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): super().__init__(session_id, system_prompt) self.model = model self.reset() - - def discard_exceeding(self, max_tokens, cur_tokens= None): + + def discard_exceeding(self, max_tokens, cur_tokens=None): precise = True try: cur_tokens = self.calc_tokens() @@ -22,7 +25,9 @@ def discard_exceeding(self, max_tokens, cur_tokens= None): precise = False if cur_tokens is None: raise e - logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + logger.debug( + "Exception when counting tokens precisely for query: {}".format(e) + ) while cur_tokens > max_tokens: if len(self.messages) > 2: self.messages.pop(1) @@ -34,25 +39,32 @@ def discard_exceeding(self, max_tokens, cur_tokens= None): cur_tokens = cur_tokens - max_tokens break elif len(self.messages) == 2 and self.messages[1]["role"] == "user": - logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) + logger.warn( + "user message exceed max_tokens. total_tokens={}".format(cur_tokens) + ) break else: - logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) + logger.debug( + "max_tokens={}, total_tokens={}, len(messages)={}".format( + max_tokens, cur_tokens, len(self.messages) + ) + ) break if precise: cur_tokens = self.calc_tokens() else: cur_tokens = cur_tokens - max_tokens return cur_tokens - + def calc_tokens(self): return num_tokens_from_messages(self.messages, self.model) - + # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def num_tokens_from_messages(messages, model): """Returns the number of tokens used by a list of messages.""" import tiktoken + try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -63,13 +75,17 @@ def num_tokens_from_messages(messages, model): elif model == "gpt-4": return num_tokens_from_messages(messages, model="gpt-4-0314") elif model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_message = ( + 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + ) tokens_per_name = -1 # if there's a name, the role is omitted elif model == "gpt-4-0314": tokens_per_message = 3 tokens_per_name = 1 else: - logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.") + logger.warn( + f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301." + ) return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") num_tokens = 0 for message in messages: diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index 757f36a6f..8a56cf18f 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -1,41 +1,52 @@ # encoding:utf-8 +import time + +import openai +import openai.error + from bot.bot import Bot from bot.openai.open_ai_image import OpenAIImage from bot.openai.open_ai_session import OpenAISession from bot.session_manager import SessionManager from bridge.context import ContextType from bridge.reply import Reply, ReplyType -from config import conf from common.log import logger -import openai -import openai.error -import time +from config import conf user_session = dict() + # OpenAI对话模型API (可用) class OpenAIBot(Bot, OpenAIImage): def __init__(self): super().__init__() - openai.api_key = conf().get('open_ai_api_key') - if conf().get('open_ai_api_base'): - openai.api_base = conf().get('open_ai_api_base') - proxy = conf().get('proxy') + openai.api_key = conf().get("open_ai_api_key") + if conf().get("open_ai_api_base"): + openai.api_base = conf().get("open_ai_api_base") + proxy = conf().get("proxy") if proxy: openai.proxy = proxy - self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003") + self.sessions = SessionManager( + OpenAISession, model=conf().get("model") or "text-davinci-003" + ) self.args = { "model": conf().get("model") or "text-davinci-003", # 对话模型的名称 - "temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 - "max_tokens":1200, # 回复最大的字符数 - "top_p":1, - "frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 - "timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 - "stop":["\n\n\n"] + "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 + "max_tokens": 1200, # 回复最大的字符数 + "top_p": 1, + "frequency_penalty": conf().get( + "frequency_penalty", 0.0 + ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "presence_penalty": conf().get( + "presence_penalty", 0.0 + ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "request_timeout": conf().get( + "request_timeout", None + ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 + "stop": ["\n\n\n"], } def reply(self, query, context=None): @@ -43,24 +54,34 @@ def reply(self, query, context=None): if context and context.type: if context.type == ContextType.TEXT: logger.info("[OPEN_AI] query={}".format(query)) - session_id = context['session_id'] + session_id = context["session_id"] reply = None - if query == '#清除记忆': + if query == "#清除记忆": self.sessions.clear_session(session_id) - reply = Reply(ReplyType.INFO, '记忆已清除') - elif query == '#清除所有': + reply = Reply(ReplyType.INFO, "记忆已清除") + elif query == "#清除所有": self.sessions.clear_all_session() - reply = Reply(ReplyType.INFO, '所有人记忆已清除') + reply = Reply(ReplyType.INFO, "所有人记忆已清除") else: session = self.sessions.session_query(query, session_id) result = self.reply_text(session) - total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] - logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)) + total_tokens, completion_tokens, reply_content = ( + result["total_tokens"], + result["completion_tokens"], + result["content"], + ) + logger.debug( + "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( + str(session), session_id, reply_content, completion_tokens + ) + ) - if total_tokens == 0 : + if total_tokens == 0: reply = Reply(ReplyType.ERROR, reply_content) else: - self.sessions.session_reply(reply_content, session_id, total_tokens) + self.sessions.session_reply( + reply_content, session_id, total_tokens + ) reply = Reply(ReplyType.TEXT, reply_content) return reply elif context.type == ContextType.IMAGE_CREATE: @@ -72,42 +93,44 @@ def reply(self, query, context=None): reply = Reply(ReplyType.ERROR, retstring) return reply - def reply_text(self, session:OpenAISession, retry_count=0): + def reply_text(self, session: OpenAISession, retry_count=0): try: - response = openai.Completion.create( - prompt=str(session), **self.args + response = openai.Completion.create(prompt=str(session), **self.args) + res_content = ( + response.choices[0]["text"].strip().replace("<|endoftext|>", "") ) - res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '') total_tokens = response["usage"]["total_tokens"] completion_tokens = response["usage"]["completion_tokens"] logger.info("[OPEN_AI] reply={}".format(res_content)) - return {"total_tokens": total_tokens, - "completion_tokens": completion_tokens, - "content": res_content} + return { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "content": res_content, + } except Exception as e: need_retry = retry_count < 2 result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} if isinstance(e, openai.error.RateLimitError): logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) - result['content'] = "提问太快啦,请休息一下再问我吧" + result["content"] = "提问太快啦,请休息一下再问我吧" if need_retry: time.sleep(5) elif isinstance(e, openai.error.Timeout): logger.warn("[OPEN_AI] Timeout: {}".format(e)) - result['content'] = "我没有收到你的消息" + result["content"] = "我没有收到你的消息" if need_retry: time.sleep(5) elif isinstance(e, openai.error.APIConnectionError): logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) need_retry = False - result['content'] = "我连接不到你的网络" + result["content"] = "我连接不到你的网络" else: logger.warn("[OPEN_AI] Exception: {}".format(e)) need_retry = False self.sessions.clear_session(session.session_id) if need_retry: - logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1)) - return self.reply_text(session, retry_count+1) + logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, retry_count + 1) else: - return result \ No newline at end of file + return result diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py index 2fae243ff..5dbbd23ed 100644 --- a/bot/openai/open_ai_image.py +++ b/bot/openai/open_ai_image.py @@ -1,38 +1,47 @@ import time + import openai import openai.error -from common.token_bucket import TokenBucket + from common.log import logger +from common.token_bucket import TokenBucket from config import conf + # OPENAI提供的画图接口 class OpenAIImage(object): def __init__(self): - openai.api_key = conf().get('open_ai_api_key') - if conf().get('rate_limit_dalle'): - self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) - + openai.api_key = conf().get("open_ai_api_key") + if conf().get("rate_limit_dalle"): + self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) + def create_img(self, query, retry_count=0): try: - if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): + if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): return False, "请求太快了,请休息一下再问我吧" logger.info("[OPEN_AI] image_query={}".format(query)) response = openai.Image.create( - prompt=query, #图片描述 - n=1, #每次生成图片的数量 - size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 + prompt=query, # 图片描述 + n=1, # 每次生成图片的数量 + size=conf().get( + "image_create_size", "256x256" + ), # 图片大小,可选有 256x256, 512x512, 1024x1024 ) - image_url = response['data'][0]['url'] + image_url = response["data"][0]["url"] logger.info("[OPEN_AI] image_url={}".format(image_url)) return True, image_url except openai.error.RateLimitError as e: logger.warn(e) if retry_count < 1: time.sleep(5) - logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) - return self.create_img(query, retry_count+1) + logger.warn( + "[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format( + retry_count + 1 + ) + ) + return self.create_img(query, retry_count + 1) else: return False, "提问太快啦,请休息一下再问我吧" except Exception as e: logger.exception(e) - return False, str(e) \ No newline at end of file + return False, str(e) diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py index 28dd7eca1..78cf43900 100644 --- a/bot/openai/open_ai_session.py +++ b/bot/openai/open_ai_session.py @@ -1,32 +1,34 @@ from bot.session_manager import Session from common.log import logger + + class OpenAISession(Session): - def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"): + def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): super().__init__(session_id, system_prompt) self.model = model self.reset() def __str__(self): # 构造对话模型的输入 - ''' + """ e.g. Q: xxx A: xxx Q: xxx - ''' + """ prompt = "" for item in self.messages: - if item['role'] == 'system': - prompt += item['content'] + "<|endoftext|>\n\n\n" - elif item['role'] == 'user': - prompt += "Q: " + item['content'] + "\n" - elif item['role'] == 'assistant': - prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n" + if item["role"] == "system": + prompt += item["content"] + "<|endoftext|>\n\n\n" + elif item["role"] == "user": + prompt += "Q: " + item["content"] + "\n" + elif item["role"] == "assistant": + prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" - if len(self.messages) > 0 and self.messages[-1]['role'] == 'user': + if len(self.messages) > 0 and self.messages[-1]["role"] == "user": prompt += "A: " return prompt - def discard_exceeding(self, max_tokens, cur_tokens= None): + def discard_exceeding(self, max_tokens, cur_tokens=None): precise = True try: cur_tokens = self.calc_tokens() @@ -34,7 +36,9 @@ def discard_exceeding(self, max_tokens, cur_tokens= None): precise = False if cur_tokens is None: raise e - logger.debug("Exception when counting tokens precisely for query: {}".format(e)) + logger.debug( + "Exception when counting tokens precisely for query: {}".format(e) + ) while cur_tokens > max_tokens: if len(self.messages) > 1: self.messages.pop(0) @@ -46,24 +50,34 @@ def discard_exceeding(self, max_tokens, cur_tokens= None): cur_tokens = len(str(self)) break elif len(self.messages) == 1 and self.messages[0]["role"] == "user": - logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) + logger.warn( + "user question exceed max_tokens. total_tokens={}".format( + cur_tokens + ) + ) break else: - logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) + logger.debug( + "max_tokens={}, total_tokens={}, len(conversation)={}".format( + max_tokens, cur_tokens, len(self.messages) + ) + ) break if precise: cur_tokens = self.calc_tokens() else: cur_tokens = len(str(self)) return cur_tokens - + def calc_tokens(self): return num_tokens_from_string(str(self), self.model) + # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def num_tokens_from_string(string: str, model: str) -> int: """Returns the number of tokens in a text string.""" import tiktoken + encoding = tiktoken.encoding_for_model(model) - num_tokens = len(encoding.encode(string,disallowed_special=())) - return num_tokens \ No newline at end of file + num_tokens = len(encoding.encode(string, disallowed_special=())) + return num_tokens diff --git a/bot/session_manager.py b/bot/session_manager.py index ce189ed94..1aff647b5 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -2,6 +2,7 @@ from common.log import logger from config import conf + class Session(object): def __init__(self, session_id, system_prompt=None): self.session_id = session_id @@ -13,7 +14,7 @@ def __init__(self, session_id, system_prompt=None): # 重置会话 def reset(self): - system_item = {'role': 'system', 'content': self.system_prompt} + system_item = {"role": "system", "content": self.system_prompt} self.messages = [system_item] def set_system_prompt(self, system_prompt): @@ -21,13 +22,13 @@ def set_system_prompt(self, system_prompt): self.reset() def add_query(self, query): - user_item = {'role': 'user', 'content': query} + user_item = {"role": "user", "content": query} self.messages.append(user_item) def add_reply(self, reply): - assistant_item = {'role': 'assistant', 'content': reply} + assistant_item = {"role": "assistant", "content": reply} self.messages.append(assistant_item) - + def discard_exceeding(self, max_tokens=None, cur_tokens=None): raise NotImplementedError @@ -37,8 +38,8 @@ def calc_tokens(self): class SessionManager(object): def __init__(self, sessioncls, **session_args): - if conf().get('expires_in_seconds'): - sessions = ExpiredDict(conf().get('expires_in_seconds')) + if conf().get("expires_in_seconds"): + sessions = ExpiredDict(conf().get("expires_in_seconds")) else: sessions = dict() self.sessions = sessions @@ -46,20 +47,22 @@ def __init__(self, sessioncls, **session_args): self.session_args = session_args def build_session(self, session_id, system_prompt=None): - ''' - 如果session_id不在sessions中,创建一个新的session并添加到sessions中 - 如果system_prompt不会空,会更新session的system_prompt并重置session - ''' + """ + 如果session_id不在sessions中,创建一个新的session并添加到sessions中 + 如果system_prompt不会空,会更新session的system_prompt并重置session + """ if session_id is None: return self.sessioncls(session_id, system_prompt, **self.session_args) - + if session_id not in self.sessions: - self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) + self.sessions[session_id] = self.sessioncls( + session_id, system_prompt, **self.session_args + ) elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session self.sessions[session_id].set_system_prompt(system_prompt) session = self.sessions[session_id] return session - + def session_query(self, query, session_id): session = self.build_session(session_id) session.add_query(query) @@ -68,23 +71,33 @@ def session_query(self, query, session_id): total_tokens = session.discard_exceeding(max_tokens, None) logger.debug("prompt tokens used={}".format(total_tokens)) except Exception as e: - logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) + logger.debug( + "Exception when counting tokens precisely for prompt: {}".format(str(e)) + ) return session - def session_reply(self, reply, session_id, total_tokens = None): + def session_reply(self, reply, session_id, total_tokens=None): session = self.build_session(session_id) session.add_reply(reply) try: max_tokens = conf().get("conversation_max_tokens", 1000) tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) - logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) + logger.debug( + "raw total_tokens={}, savesession tokens={}".format( + total_tokens, tokens_cnt + ) + ) except Exception as e: - logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) + logger.debug( + "Exception when counting tokens precisely for session: {}".format( + str(e) + ) + ) return session def clear_session(self, session_id): if session_id in self.sessions: - del(self.sessions[session_id]) + del self.sessions[session_id] def clear_all_session(self): self.sessions.clear() diff --git a/bridge/bridge.py b/bridge/bridge.py index a19fc8379..dcf6e7e07 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,31 +1,31 @@ +from bot import bot_factory from bridge.context import Context from bridge.reply import Reply +from common import const from common.log import logger -from bot import bot_factory from common.singleton import singleton -from voice import voice_factory from config import conf -from common import const +from voice import voice_factory @singleton class Bridge(object): def __init__(self): - self.btype={ + self.btype = { "chat": const.CHATGPT, "voice_to_text": conf().get("voice_to_text", "openai"), - "text_to_voice": conf().get("text_to_voice", "google") + "text_to_voice": conf().get("text_to_voice", "google"), } model_type = conf().get("model") if model_type in ["text-davinci-003"]: - self.btype['chat'] = const.OPEN_AI + self.btype["chat"] = const.OPEN_AI if conf().get("use_azure_chatgpt", False): - self.btype['chat'] = const.CHATGPTONAZURE - self.bots={} + self.btype["chat"] = const.CHATGPTONAZURE + self.bots = {} - def get_bot(self,typename): + def get_bot(self, typename): if self.bots.get(typename) is None: - logger.info("create bot {} for {}".format(self.btype[typename],typename)) + logger.info("create bot {} for {}".format(self.btype[typename], typename)) if typename == "text_to_voice": self.bots[typename] = voice_factory.create_voice(self.btype[typename]) elif typename == "voice_to_text": @@ -33,18 +33,15 @@ def get_bot(self,typename): elif typename == "chat": self.bots[typename] = bot_factory.create_bot(self.btype[typename]) return self.bots[typename] - - def get_bot_type(self,typename): - return self.btype[typename] + def get_bot_type(self, typename): + return self.btype[typename] - def fetch_reply_content(self, query, context : Context) -> Reply: + def fetch_reply_content(self, query, context: Context) -> Reply: return self.get_bot("chat").reply(query, context) - def fetch_voice_to_text(self, voiceFile) -> Reply: return self.get_bot("voice_to_text").voiceToText(voiceFile) def fetch_text_to_voice(self, text) -> Reply: return self.get_bot("text_to_voice").textToVoice(text) - diff --git a/bridge/context.py b/bridge/context.py index 728333326..7a857b3b6 100644 --- a/bridge/context.py +++ b/bridge/context.py @@ -2,36 +2,39 @@ from enum import Enum -class ContextType (Enum): - TEXT = 1 # 文本消息 - VOICE = 2 # 音频消息 - IMAGE = 3 # 图片消息 - IMAGE_CREATE = 10 # 创建图片命令 - + +class ContextType(Enum): + TEXT = 1 # 文本消息 + VOICE = 2 # 音频消息 + IMAGE = 3 # 图片消息 + IMAGE_CREATE = 10 # 创建图片命令 + def __str__(self): return self.name + + class Context: - def __init__(self, type : ContextType = None , content = None, kwargs = dict()): + def __init__(self, type: ContextType = None, content=None, kwargs=dict()): self.type = type self.content = content self.kwargs = kwargs def __contains__(self, key): - if key == 'type': + if key == "type": return self.type is not None - elif key == 'content': + elif key == "content": return self.content is not None else: return key in self.kwargs - + def __getitem__(self, key): - if key == 'type': + if key == "type": return self.type - elif key == 'content': + elif key == "content": return self.content else: return self.kwargs[key] - + def get(self, key, default=None): try: return self[key] @@ -39,20 +42,22 @@ def get(self, key, default=None): return default def __setitem__(self, key, value): - if key == 'type': + if key == "type": self.type = value - elif key == 'content': + elif key == "content": self.content = value else: self.kwargs[key] = value def __delitem__(self, key): - if key == 'type': + if key == "type": self.type = None - elif key == 'content': + elif key == "content": self.content = None else: del self.kwargs[key] - + def __str__(self): - return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) \ No newline at end of file + return "Context(type={}, content={}, kwargs={})".format( + self.type, self.content, self.kwargs + ) diff --git a/bridge/reply.py b/bridge/reply.py index c6bcd5465..d9d6703a4 100644 --- a/bridge/reply.py +++ b/bridge/reply.py @@ -1,22 +1,25 @@ - # encoding:utf-8 from enum import Enum + class ReplyType(Enum): - TEXT = 1 # 文本 - VOICE = 2 # 音频文件 - IMAGE = 3 # 图片文件 - IMAGE_URL = 4 # 图片URL - + TEXT = 1 # 文本 + VOICE = 2 # 音频文件 + IMAGE = 3 # 图片文件 + IMAGE_URL = 4 # 图片URL + INFO = 9 ERROR = 10 + def __str__(self): return self.name + class Reply: - def __init__(self, type : ReplyType = None , content = None): + def __init__(self, type: ReplyType = None, content=None): self.type = type self.content = content + def __str__(self): - return "Reply(type={}, content={})".format(self.type, self.content) \ No newline at end of file + return "Reply(type={}, content={})".format(self.type, self.content) diff --git a/channel/channel.py b/channel/channel.py index 01e20d617..6464d771e 100644 --- a/channel/channel.py +++ b/channel/channel.py @@ -6,8 +6,10 @@ from bridge.context import Context from bridge.reply import * + class Channel(object): NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] + def startup(self): """ init channel @@ -27,15 +29,15 @@ def send(self, reply: Reply, context: Context): send message to user :param msg: message content :param receiver: receiver channel account - :return: + :return: """ raise NotImplementedError - def build_reply_content(self, query, context : Context=None) -> Reply: + def build_reply_content(self, query, context: Context = None) -> Reply: return Bridge().fetch_reply_content(query, context) def build_voice_to_text(self, voice_file) -> Reply: return Bridge().fetch_voice_to_text(voice_file) - + def build_text_to_voice(self, text) -> Reply: return Bridge().fetch_text_to_voice(text) diff --git a/channel/channel_factory.py b/channel/channel_factory.py index e272206eb..ebd973254 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -2,25 +2,31 @@ channel factory """ + def create_channel(channel_type): """ create a channel instance :param channel_type: channel type code :return: channel instance """ - if channel_type == 'wx': + if channel_type == "wx": from channel.wechat.wechat_channel import WechatChannel + return WechatChannel() - elif channel_type == 'wxy': + elif channel_type == "wxy": from channel.wechat.wechaty_channel import WechatyChannel + return WechatyChannel() - elif channel_type == 'terminal': + elif channel_type == "terminal": from channel.terminal.terminal_channel import TerminalChannel + return TerminalChannel() - elif channel_type == 'wechatmp': + elif channel_type == "wechatmp": from channel.wechatmp.wechatmp_channel import WechatMPChannel - return WechatMPChannel(passive_reply = True) - elif channel_type == 'wechatmp_service': + + return WechatMPChannel(passive_reply=True) + elif channel_type == "wechatmp_service": from channel.wechatmp.wechatmp_channel import WechatMPChannel - return WechatMPChannel(passive_reply = False) + + return WechatMPChannel(passive_reply=False) raise RuntimeError diff --git a/channel/chat_channel.py b/channel/chat_channel.py index a89d0adfd..89dfe145a 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -1,137 +1,172 @@ - - -from asyncio import CancelledError -from concurrent.futures import Future, ThreadPoolExecutor import os import re import threading import time -from common.dequeue import Dequeue -from channel.channel import Channel -from bridge.reply import * +from asyncio import CancelledError +from concurrent.futures import Future, ThreadPoolExecutor + from bridge.context import * -from config import conf +from bridge.reply import * +from channel.channel import Channel +from common.dequeue import Dequeue from common.log import logger +from config import conf from plugins import * + try: from voice.audio_convert import any_to_wav except Exception as e: pass + # 抽象类, 它包含了与消息通道无关的通用处理逻辑 class ChatChannel(Channel): - name = None # 登录的用户名 - user_id = None # 登录的用户id - futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 - sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 - lock = threading.Lock() # 用于控制对sessions的访问 + name = None # 登录的用户名 + user_id = None # 登录的用户id + futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 + sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 + lock = threading.Lock() # 用于控制对sessions的访问 handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 def __init__(self): _thread = threading.Thread(target=self.consume) _thread.setDaemon(True) _thread.start() - # 根据消息构造context,消息内容相关的触发项写在这里 def _compose_context(self, ctype: ContextType, content, **kwargs): context = Context(ctype, content) context.kwargs = kwargs - # context首次传入时,origin_ctype是None, + # context首次传入时,origin_ctype是None, # 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。 # origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀 - if 'origin_ctype' not in context: - context['origin_ctype'] = ctype + if "origin_ctype" not in context: + context["origin_ctype"] = ctype # context首次传入时,receiver是None,根据类型设置receiver - first_in = 'receiver' not in context + first_in = "receiver" not in context # 群名匹配过程,设置session_id和receiver - if first_in: # context首次传入时,receiver是None,根据类型设置receiver + if first_in: # context首次传入时,receiver是None,根据类型设置receiver config = conf() - cmsg = context['msg'] + cmsg = context["msg"] if context.get("isgroup", False): group_name = cmsg.other_user_nickname group_id = cmsg.other_user_id - group_name_white_list = config.get('group_name_white_list', []) - group_name_keyword_white_list = config.get('group_name_keyword_white_list', []) - if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]): - group_chat_in_one_session = conf().get('group_chat_in_one_session', []) + group_name_white_list = config.get("group_name_white_list", []) + group_name_keyword_white_list = config.get( + "group_name_keyword_white_list", [] + ) + if any( + [ + group_name in group_name_white_list, + "ALL_GROUP" in group_name_white_list, + check_contain(group_name, group_name_keyword_white_list), + ] + ): + group_chat_in_one_session = conf().get( + "group_chat_in_one_session", [] + ) session_id = cmsg.actual_user_id - if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]): + if any( + [ + group_name in group_chat_in_one_session, + "ALL_GROUP" in group_chat_in_one_session, + ] + ): session_id = group_id else: return None - context['session_id'] = session_id - context['receiver'] = group_id + context["session_id"] = session_id + context["receiver"] = group_id else: - context['session_id'] = cmsg.other_user_id - context['receiver'] = cmsg.other_user_id - e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {'channel': self, 'context': context})) - context = e_context['context'] + context["session_id"] = cmsg.other_user_id + context["receiver"] = cmsg.other_user_id + e_context = PluginManager().emit_event( + EventContext( + Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context} + ) + ) + context = e_context["context"] if e_context.is_pass() or context is None: return context - if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True): + if cmsg.from_user_id == self.user_id and not config.get( + "trigger_by_self", True + ): logger.debug("[WX]self message skipped") return None # 消息内容匹配过程,并处理content if ctype == ContextType.TEXT: - if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 + if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 logger.debug("[WX]reference query skipped") return None - - if context.get("isgroup", False): # 群聊 + + if context.get("isgroup", False): # 群聊 # 校验关键字 - match_prefix = check_prefix(content, conf().get('group_chat_prefix')) - match_contain = check_contain(content, conf().get('group_chat_keyword')) + match_prefix = check_prefix(content, conf().get("group_chat_prefix")) + match_contain = check_contain(content, conf().get("group_chat_keyword")) flag = False if match_prefix is not None or match_contain is not None: flag = True if match_prefix: - content = content.replace(match_prefix, '', 1).strip() - if context['msg'].is_at: + content = content.replace(match_prefix, "", 1).strip() + if context["msg"].is_at: logger.info("[WX]receive group at") if not conf().get("group_at_off", False): flag = True - pattern = f'@{self.name}(\u2005|\u0020)' - content = re.sub(pattern, r'', content) - + pattern = f"@{self.name}(\u2005|\u0020)" + content = re.sub(pattern, r"", content) + if not flag: if context["origin_ctype"] == ContextType.VOICE: - logger.info("[WX]receive group voice, but checkprefix didn't match") + logger.info( + "[WX]receive group voice, but checkprefix didn't match" + ) return None - else: # 单聊 - match_prefix = check_prefix(content, conf().get('single_chat_prefix',[''])) - if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 - content = content.replace(match_prefix, '', 1).strip() - elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 + else: # 单聊 + match_prefix = check_prefix( + content, conf().get("single_chat_prefix", [""]) + ) + if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 + content = content.replace(match_prefix, "", 1).strip() + elif ( + context["origin_ctype"] == ContextType.VOICE + ): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 pass else: - return None - - img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) + return None + + img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) if img_match_prefix: - content = content.replace(img_match_prefix, '', 1) + content = content.replace(img_match_prefix, "", 1) context.type = ContextType.IMAGE_CREATE else: context.type = ContextType.TEXT context.content = content.strip() - if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: - context['desire_rtype'] = ReplyType.VOICE - elif context.type == ContextType.VOICE: - if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: - context['desire_rtype'] = ReplyType.VOICE + if ( + "desire_rtype" not in context + and conf().get("always_reply_voice") + and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE + ): + context["desire_rtype"] = ReplyType.VOICE + elif context.type == ContextType.VOICE: + if ( + "desire_rtype" not in context + and conf().get("voice_reply_voice") + and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE + ): + context["desire_rtype"] = ReplyType.VOICE return context def _handle(self, context: Context): if context is None or not context.content: return - logger.debug('[WX] ready to handle context: {}'.format(context)) + logger.debug("[WX] ready to handle context: {}".format(context)) # reply的构建步骤 reply = self._generate_reply(context) - logger.debug('[WX] ready to decorate reply: {}'.format(reply)) + logger.debug("[WX] ready to decorate reply: {}".format(reply)) # reply的包装步骤 reply = self._decorate_reply(context, reply) @@ -139,20 +174,31 @@ def _handle(self, context: Context): self._send_reply(context, reply) def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: - e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { - 'channel': self, 'context': context, 'reply': reply})) - reply = e_context['reply'] + e_context = PluginManager().emit_event( + EventContext( + Event.ON_HANDLE_CONTEXT, + {"channel": self, "context": context, "reply": reply}, + ) + ) + reply = e_context["reply"] if not e_context.is_pass(): - logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) - if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 + logger.debug( + "[WX] ready to handle context: type={}, content={}".format( + context.type, context.content + ) + ) + if ( + context.type == ContextType.TEXT + or context.type == ContextType.IMAGE_CREATE + ): # 文字和图片消息 reply = super().build_reply_content(context.content, context) elif context.type == ContextType.VOICE: # 语音消息 - cmsg = context['msg'] + cmsg = context["msg"] cmsg.prepare() file_path = context.content - wav_path = os.path.splitext(file_path)[0] + '.wav' + wav_path = os.path.splitext(file_path)[0] + ".wav" try: - any_to_wav(file_path, wav_path) + any_to_wav(file_path, wav_path) except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 logger.warning("[WX]any to wav error, use raw path. " + str(e)) wav_path = file_path @@ -169,7 +215,8 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: if reply.type == ReplyType.TEXT: new_context = self._compose_context( - ContextType.TEXT, reply.content, **context.kwargs) + ContextType.TEXT, reply.content, **context.kwargs + ) if new_context: reply = self._generate_reply(new_context) else: @@ -177,18 +224,21 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑 pass else: - logger.error('[WX] unknown context type: {}'.format(context.type)) + logger.error("[WX] unknown context type: {}".format(context.type)) return return reply def _decorate_reply(self, context: Context, reply: Reply) -> Reply: if reply and reply.type: - e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { - 'channel': self, 'context': context, 'reply': reply})) - reply = e_context['reply'] - desire_rtype = context.get('desire_rtype') + e_context = PluginManager().emit_event( + EventContext( + Event.ON_DECORATE_REPLY, + {"channel": self, "context": context, "reply": reply}, + ) + ) + reply = e_context["reply"] + desire_rtype = context.get("desire_rtype") if not e_context.is_pass() and reply and reply.type: - if reply.type in self.NOT_SUPPORT_REPLYTYPE: logger.error("[WX]reply type not support: " + str(reply.type)) reply.type = ReplyType.ERROR @@ -196,59 +246,91 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply: if reply.type == ReplyType.TEXT: reply_text = reply.content - if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: + if ( + desire_rtype == ReplyType.VOICE + and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE + ): reply = super().build_text_to_voice(reply.content) return self._decorate_reply(context, reply) if context.get("isgroup", False): - reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip() - reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + reply_text = ( + "@" + + context["msg"].actual_user_nickname + + " " + + reply_text.strip() + ) + reply_text = ( + conf().get("group_chat_reply_prefix", "") + reply_text + ) else: - reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + reply_text = ( + conf().get("single_chat_reply_prefix", "") + reply_text + ) reply.content = reply_text elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: - reply.content = "["+str(reply.type)+"]\n" + reply.content - elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: + reply.content = "[" + str(reply.type) + "]\n" + reply.content + elif ( + reply.type == ReplyType.IMAGE_URL + or reply.type == ReplyType.VOICE + or reply.type == ReplyType.IMAGE + ): pass else: - logger.error('[WX] unknown reply type: {}'.format(reply.type)) + logger.error("[WX] unknown reply type: {}".format(reply.type)) return - if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: - logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type)) + if ( + desire_rtype + and desire_rtype != reply.type + and reply.type not in [ReplyType.ERROR, ReplyType.INFO] + ): + logger.warning( + "[WX] desire_rtype: {}, but reply type: {}".format( + context.get("desire_rtype"), reply.type + ) + ) return reply def _send_reply(self, context: Context, reply: Reply): if reply and reply.type: - e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { - 'channel': self, 'context': context, 'reply': reply})) - reply = e_context['reply'] + e_context = PluginManager().emit_event( + EventContext( + Event.ON_SEND_REPLY, + {"channel": self, "context": context, "reply": reply}, + ) + ) + reply = e_context["reply"] if not e_context.is_pass() and reply and reply.type: - logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context)) + logger.debug( + "[WX] ready to send reply: {}, context: {}".format(reply, context) + ) self._send(reply, context) - def _send(self, reply: Reply, context: Context, retry_cnt = 0): + def _send(self, reply: Reply, context: Context, retry_cnt=0): try: self.send(reply, context) except Exception as e: - logger.error('[WX] sendMsg error: {}'.format(str(e))) + logger.error("[WX] sendMsg error: {}".format(str(e))) if isinstance(e, NotImplementedError): return logger.exception(e) if retry_cnt < 2: - time.sleep(3+3*retry_cnt) - self._send(reply, context, retry_cnt+1) + time.sleep(3 + 3 * retry_cnt) + self._send(reply, context, retry_cnt + 1) - def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数 + def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 logger.debug("Worker return success, session_id = {}".format(session_id)) - def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 + def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 logger.exception("Worker return exception: {}".format(exception)) def _thread_pool_callback(self, session_id, **kwargs): - def func(worker:Future): + def func(worker: Future): try: worker_exception = worker.exception() if worker_exception: - self._fail_callback(session_id, exception = worker_exception, **kwargs) + self._fail_callback( + session_id, exception=worker_exception, **kwargs + ) else: self._success_callback(session_id, **kwargs) except CancelledError as e: @@ -257,15 +339,19 @@ def func(worker:Future): logger.exception("Worker raise exception: {}".format(e)) with self.lock: self.sessions[session_id][1].release() + return func def produce(self, context: Context): - session_id = context['session_id'] + session_id = context["session_id"] with self.lock: if session_id not in self.sessions: - self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))] - if context.type == ContextType.TEXT and context.content.startswith("#"): - self.sessions[session_id][0].putleft(context) # 优先处理管理命令 + self.sessions[session_id] = [ + Dequeue(), + threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)), + ] + if context.type == ContextType.TEXT and context.content.startswith("#"): + self.sessions[session_id][0].putleft(context) # 优先处理管理命令 else: self.sessions[session_id][0].put(context) @@ -276,44 +362,58 @@ def consume(self): session_ids = list(self.sessions.keys()) for session_id in session_ids: context_queue, semaphore = self.sessions[session_id] - if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除 + if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除 if not context_queue.empty(): context = context_queue.get() logger.debug("[WX] consume context: {}".format(context)) - future:Future = self.handler_pool.submit(self._handle, context) - future.add_done_callback(self._thread_pool_callback(session_id, context = context)) + future: Future = self.handler_pool.submit( + self._handle, context + ) + future.add_done_callback( + self._thread_pool_callback(session_id, context=context) + ) if session_id not in self.futures: self.futures[session_id] = [] self.futures[session_id].append(future) - elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 - self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] - assert len(self.futures[session_id]) == 0, "thread pool error" + elif ( + semaphore._initial_value == semaphore._value + 1 + ): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 + self.futures[session_id] = [ + t for t in self.futures[session_id] if not t.done() + ] + assert ( + len(self.futures[session_id]) == 0 + ), "thread pool error" del self.sessions[session_id] else: semaphore.release() time.sleep(0.1) # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务 - def cancel_session(self, session_id): + def cancel_session(self, session_id): with self.lock: if session_id in self.sessions: for future in self.futures[session_id]: future.cancel() cnt = self.sessions[session_id][0].qsize() - if cnt>0: - logger.info("Cancel {} messages in session {}".format(cnt, session_id)) + if cnt > 0: + logger.info( + "Cancel {} messages in session {}".format(cnt, session_id) + ) self.sessions[session_id][0] = Dequeue() - + def cancel_all_session(self): with self.lock: for session_id in self.sessions: for future in self.futures[session_id]: future.cancel() cnt = self.sessions[session_id][0].qsize() - if cnt>0: - logger.info("Cancel {} messages in session {}".format(cnt, session_id)) + if cnt > 0: + logger.info( + "Cancel {} messages in session {}".format(cnt, session_id) + ) self.sessions[session_id][0] = Dequeue() - + def check_prefix(content, prefix_list): if not prefix_list: @@ -323,6 +423,7 @@ def check_prefix(content, prefix_list): return prefix return None + def check_contain(content, keyword_list): if not keyword_list: return None diff --git a/channel/chat_message.py b/channel/chat_message.py index 2f8993c5d..fdd4d9005 100644 --- a/channel/chat_message.py +++ b/channel/chat_message.py @@ -1,5 +1,4 @@ - -""" +""" 本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。 填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel @@ -20,7 +19,7 @@ other_user_nickname: 同上 is_group: 是否是群消息 (群聊必填) -is_at: 是否被at +is_at: 是否被at - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) actual_user_id: 实际发送者id (群聊必填) @@ -34,20 +33,22 @@ _rawmsg: 原始消息对象 """ + + class ChatMessage(object): msg_id = None create_time = None - + ctype = None content = None - + from_user_id = None from_user_nickname = None to_user_id = None to_user_nickname = None other_user_id = None other_user_nickname = None - + is_group = False is_at = False actual_user_id = None @@ -57,8 +58,7 @@ class ChatMessage(object): _prepared = False _rawmsg = None - - def __init__(self,_rawmsg): + def __init__(self, _rawmsg): self._rawmsg = _rawmsg def prepare(self): @@ -67,7 +67,7 @@ def prepare(self): self._prepare_fn() def __str__(self): - return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format( + return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format( self.msg_id, self.create_time, self.ctype, @@ -82,4 +82,4 @@ def __str__(self): self.is_at, self.actual_user_id, self.actual_user_nickname, - ) \ No newline at end of file + ) diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py index 16c7acdc6..e2060789c 100644 --- a/channel/terminal/terminal_channel.py +++ b/channel/terminal/terminal_channel.py @@ -1,14 +1,23 @@ +import sys + from bridge.context import * from bridge.reply import Reply, ReplyType from channel.chat_channel import ChatChannel, check_prefix from channel.chat_message import ChatMessage -import sys - -from config import conf from common.log import logger +from config import conf + class TerminalMessage(ChatMessage): - def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"): + def __init__( + self, + msg_id, + content, + ctype=ContextType.TEXT, + from_user_id="User", + to_user_id="Chatgpt", + other_user_id="Chatgpt", + ): self.msg_id = msg_id self.ctype = ctype self.content = content @@ -16,6 +25,7 @@ def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "U self.to_user_id = to_user_id self.other_user_id = other_user_id + class TerminalChannel(ChatChannel): NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] @@ -23,14 +33,18 @@ def send(self, reply: Reply, context: Context): print("\nBot:") if reply.type == ReplyType.IMAGE: from PIL import Image + image_storage = reply.content image_storage.seek(0) img = Image.open(image_storage) print("") img.show() - elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 + elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 + import io + + import requests from PIL import Image - import requests,io + img_url = reply.content pic_res = requests.get(img_url, stream=True) image_storage = io.BytesIO() @@ -59,11 +73,13 @@ def startup(self): print("\nExiting...") sys.exit() msg_id += 1 - trigger_prefixs = conf().get("single_chat_prefix",[""]) + trigger_prefixs = conf().get("single_chat_prefix", [""]) if check_prefix(prompt, trigger_prefixs) is None: - prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 - - context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt)) + prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 + + context = self._compose_context( + ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt) + ) if context: self.produce(context) else: diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index b11611d52..52c8ee3e0 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -4,40 +4,45 @@ wechat channel """ +import io +import json import os import threading -import requests -import io import time -import json + +import requests + +from bridge.context import * +from bridge.reply import * from channel.chat_channel import ChatChannel from channel.wechat.wechat_message import * -from common.singleton import singleton +from common.expired_dict import ExpiredDict from common.log import logger +from common.singleton import singleton +from common.time_check import time_checker +from config import conf, get_appdata_dir from lib import itchat from lib.itchat.content import * -from bridge.reply import * -from bridge.context import * -from config import conf -from common.time_check import time_checker -from common.expired_dict import ExpiredDict from plugins import * -@itchat.msg_register([TEXT,VOICE,PICTURE]) + +@itchat.msg_register([TEXT, VOICE, PICTURE]) def handler_single_msg(msg): # logger.debug("handler_single_msg: {}".format(msg)) - if msg['Type'] == PICTURE and msg['MsgType'] == 47: + if msg["Type"] == PICTURE and msg["MsgType"] == 47: return None WechatChannel().handle_single(WeChatMessage(msg)) return None -@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True) + +@itchat.msg_register([TEXT, VOICE, PICTURE], isGroupChat=True) def handler_group_msg(msg): - if msg['Type'] == PICTURE and msg['MsgType'] == 47: + if msg["Type"] == PICTURE and msg["MsgType"] == 47: return None - WechatChannel().handle_group(WeChatMessage(msg,True)) + WechatChannel().handle_group(WeChatMessage(msg, True)) return None + def _check(func): def wrapper(self, cmsg: ChatMessage): msgId = cmsg.msg_id @@ -45,21 +50,27 @@ def wrapper(self, cmsg: ChatMessage): logger.info("Wechat message {} already received, ignore".format(msgId)) return self.receivedMsgs[msgId] = cmsg - create_time = cmsg.create_time # 消息时间戳 - if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 + create_time = cmsg.create_time # 消息时间戳 + if ( + conf().get("hot_reload") == True + and int(create_time) < int(time.time()) - 60 + ): # 跳过1分钟前的历史消息 logger.debug("[WX]history message {} skipped".format(msgId)) return return func(self, cmsg) + return wrapper -#可用的二维码生成接口 -#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com -#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com -def qrCallback(uuid,status,qrcode): + +# 可用的二维码生成接口 +# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com +# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com +def qrCallback(uuid, status, qrcode): # logger.debug("qrCallback: {} {}".format(uuid,status)) - if status == '0': + if status == "0": try: from PIL import Image + img = Image.open(io.BytesIO(qrcode)) _thread = threading.Thread(target=img.show, args=("QRCode",)) _thread.setDaemon(True) @@ -68,48 +79,68 @@ def qrCallback(uuid,status,qrcode): pass import qrcode + url = f"https://login.weixin.qq.com/l/{uuid}" - - qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) - qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url) - qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url) - qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url) + + qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) + qr_api2 = ( + "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format( + url + ) + ) + qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) + qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format( + url + ) print("You can also scan QRCode in any website below:") print(qr_api3) print(qr_api4) print(qr_api2) print(qr_api1) - + qr = qrcode.QRCode(border=1) qr.add_data(url) qr.make(fit=True) qr.print_ascii(invert=True) + @singleton class WechatChannel(ChatChannel): NOT_SUPPORT_REPLYTYPE = [] + def __init__(self): super().__init__() - self.receivedMsgs = ExpiredDict(60*60*24) + self.receivedMsgs = ExpiredDict(60 * 60 * 24) def startup(self): - itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 # login by scan QRCode - hotReload = conf().get('hot_reload', False) + hotReload = conf().get("hot_reload", False) + status_path = os.path.join(get_appdata_dir(), "itchat.pkl") try: - itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) + itchat.auto_login( + enableCmdQR=2, + hotReload=hotReload, + statusStorageDir=status_path, + qrCallback=qrCallback, + ) except Exception as e: if hotReload: logger.error("Hot reload failed, try to login without hot reload") itchat.logout() - os.remove("itchat.pkl") - itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) + os.remove(status_path) + itchat.auto_login( + enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback + ) else: raise e self.user_id = itchat.instance.storageClass.userName self.name = itchat.instance.storageClass.nickName - logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) + logger.info( + "Wechat login success, user_id: {}, nickname: {}".format( + self.user_id, self.name + ) + ) # start message listener itchat.run() @@ -127,24 +158,30 @@ def startup(self): @time_checker @_check - def handle_single(self, cmsg : ChatMessage): + def handle_single(self, cmsg: ChatMessage): if cmsg.ctype == ContextType.VOICE: - if conf().get('speech_recognition') != True: + if conf().get("speech_recognition") != True: return logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) elif cmsg.ctype == ContextType.IMAGE: logger.debug("[WX]receive image msg: {}".format(cmsg.content)) else: - logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) - context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) + logger.debug( + "[WX]receive text msg: {}, cmsg={}".format( + json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg + ) + ) + context = self._compose_context( + cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg + ) if context: self.produce(context) @time_checker @_check - def handle_group(self, cmsg : ChatMessage): + def handle_group(self, cmsg: ChatMessage): if cmsg.ctype == ContextType.VOICE: - if conf().get('speech_recognition') != True: + if conf().get("speech_recognition") != True: return logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) elif cmsg.ctype == ContextType.IMAGE: @@ -152,23 +189,25 @@ def handle_group(self, cmsg : ChatMessage): else: # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) pass - context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) + context = self._compose_context( + cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg + ) if context: self.produce(context) - + # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply: Reply, context: Context): receiver = context["receiver"] if reply.type == ReplyType.TEXT: itchat.send(reply.content, toUserName=receiver) - logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: itchat.send(reply.content, toUserName=receiver) - logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) elif reply.type == ReplyType.VOICE: itchat.send_file(reply.content, toUserName=receiver) - logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) - elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 + logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver)) + elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 img_url = reply.content pic_res = requests.get(img_url, stream=True) image_storage = io.BytesIO() @@ -176,9 +215,9 @@ def send(self, reply: Reply, context: Context): image_storage.write(block) image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) - logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) - elif reply.type == ReplyType.IMAGE: # 从文件读取图片 + logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) + elif reply.type == ReplyType.IMAGE: # 从文件读取图片 image_storage = reply.content image_storage.seek(0) itchat.send_image(image_storage, toUserName=receiver) - logger.info('[WX] sendImage, receiver={}'.format(receiver)) + logger.info("[WX] sendImage, receiver={}".format(receiver)) diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py index bc12b43c8..92182bde7 100644 --- a/channel/wechat/wechat_message.py +++ b/channel/wechat/wechat_message.py @@ -1,54 +1,54 @@ - - from bridge.context import ContextType from channel.chat_message import ChatMessage -from common.tmp_dir import TmpDir from common.log import logger -from lib.itchat.content import * +from common.tmp_dir import TmpDir from lib import itchat +from lib.itchat.content import * -class WeChatMessage(ChatMessage): +class WeChatMessage(ChatMessage): def __init__(self, itchat_msg, is_group=False): - super().__init__( itchat_msg) - self.msg_id = itchat_msg['MsgId'] - self.create_time = itchat_msg['CreateTime'] + super().__init__(itchat_msg) + self.msg_id = itchat_msg["MsgId"] + self.create_time = itchat_msg["CreateTime"] self.is_group = is_group - - if itchat_msg['Type'] == TEXT: + + if itchat_msg["Type"] == TEXT: self.ctype = ContextType.TEXT - self.content = itchat_msg['Text'] - elif itchat_msg['Type'] == VOICE: + self.content = itchat_msg["Text"] + elif itchat_msg["Type"] == VOICE: self.ctype = ContextType.VOICE - self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 + self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 self._prepare_fn = lambda: itchat_msg.download(self.content) - elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3: + elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3: self.ctype = ContextType.IMAGE - self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 + self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 self._prepare_fn = lambda: itchat_msg.download(self.content) else: - raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type'])) - - self.from_user_id = itchat_msg['FromUserName'] - self.to_user_id = itchat_msg['ToUserName'] - + raise NotImplementedError( + "Unsupported message type: {}".format(itchat_msg["Type"]) + ) + + self.from_user_id = itchat_msg["FromUserName"] + self.to_user_id = itchat_msg["ToUserName"] + user_id = itchat.instance.storageClass.userName nickname = itchat.instance.storageClass.nickName - + # 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下 # 以下很繁琐,一句话总结:能填的都填了。 if self.from_user_id == user_id: self.from_user_nickname = nickname if self.to_user_id == user_id: self.to_user_nickname = nickname - try: # 陌生人时候, 'User'字段可能不存在 - self.other_user_id = itchat_msg['User']['UserName'] - self.other_user_nickname = itchat_msg['User']['NickName'] + try: # 陌生人时候, 'User'字段可能不存在 + self.other_user_id = itchat_msg["User"]["UserName"] + self.other_user_nickname = itchat_msg["User"]["NickName"] if self.other_user_id == self.from_user_id: self.from_user_nickname = self.other_user_nickname if self.other_user_id == self.to_user_id: self.to_user_nickname = self.other_user_nickname - except KeyError as e: # 处理偶尔没有对方信息的情况 + except KeyError as e: # 处理偶尔没有对方信息的情况 logger.warn("[WX]get other_user_id failed: " + str(e)) if self.from_user_id == user_id: self.other_user_id = self.to_user_id @@ -56,6 +56,6 @@ def __init__(self, itchat_msg, is_group=False): self.other_user_id = self.from_user_id if self.is_group: - self.is_at = itchat_msg['IsAt'] - self.actual_user_id = itchat_msg['ActualUserName'] - self.actual_user_nickname = itchat_msg['ActualNickName'] + self.is_at = itchat_msg["IsAt"] + self.actual_user_id = itchat_msg["ActualUserName"] + self.actual_user_nickname = itchat_msg["ActualNickName"] diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 65348bacc..7383a206c 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -4,104 +4,118 @@ wechaty channel Python Wechaty - https://github.com/wechaty/python-wechaty """ +import asyncio import base64 import os import time -import asyncio -from bridge.context import Context -from wechaty_puppet import FileBox -from wechaty import Wechaty, Contact + +from wechaty import Contact, Wechaty from wechaty.user import Message -from bridge.reply import * +from wechaty_puppet import FileBox + from bridge.context import * +from bridge.context import Context +from bridge.reply import * from channel.chat_channel import ChatChannel from channel.wechat.wechaty_message import WechatyMessage from common.log import logger from common.singleton import singleton from config import conf + try: from voice.audio_convert import any_to_sil except Exception as e: pass + @singleton class WechatyChannel(ChatChannel): NOT_SUPPORT_REPLYTYPE = [] + def __init__(self): super().__init__() def startup(self): config = conf() - token = config.get('wechaty_puppet_service_token') - os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token + token = config.get("wechaty_puppet_service_token") + os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token asyncio.run(self.main()) async def main(self): - loop = asyncio.get_event_loop() - #将asyncio的loop传入处理线程 - self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop) + # 将asyncio的loop传入处理线程 + self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop) self.bot = Wechaty() - self.bot.on('login', self.on_login) - self.bot.on('message', self.on_message) + self.bot.on("login", self.on_login) + self.bot.on("message", self.on_message) await self.bot.start() async def on_login(self, contact: Contact): self.user_id = contact.contact_id self.name = contact.name - logger.info('[WX] login user={}'.format(contact)) + logger.info("[WX] login user={}".format(contact)) # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 def send(self, reply: Reply, context: Context): - receiver_id = context['receiver'] + receiver_id = context["receiver"] loop = asyncio.get_event_loop() - if context['isgroup']: - receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result() + if context["isgroup"]: + receiver = asyncio.run_coroutine_threadsafe( + self.bot.Room.find(receiver_id), loop + ).result() else: - receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result() + receiver = asyncio.run_coroutine_threadsafe( + self.bot.Contact.find(receiver_id), loop + ).result() msg = None if reply.type == ReplyType.TEXT: msg = reply.content - asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() - logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() + logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: msg = reply.content - asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() - logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) + asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() + logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) elif reply.type == ReplyType.VOICE: voiceLength = None file_path = reply.content - sil_file = os.path.splitext(file_path)[0] + '.sil' + sil_file = os.path.splitext(file_path)[0] + ".sil" voiceLength = int(any_to_sil(file_path, sil_file)) if voiceLength >= 60000: voiceLength = 60000 - logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength)) + logger.info( + "[WX] voice too long, length={}, set to 60s".format(voiceLength) + ) # 发送语音 t = int(time.time()) - msg = FileBox.from_file(sil_file, name=str(t) + '.sil') + msg = FileBox.from_file(sil_file, name=str(t) + ".sil") if voiceLength is not None: - msg.metadata['voiceLength'] = voiceLength - asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() + msg.metadata["voiceLength"] = voiceLength + asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() try: os.remove(file_path) if sil_file != file_path: os.remove(sil_file) except Exception as e: pass - logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver)) - elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 + logger.info( + "[WX] sendVoice={}, receiver={}".format(reply.content, receiver) + ) + elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 img_url = reply.content t = int(time.time()) - msg = FileBox.from_url(url=img_url, name=str(t) + '.png') - asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() - logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) - elif reply.type == ReplyType.IMAGE: # 从文件读取图片 + msg = FileBox.from_url(url=img_url, name=str(t) + ".png") + asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() + logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) + elif reply.type == ReplyType.IMAGE: # 从文件读取图片 image_storage = reply.content image_storage.seek(0) t = int(time.time()) - msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png') - asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() - logger.info('[WX] sendImage, receiver={}'.format(receiver)) + msg = FileBox.from_base64( + base64.b64encode(image_storage.read()), str(t) + ".png" + ) + asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() + logger.info("[WX] sendImage, receiver={}".format(receiver)) async def on_message(self, msg: Message): """ @@ -110,16 +124,16 @@ async def on_message(self, msg: Message): try: cmsg = await WechatyMessage(msg) except NotImplementedError as e: - logger.debug('[WX] {}'.format(e)) + logger.debug("[WX] {}".format(e)) return except Exception as e: - logger.exception('[WX] {}'.format(e)) + logger.exception("[WX] {}".format(e)) return - logger.debug('[WX] message:{}'.format(cmsg)) + logger.debug("[WX] message:{}".format(cmsg)) room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None isgroup = room is not None ctype = cmsg.ctype context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) if context: - logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context)) - self.produce(context) \ No newline at end of file + logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context)) + self.produce(context) diff --git a/channel/wechat/wechaty_message.py b/channel/wechat/wechaty_message.py index be2fe5de6..f7d27faf8 100644 --- a/channel/wechat/wechaty_message.py +++ b/channel/wechat/wechaty_message.py @@ -1,17 +1,21 @@ import asyncio import re + from wechaty import MessageType +from wechaty.user import Message + from bridge.context import ContextType from channel.chat_message import ChatMessage -from common.tmp_dir import TmpDir from common.log import logger -from wechaty.user import Message +from common.tmp_dir import TmpDir + class aobject(object): """Inheriting this class allows you to define an async __init__. So you can create objects by doing something like `await MyClass(params)` """ + async def __new__(cls, *a, **kw): instance = super().__new__(cls) await instance.__init__(*a, **kw) @@ -19,17 +23,18 @@ async def __new__(cls, *a, **kw): async def __init__(self): pass -class WechatyMessage(ChatMessage, aobject): + +class WechatyMessage(ChatMessage, aobject): async def __init__(self, wechaty_msg: Message): super().__init__(wechaty_msg) - + room = wechaty_msg.room() self.msg_id = wechaty_msg.message_id self.create_time = wechaty_msg.payload.timestamp self.is_group = room is not None - + if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT: self.ctype = ContextType.TEXT self.content = wechaty_msg.text() @@ -40,12 +45,17 @@ async def __init__(self, wechaty_msg: Message): def func(): loop = asyncio.get_event_loop() - asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result() + asyncio.run_coroutine_threadsafe( + voice_file.to_file(self.content), loop + ).result() + self._prepare_fn = func - + else: - raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) - + raise NotImplementedError( + "Unsupported message type: {}".format(wechaty_msg.type()) + ) + from_contact = wechaty_msg.talker() # 获取消息的发送者 self.from_user_id = from_contact.contact_id self.from_user_nickname = from_contact.name @@ -54,7 +64,7 @@ def func(): # wecahty: from是消息实际发送者, to:所在群 # itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己 # 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户 - + if self.is_group: self.to_user_id = room.room_id self.to_user_nickname = await room.topic() @@ -63,22 +73,22 @@ def func(): self.to_user_id = to_contact.contact_id self.to_user_nickname = to_contact.name - if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 + if ( + self.is_group or wechaty_msg.is_self() + ): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 self.other_user_id = self.to_user_id self.other_user_nickname = self.to_user_nickname else: self.other_user_id = self.from_user_id self.other_user_nickname = self.from_user_nickname - - - if self.is_group: # wechaty群聊中,实际发送用户就是from_user + if self.is_group: # wechaty群聊中,实际发送用户就是from_user self.is_at = await wechaty_msg.mention_self() - if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 + if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 name = wechaty_msg.wechaty.user_self().name - pattern = f'@{name}(\u2005|\u0020)' - if re.search(pattern,self.content): - logger.debug(f'wechaty message {self.msg_id} include at') + pattern = f"@{name}(\u2005|\u0020)" + if re.search(pattern, self.content): + logger.debug(f"wechaty message {self.msg_id} include at") self.is_at = True self.actual_user_id = self.from_user_id diff --git a/channel/wechatmp/README.md b/channel/wechatmp/README.md index c69ca92de..69b8037b9 100644 --- a/channel/wechatmp/README.md +++ b/channel/wechatmp/README.md @@ -21,12 +21,12 @@ pip3 install web.py 相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加 ``` -"channel_type": "wechatmp", +"channel_type": "wechatmp", "wechatmp_token": "Token", # 微信公众平台的Token "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 "wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 "wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 -``` +``` 然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径): ``` sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080 diff --git a/channel/wechatmp/ServiceAccount.py b/channel/wechatmp/ServiceAccount.py index eeef66d39..699581d99 100644 --- a/channel/wechatmp/ServiceAccount.py +++ b/channel/wechatmp/ServiceAccount.py @@ -1,46 +1,66 @@ -import web import time -import channel.wechatmp.reply as reply + +import web + import channel.wechatmp.receive as receive -from config import conf -from common.log import logger +import channel.wechatmp.reply as reply from bridge.context import * -from channel.wechatmp.common import * +from channel.wechatmp.common import * from channel.wechatmp.wechatmp_channel import WechatMPChannel +from common.log import logger +from config import conf -# This class is instantiated once per query -class Query(): +# This class is instantiated once per query +class Query: def GET(self): return verify_server(web.input()) def POST(self): - # Make sure to return the instance that first created, @singleton will do that. + # Make sure to return the instance that first created, @singleton will do that. channel = WechatMPChannel() try: webData = web.data() # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) wechatmp_msg = receive.parse_xml(webData) - if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice': + if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice": from_user = wechatmp_msg.from_user_id message = wechatmp_msg.content.decode("utf-8") message_id = wechatmp_msg.msg_id - logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) - context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) + logger.info( + "[wechatmp] {}:{} Receive post query {} {}: {}".format( + web.ctx.env.get("REMOTE_ADDR"), + web.ctx.env.get("REMOTE_PORT"), + from_user, + message_id, + message, + ) + ) + context = channel._compose_context( + ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg + ) if context: # set private openai_api_key # if from_user is not changed in itchat, this can be placed at chat_channel user_data = conf().get_user_data(from_user) - context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key + context["openai_api_key"] = user_data.get( + "openai_api_key" + ) # None or user openai_api_key channel.produce(context) # The reply will be sent by channel.send() in another thread return "success" - elif wechatmp_msg.msg_type == 'event': - logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id)) + elif wechatmp_msg.msg_type == "event": + logger.info( + "[wechatmp] Event {} from {}".format( + wechatmp_msg.Event, wechatmp_msg.from_user_id + ) + ) content = subscribe_msg() - replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) + replyMsg = reply.TextMsg( + wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content + ) return replyMsg.send() else: logger.info("暂且不处理") @@ -48,4 +68,3 @@ def POST(self): except Exception as exc: logger.exception(exc) return exc - diff --git a/channel/wechatmp/SubscribeAccount.py b/channel/wechatmp/SubscribeAccount.py index 7de2abd4a..8eeedb4bf 100644 --- a/channel/wechatmp/SubscribeAccount.py +++ b/channel/wechatmp/SubscribeAccount.py @@ -1,81 +1,117 @@ -import web import time -import channel.wechatmp.reply as reply + +import web + import channel.wechatmp.receive as receive -from config import conf -from common.log import logger +import channel.wechatmp.reply as reply from bridge.context import * -from channel.wechatmp.common import * +from channel.wechatmp.common import * from channel.wechatmp.wechatmp_channel import WechatMPChannel +from common.log import logger +from config import conf -# This class is instantiated once per query -class Query(): +# This class is instantiated once per query +class Query: def GET(self): return verify_server(web.input()) def POST(self): - # Make sure to return the instance that first created, @singleton will do that. + # Make sure to return the instance that first created, @singleton will do that. channel = WechatMPChannel() try: query_time = time.time() webData = web.data() logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) wechatmp_msg = receive.parse_xml(webData) - if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice': + if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice": from_user = wechatmp_msg.from_user_id to_user = wechatmp_msg.to_user_id message = wechatmp_msg.content.decode("utf-8") message_id = wechatmp_msg.msg_id - logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) + logger.info( + "[wechatmp] {}:{} Receive post query {} {}: {}".format( + web.ctx.env.get("REMOTE_ADDR"), + web.ctx.env.get("REMOTE_PORT"), + from_user, + message_id, + message, + ) + ) supported = True if "【收到不支持的消息类型,暂无法显示】" in message: - supported = False # not supported, used to refresh + supported = False # not supported, used to refresh cache_key = from_user reply_text = "" # New request - if cache_key not in channel.cache_dict and cache_key not in channel.running: + if ( + cache_key not in channel.cache_dict + and cache_key not in channel.running + ): # The first query begin, reset the cache - context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) - logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg)) - if message_id in channel.received_msgs: # received and finished + context = channel._compose_context( + ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg + ) + logger.debug( + "[wechatmp] context: {} {}".format(context, wechatmp_msg) + ) + if message_id in channel.received_msgs: # received and finished # no return because of bandwords or other reasons return "success" if supported and context: # set private openai_api_key # if from_user is not changed in itchat, this can be placed at chat_channel user_data = conf().get_user_data(from_user) - context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key + context["openai_api_key"] = user_data.get( + "openai_api_key" + ) # None or user openai_api_key channel.received_msgs[message_id] = wechatmp_msg channel.running.add(cache_key) channel.produce(context) else: - trigger_prefix = conf().get('single_chat_prefix',[''])[0] + trigger_prefix = conf().get("single_chat_prefix", [""])[0] if trigger_prefix or not supported: if trigger_prefix: - content = textwrap.dedent(f"""\ + content = textwrap.dedent( + f"""\ 请输入'{trigger_prefix}'接你想说的话跟我说话。 例如: - {trigger_prefix}你好,很高兴见到你。""") + {trigger_prefix}你好,很高兴见到你。""" + ) else: - content = textwrap.dedent("""\ + content = textwrap.dedent( + """\ 你好,很高兴见到你。 - 请跟我说话吧。""") + 请跟我说话吧。""" + ) else: logger.error(f"[wechatmp] unknown error") - content = textwrap.dedent("""\ - 未知错误,请稍后再试""") - replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) + content = textwrap.dedent( + """\ + 未知错误,请稍后再试""" + ) + replyMsg = reply.TextMsg( + wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content + ) return replyMsg.send() channel.query1[cache_key] = False channel.query2[cache_key] = False channel.query3[cache_key] = False # User request again, and the answer is not ready - elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: - channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. - channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. + elif ( + cache_key in channel.running + and channel.query1.get(cache_key) == True + and channel.query2.get(cache_key) == True + and channel.query3.get(cache_key) == True + ): + channel.query1[ + cache_key + ] = False # To improve waiting experience, this can be set to True. + channel.query2[ + cache_key + ] = False # To improve waiting experience, this can be set to True. channel.query3[cache_key] = False # User request again, and the answer is ready elif cache_key in channel.cache_dict: @@ -84,7 +120,9 @@ def POST(self): channel.query2[cache_key] = True channel.query3[cache_key] = True - assert not (cache_key in channel.cache_dict and cache_key in channel.running) + assert not ( + cache_key in channel.cache_dict and cache_key in channel.running + ) if channel.query1.get(cache_key) == False: # The first query from wechat official server @@ -128,14 +166,20 @@ def POST(self): # Have waiting for 3x5 seconds # return timeout message reply_text = "【正在思考中,回复任意文字尝试获取回复】" - logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id)) + logger.info( + "[wechatmp] Three queries has finished For {}: {}".format( + from_user, message_id + ) + ) replyPost = reply.TextMsg(from_user, to_user, reply_text).send() return replyPost else: pass - - if cache_key not in channel.cache_dict and cache_key not in channel.running: + if ( + cache_key not in channel.cache_dict + and cache_key not in channel.running + ): # no return because of bandwords or other reasons return "success" @@ -147,26 +191,42 @@ def POST(self): if cache_key in channel.cache_dict: content = channel.cache_dict[cache_key] - if len(content.encode('utf8'))<=MAX_UTF8_LEN: + if len(content.encode("utf8")) <= MAX_UTF8_LEN: reply_text = channel.cache_dict[cache_key] channel.cache_dict.pop(cache_key) else: continue_text = "\n【未完待续,回复任意文字以继续】" - splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1) + splits = split_string_by_utf8_length( + content, + MAX_UTF8_LEN - len(continue_text.encode("utf-8")), + max_split=1, + ) reply_text = splits[0] + continue_text channel.cache_dict[cache_key] = splits[1] - logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text)) + logger.info( + "[wechatmp] {}:{} Do send {}".format( + web.ctx.env.get("REMOTE_ADDR"), + web.ctx.env.get("REMOTE_PORT"), + reply_text, + ) + ) replyPost = reply.TextMsg(from_user, to_user, reply_text).send() return replyPost - elif wechatmp_msg.msg_type == 'event': - logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id)) + elif wechatmp_msg.msg_type == "event": + logger.info( + "[wechatmp] Event {} from {}".format( + wechatmp_msg.content, wechatmp_msg.from_user_id + ) + ) content = subscribe_msg() - replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) + replyMsg = reply.TextMsg( + wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content + ) return replyMsg.send() else: logger.info("暂且不处理") return "success" except Exception as exc: logger.exception(exc) - return exc \ No newline at end of file + return exc diff --git a/channel/wechatmp/common.py b/channel/wechatmp/common.py index 27609e0e9..192b86bda 100644 --- a/channel/wechatmp/common.py +++ b/channel/wechatmp/common.py @@ -1,9 +1,11 @@ -from config import conf import hashlib import textwrap +from config import conf + MAX_UTF8_LEN = 2048 + class WeChatAPIException(Exception): pass @@ -16,13 +18,13 @@ def verify_server(data): timestamp = data.timestamp nonce = data.nonce echostr = data.echostr - token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写 + token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写 data_list = [token, timestamp, nonce] data_list.sort() sha1 = hashlib.sha1() # map(sha1.update, data_list) #python2 - sha1.update("".join(data_list).encode('utf-8')) + sha1.update("".join(data_list).encode("utf-8")) hashcode = sha1.hexdigest() print("handle/GET func: hashcode, signature: ", hashcode, signature) if hashcode == signature: @@ -32,9 +34,11 @@ def verify_server(data): except Exception as Argument: return Argument + def subscribe_msg(): - trigger_prefix = conf().get('single_chat_prefix',[''])[0] - msg = textwrap.dedent(f"""\ + trigger_prefix = conf().get("single_chat_prefix", [""])[0] + msg = textwrap.dedent( + f"""\ 感谢您的关注! 这里是ChatGPT,可以自由对话。 资源有限,回复较慢,请勿着急。 @@ -42,22 +46,23 @@ def subscribe_msg(): 暂时不支持图片输入。 支持图片输出,画字开头的问题将回复图片链接。 支持角色扮演和文字冒险两种定制模式对话。 - 输入'{trigger_prefix}#帮助' 查看详细指令。""") + 输入'{trigger_prefix}#帮助' 查看详细指令。""" + ) return msg def split_string_by_utf8_length(string, max_length, max_split=0): - encoded = string.encode('utf-8') + encoded = string.encode("utf-8") start, end = 0, 0 result = [] while end < len(encoded): if max_split > 0 and len(result) >= max_split: - result.append(encoded[start:].decode('utf-8')) + result.append(encoded[start:].decode("utf-8")) break end = start + max_length # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: end -= 1 - result.append(encoded[start:end].decode('utf-8')) + result.append(encoded[start:end].decode("utf-8")) start = end - return result \ No newline at end of file + return result diff --git a/channel/wechatmp/receive.py b/channel/wechatmp/receive.py index bb431dcf7..1285fd155 100644 --- a/channel/wechatmp/receive.py +++ b/channel/wechatmp/receive.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*-# # filename: receive.py import xml.etree.ElementTree as ET + from bridge.context import ContextType from channel.chat_message import ChatMessage from common.log import logger @@ -12,34 +13,35 @@ def parse_xml(web_data): xmlData = ET.fromstring(web_data) return WeChatMPMessage(xmlData) + class WeChatMPMessage(ChatMessage): def __init__(self, xmlData): super().__init__(xmlData) - self.to_user_id = xmlData.find('ToUserName').text - self.from_user_id = xmlData.find('FromUserName').text - self.create_time = xmlData.find('CreateTime').text - self.msg_type = xmlData.find('MsgType').text + self.to_user_id = xmlData.find("ToUserName").text + self.from_user_id = xmlData.find("FromUserName").text + self.create_time = xmlData.find("CreateTime").text + self.msg_type = xmlData.find("MsgType").text try: - self.msg_id = xmlData.find('MsgId').text + self.msg_id = xmlData.find("MsgId").text except: - self.msg_id = self.from_user_id+self.create_time + self.msg_id = self.from_user_id + self.create_time self.is_group = False - + # reply to other_user_id self.other_user_id = self.from_user_id - if self.msg_type == 'text': + if self.msg_type == "text": self.ctype = ContextType.TEXT - self.content = xmlData.find('Content').text.encode("utf-8") - elif self.msg_type == 'voice': + self.content = xmlData.find("Content").text.encode("utf-8") + elif self.msg_type == "voice": self.ctype = ContextType.TEXT - self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果 - elif self.msg_type == 'image': + self.content = xmlData.find("Recognition").text.encode("utf-8") # 接收语音识别结果 + elif self.msg_type == "image": # not implemented - self.pic_url = xmlData.find('PicUrl').text - self.media_id = xmlData.find('MediaId').text - elif self.msg_type == 'event': - self.content = xmlData.find('Event').text - else: # video, shortvideo, location, link + self.pic_url = xmlData.find("PicUrl").text + self.media_id = xmlData.find("MediaId").text + elif self.msg_type == "event": + self.content = xmlData.find("Event").text + else: # video, shortvideo, location, link # not implemented - pass \ No newline at end of file + pass diff --git a/channel/wechatmp/reply.py b/channel/wechatmp/reply.py index 5f3a9347a..2f852f99b 100644 --- a/channel/wechatmp/reply.py +++ b/channel/wechatmp/reply.py @@ -2,6 +2,7 @@ # filename: reply.py import time + class Msg(object): def __init__(self): pass @@ -9,13 +10,14 @@ def __init__(self): def send(self): return "success" + class TextMsg(Msg): def __init__(self, toUserName, fromUserName, content): self.__dict = dict() - self.__dict['ToUserName'] = toUserName - self.__dict['FromUserName'] = fromUserName - self.__dict['CreateTime'] = int(time.time()) - self.__dict['Content'] = content + self.__dict["ToUserName"] = toUserName + self.__dict["FromUserName"] = fromUserName + self.__dict["CreateTime"] = int(time.time()) + self.__dict["Content"] = content def send(self): XmlForm = """ @@ -29,13 +31,14 @@ def send(self): """ return XmlForm.format(**self.__dict) + class ImageMsg(Msg): def __init__(self, toUserName, fromUserName, mediaId): self.__dict = dict() - self.__dict['ToUserName'] = toUserName - self.__dict['FromUserName'] = fromUserName - self.__dict['CreateTime'] = int(time.time()) - self.__dict['MediaId'] = mediaId + self.__dict["ToUserName"] = toUserName + self.__dict["FromUserName"] = fromUserName + self.__dict["CreateTime"] = int(time.time()) + self.__dict["MediaId"] = mediaId def send(self): XmlForm = """ @@ -49,4 +52,4 @@ def send(self): """ - return XmlForm.format(**self.__dict) \ No newline at end of file + return XmlForm.format(**self.__dict) diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index c56c1cb3b..ac3c3ac1a 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -1,17 +1,19 @@ # -*- coding: utf-8 -*- -import web -import time import json -import requests import threading -from common.singleton import singleton -from common.log import logger -from common.expired_dict import ExpiredDict -from config import conf -from bridge.reply import * +import time + +import requests +import web + from bridge.context import * +from bridge.reply import * from channel.chat_channel import ChatChannel -from channel.wechatmp.common import * +from channel.wechatmp.common import * +from common.expired_dict import ExpiredDict +from common.log import logger +from common.singleton import singleton +from config import conf # If using SSL, uncomment the following lines, and modify the certificate path. # from cheroot.server import HTTPServer @@ -20,13 +22,14 @@ # certificate='/ssl/cert.pem', # private_key='/ssl/cert.key') + @singleton class WechatMPChannel(ChatChannel): - def __init__(self, passive_reply = True): + def __init__(self, passive_reply=True): super().__init__() self.passive_reply = passive_reply self.running = set() - self.received_msgs = ExpiredDict(60*60*24) + self.received_msgs = ExpiredDict(60 * 60 * 24) if self.passive_reply: self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] self.cache_dict = dict() @@ -36,8 +39,8 @@ def __init__(self, passive_reply = True): else: # TODO support image self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] - self.app_id = conf().get('wechatmp_app_id') - self.app_secret = conf().get('wechatmp_app_secret') + self.app_id = conf().get("wechatmp_app_id") + self.app_secret = conf().get("wechatmp_app_secret") self.access_token = None self.access_token_expires_time = 0 self.access_token_lock = threading.Lock() @@ -45,13 +48,12 @@ def __init__(self, passive_reply = True): def startup(self): if self.passive_reply: - urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query') + urls = ("/wx", "channel.wechatmp.SubscribeAccount.Query") else: - urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query') + urls = ("/wx", "channel.wechatmp.ServiceAccount.Query") app = web.application(urls, globals(), autoreload=False) - port = conf().get('wechatmp_port', 8080) - web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port)) - + port = conf().get("wechatmp_port", 8080) + web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) def wechatmp_request(self, method, url, **kwargs): r = requests.request(method=method, url=url, **kwargs) @@ -63,7 +65,6 @@ def wechatmp_request(self, method, url, **kwargs): return ret def get_access_token(self): - # return the access_token if self.access_token: if self.access_token_expires_time - time.time() > 60: @@ -76,15 +77,15 @@ def get_access_token(self): # This happens every 2 hours, so it doesn't affect the experience very much time.sleep(1) self.access_token = None - url="https://api.weixin.qq.com/cgi-bin/token" - params={ + url = "https://api.weixin.qq.com/cgi-bin/token" + params = { "grant_type": "client_credential", "appid": self.app_id, - "secret": self.app_secret + "secret": self.app_secret, } - data = self.wechatmp_request(method='get', url=url, params=params) - self.access_token = data['access_token'] - self.access_token_expires_time = int(time.time()) + data['expires_in'] + data = self.wechatmp_request(method="get", url=url, params=params) + self.access_token = data["access_token"] + self.access_token_expires_time = int(time.time()) + data["expires_in"] logger.info("[wechatmp] access_token: {}".format(self.access_token)) self.access_token_lock.release() else: @@ -101,29 +102,37 @@ def send(self, reply: Reply, context: Context): else: receiver = context["receiver"] reply_text = reply.content - url="https://api.weixin.qq.com/cgi-bin/message/custom/send" - params = { - "access_token": self.get_access_token() - } + url = "https://api.weixin.qq.com/cgi-bin/message/custom/send" + params = {"access_token": self.get_access_token()} json_data = { "touser": receiver, "msgtype": "text", - "text": {"content": reply_text} + "text": {"content": reply_text}, } - self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8')) + self.wechatmp_request( + method="post", + url=url, + params=params, + data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), + ) logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) return - - def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 - logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id)) + def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 + logger.debug( + "[wechatmp] Success to generate reply, msgId={}".format( + context["msg"].msg_id + ) + ) if self.passive_reply: self.running.remove(session_id) - - def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 - logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) + def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 + logger.exception( + "[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format( + context["msg"].msg_id, exception + ) + ) if self.passive_reply: assert session_id not in self.cache_dict self.running.remove(session_id) - diff --git a/common/const.py b/common/const.py index 8336da39b..14d00fd03 100644 --- a/common/const.py +++ b/common/const.py @@ -2,4 +2,4 @@ OPEN_AI = "openAI" CHATGPT = "chatGPT" BAIDU = "baidu" -CHATGPTONAZURE = "chatGPTOnAzure" \ No newline at end of file +CHATGPTONAZURE = "chatGPTOnAzure" diff --git a/common/dequeue.py b/common/dequeue.py index edc9ef087..39baf5816 100644 --- a/common/dequeue.py +++ b/common/dequeue.py @@ -1,7 +1,7 @@ - from queue import Full, Queue from time import monotonic as time + # add implementation of putleft to Queue class Dequeue(Queue): def putleft(self, item, block=True, timeout=None): @@ -30,4 +30,4 @@ def putleft_nowait(self, item): return self.putleft(item, block=False) def _putleft(self, item): - self.queue.appendleft(item) \ No newline at end of file + self.queue.appendleft(item) diff --git a/common/expired_dict.py b/common/expired_dict.py index a86923b74..42fb4b178 100644 --- a/common/expired_dict.py +++ b/common/expired_dict.py @@ -39,4 +39,4 @@ def items(self): return [(key, self[key]) for key in self.keys()] def __iter__(self): - return self.keys().__iter__() \ No newline at end of file + return self.keys().__iter__() diff --git a/common/log.py b/common/log.py index 14c7effbf..f02a365b7 100644 --- a/common/log.py +++ b/common/log.py @@ -10,20 +10,29 @@ def _reset_logger(log): log.handlers.clear() log.propagate = False console_handle = logging.StreamHandler(sys.stdout) - console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S')) - file_handle = logging.FileHandler('run.log', encoding='utf-8') - file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S')) + console_handle.setFormatter( + logging.Formatter( + "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + file_handle = logging.FileHandler("run.log", encoding="utf-8") + file_handle.setFormatter( + logging.Formatter( + "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) log.addHandler(file_handle) log.addHandler(console_handle) + def _get_logger(): - log = logging.getLogger('log') + log = logging.getLogger("log") _reset_logger(log) log.setLevel(logging.INFO) return log # 日志句柄 -logger = _get_logger() \ No newline at end of file +logger = _get_logger() diff --git a/common/package_manager.py b/common/package_manager.py index 9370298bc..8f1aa457c 100644 --- a/common/package_manager.py +++ b/common/package_manager.py @@ -1,15 +1,20 @@ import time + import pip from pip._internal import main as pipmain -from common.log import logger,_reset_logger + +from common.log import _reset_logger, logger + def install(package): - pipmain(['install', package]) + pipmain(["install", package]) + def install_requirements(file): - pipmain(['install', '-r', file, "--upgrade"]) + pipmain(["install", "-r", file, "--upgrade"]) _reset_logger(logger) + def check_dulwich(): needwait = False for i in range(2): @@ -18,13 +23,14 @@ def check_dulwich(): needwait = False try: import dulwich + return except ImportError: try: - install('dulwich') + install("dulwich") except: needwait = True try: import dulwich except ImportError: - raise ImportError("Unable to import dulwich") \ No newline at end of file + raise ImportError("Unable to import dulwich") diff --git a/common/sorted_dict.py b/common/sorted_dict.py index a918a0c2e..7a1e85b94 100644 --- a/common/sorted_dict.py +++ b/common/sorted_dict.py @@ -62,4 +62,4 @@ def __iter__(self): return iter(self.keys()) def __repr__(self): - return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' + return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})" diff --git a/common/time_check.py b/common/time_check.py index 229dd149b..808f71ab3 100644 --- a/common/time_check.py +++ b/common/time_check.py @@ -1,7 +1,11 @@ -import time,re,hashlib +import hashlib +import re +import time + import config from common.log import logger + def time_checker(f): def _time_checker(self, *args, **kwargs): _config = config.conf() @@ -9,17 +13,25 @@ def _time_checker(self, *args, **kwargs): if chat_time_module: chat_start_time = _config.get("chat_start_time", "00:00") chat_stopt_time = _config.get("chat_stop_time", "24:00") - time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00 + time_regex = re.compile( + r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$" + ) # 时间匹配,包含24:00 starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 - chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 + chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 # 时间格式检查 - if not (starttime_format_check and stoptime_format_check and chat_time_check): - logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check)) - if chat_start_time>"23:59": - logger.error('启动时间可能存在问题,请修改!') + if not ( + starttime_format_check and stoptime_format_check and chat_time_check + ): + logger.warn( + "时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format( + starttime_format_check, stoptime_format_check + ) + ) + if chat_start_time > "23:59": + logger.error("启动时间可能存在问题,请修改!") # 服务时间检查 now_time = time.strftime("%H:%M", time.localtime()) @@ -27,12 +39,12 @@ def _time_checker(self, *args, **kwargs): f(self, *args, **kwargs) return None else: - if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置 + if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置 f(self, *args, **kwargs) else: - logger.info('非服务时间内,不接受访问') + logger.info("非服务时间内,不接受访问") return None else: f(self, *args, **kwargs) # 未开启时间模块则直接回答 - return _time_checker + return _time_checker diff --git a/common/tmp_dir.py b/common/tmp_dir.py index 63f56ebc3..b01880bdd 100644 --- a/common/tmp_dir.py +++ b/common/tmp_dir.py @@ -1,20 +1,18 @@ - import os import pathlib + from config import conf class TmpDir(object): - """A temporary directory that is deleted when the object is destroyed. - """ + """A temporary directory that is deleted when the object is destroyed.""" + + tmpFilePath = pathlib.Path("./tmp/") - tmpFilePath = pathlib.Path('./tmp/') - def __init__(self): pathExists = os.path.exists(self.tmpFilePath) if not pathExists: os.makedirs(self.tmpFilePath) def path(self): - return str(self.tmpFilePath) + '/' - \ No newline at end of file + return str(self.tmpFilePath) + "/" diff --git a/config-template.json b/config-template.json index 528bd6010..864aa03fe 100644 --- a/config-template.json +++ b/config-template.json @@ -2,16 +2,30 @@ "open_ai_api_key": "YOUR API KEY", "model": "gpt-3.5-turbo", "proxy": "", - "single_chat_prefix": ["bot", "@bot"], + "single_chat_prefix": [ + "bot", + "@bot" + ], "single_chat_reply_prefix": "[bot] ", - "group_chat_prefix": ["@bot"], - "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], - "group_chat_in_one_session": ["ChatGPT测试群"], - "image_create_prefix": ["画", "看", "找"], + "group_chat_prefix": [ + "@bot" + ], + "group_name_white_list": [ + "ChatGPT测试群", + "ChatGPT测试群2" + ], + "group_chat_in_one_session": [ + "ChatGPT测试群" + ], + "image_create_prefix": [ + "画", + "看", + "找" + ], "speech_recognition": false, "group_speech_recognition": false, "voice_reply_voice": false, "conversation_max_tokens": 1000, "expires_in_seconds": 3600, "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" -} \ No newline at end of file +} diff --git a/config.py b/config.py index 8745bf963..8f5d2cabe 100644 --- a/config.py +++ b/config.py @@ -3,9 +3,10 @@ import json import logging import os -from common.log import logger import pickle +from common.log import logger + # 将所有可用的配置项写在字典里, 请使用小写字母 available_setting = { # openai api配置 @@ -16,8 +17,7 @@ # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 "model": "gpt-3.5-turbo", "use_azure_chatgpt": False, # 是否使用azure的chatgpt - "azure_deployment_id": "", #azure 模型部署名称 - + "azure_deployment_id": "", # azure 模型部署名称 # Bot触发配置 "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 @@ -30,25 +30,22 @@ "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 "trigger_by_self": False, # 是否允许机器人触发 "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 - "concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序 - + "concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序 + "image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 # chatgpt会话参数 "expires_in_seconds": 3600, # 无操作会话的过期时间 "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 - # chatgpt限流配置 "rate_limit_chatgpt": 20, # chatgpt的调用频率限制 "rate_limit_dalle": 50, # openai dalle的调用频率限制 - # chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create "temperature": 0.9, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, - "request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 - "timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 - + "request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 # 语音设置 "speech_recognition": False, # 是否开启语音识别 "group_speech_recognition": False, # 是否开启群组语音识别 @@ -56,50 +53,41 @@ "always_reply_voice": False, # 是否一直使用语音回复 "voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure "text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure - # baidu 语音api配置, 使用百度语音识别和语音合成时需要 "baidu_app_id": "", "baidu_api_key": "", "baidu_secret_key": "", # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 "baidu_dev_pid": "1536", - # azure 语音api配置, 使用azure语音识别和语音合成时需要 "azure_voice_api_key": "", "azure_voice_region": "japaneast", - # 服务时间限制,目前支持itchat "chat_time_module": False, # 是否开启服务时间限制 "chat_start_time": "00:00", # 服务开始时间 "chat_stop_time": "24:00", # 服务结束时间 - # itchat的配置 "hot_reload": False, # 是否开启热重载 - # wechaty的配置 "wechaty_puppet_service_token": "", # wechaty的token - # wechatmp的配置 - "wechatmp_token": "", # 微信公众平台的Token - "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 - "wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 + "wechatmp_token": "", # 微信公众平台的Token + "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 + "wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 "wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 - # chatgpt指令自定义触发词 - "clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头 - + "clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头 # channel配置 - "channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service} - + "channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service} "debug": False, # 是否开启debug模式,开启后会打印更多日志 - + "appdata_dir": "", # 数据目录 # 插件配置 "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 } class Config(dict): - def __init__(self, d:dict={}): + def __init__(self, d: dict = {}): super().__init__(d) # user_datas: 用户数据,key为用户名,value为用户数据,也是dict self.user_datas = {} @@ -130,7 +118,7 @@ def get_user_data(self, user) -> dict: def load_user_datas(self): try: - with open('user_datas.pkl', 'rb') as f: + with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "rb") as f: self.user_datas = pickle.load(f) logger.info("[Config] User datas loaded.") except FileNotFoundError as e: @@ -141,12 +129,13 @@ def load_user_datas(self): def save_user_datas(self): try: - with open('user_datas.pkl', 'wb') as f: + with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "wb") as f: pickle.dump(self.user_datas, f) logger.info("[Config] User datas saved.") except Exception as e: logger.info("[Config] User datas error: {}".format(e)) + config = Config() @@ -154,7 +143,7 @@ def load_config(): global config config_path = "./config.json" if not os.path.exists(config_path): - logger.info('配置文件不存在,将使用config-template.json模板') + logger.info("配置文件不存在,将使用config-template.json模板") config_path = "./config-template.json" config_str = read_file(config_path) @@ -169,7 +158,8 @@ def load_config(): name = name.lower() if name in available_setting: logger.info( - "[INIT] override config by environ args: {}={}".format(name, value)) + "[INIT] override config by environ args: {}={}".format(name, value) + ) try: config[name] = eval(value) except: @@ -182,20 +172,29 @@ def load_config(): if config.get("debug", False): logger.setLevel(logging.DEBUG) - logger.debug("[INIT] set log level to DEBUG") + logger.debug("[INIT] set log level to DEBUG") logger.info("[INIT] load config: {}".format(config)) config.load_user_datas() + def get_root(): return os.path.dirname(os.path.abspath(__file__)) def read_file(path): - with open(path, mode='r', encoding='utf-8') as f: + with open(path, mode="r", encoding="utf-8") as f: return f.read() def conf(): return config + + +def get_appdata_dir(): + data_path = os.path.join(get_root(), conf().get("appdata_dir", "")) + if not os.path.exists(data_path): + logger.info("[INIT] data path not exists, create it: {}".format(data_path)) + os.makedirs(data_path) + return data_path diff --git a/docker/Dockerfile.debian b/docker/Dockerfile.debian index dfd289de1..798642d06 100644 --- a/docker/Dockerfile.debian +++ b/docker/Dockerfile.debian @@ -33,7 +33,7 @@ ADD ./entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh \ && groupadd -r noroot \ && useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ - && chown -R noroot:noroot ${BUILD_PREFIX} + && chown -R noroot:noroot ${BUILD_PREFIX} USER noroot diff --git a/docker/Dockerfile.debian.latest b/docker/Dockerfile.debian.latest index 95bb352c0..3fe0b76b9 100644 --- a/docker/Dockerfile.debian.latest +++ b/docker/Dockerfile.debian.latest @@ -18,7 +18,7 @@ RUN apt-get update \ && pip install --no-cache -r requirements.txt \ && pip install --no-cache -r requirements-optional.txt \ && pip install azure-cognitiveservices-speech - + WORKDIR ${BUILD_PREFIX} ADD docker/entrypoint.sh /entrypoint.sh diff --git a/docker/build.alpine.sh b/docker/build.alpine.sh index 8f102ddc8..fae7b3459 100644 --- a/docker/build.alpine.sh +++ b/docker/build.alpine.sh @@ -11,6 +11,5 @@ docker build -f Dockerfile.alpine \ -t zhayujie/chatgpt-on-wechat . # tag image -docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine +docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine - \ No newline at end of file diff --git a/docker/build.debian.sh b/docker/build.debian.sh index e7caa15dc..9b2d4f391 100644 --- a/docker/build.debian.sh +++ b/docker/build.debian.sh @@ -11,5 +11,5 @@ docker build -f Dockerfile.debian \ -t zhayujie/chatgpt-on-wechat . # tag image -docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian +docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian \ No newline at end of file diff --git a/docker/build.latest.sh b/docker/build.latest.sh index 0f06f8228..92c356497 100644 --- a/docker/build.latest.sh +++ b/docker/build.latest.sh @@ -1,4 +1,8 @@ #!/bin/bash +unset KUBECONFIG + cd .. && docker build -f docker/Dockerfile.latest \ - -t zhayujie/chatgpt-on-wechat . \ No newline at end of file + -t zhayujie/chatgpt-on-wechat . + +docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d) \ No newline at end of file diff --git a/docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine b/docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine index 85248ae43..91e94284f 100644 --- a/docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine +++ b/docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine @@ -9,7 +9,7 @@ RUN apk add --no-cache \ ffmpeg \ espeak \ && pip install --no-cache \ - baidu-aip \ + baidu-aip \ chardet \ SpeechRecognition diff --git a/docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian b/docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian index 4e999e2c5..df58d02a9 100644 --- a/docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian +++ b/docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian @@ -10,7 +10,7 @@ RUN apt-get update \ ffmpeg \ espeak \ && pip install --no-cache \ - baidu-aip \ + baidu-aip \ chardet \ SpeechRecognition diff --git a/docker/sample-chatgpt-on-wechat/Makefile b/docker/sample-chatgpt-on-wechat/Makefile index 31b5f817b..456063456 100644 --- a/docker/sample-chatgpt-on-wechat/Makefile +++ b/docker/sample-chatgpt-on-wechat/Makefile @@ -11,13 +11,13 @@ run_d: docker rm $(CONTAINER_NAME) || echo docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \ --env-file=$(DOTENV) \ - $(MOUNT) $(IMG) + $(MOUNT) $(IMG) run_i: docker rm $(CONTAINER_NAME) || echo docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \ --env-file=$(DOTENV) \ - $(MOUNT) $(IMG) + $(MOUNT) $(IMG) stop: docker stop $(CONTAINER_NAME) diff --git a/plugins/README.md b/plugins/README.md index 1f07a2cbe..2a4461573 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -24,17 +24,17 @@ 在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。 - 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。 - + - 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。 - + 安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。 - + - 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui - 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git - + 在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。 - + 安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。 ## 插件化实现 @@ -107,14 +107,14 @@ ``` 回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。 - + ```python class ReplyType(Enum): TEXT = 1 # 文本 VOICE = 2 # 音频文件 IMAGE = 3 # 图片文件 IMAGE_URL = 4 # 图片URL - + INFO = 9 ERROR = 10 class Reply: @@ -159,12 +159,12 @@ 目前支持三类触发事件: ``` -1.收到消息 ----> `ON_HANDLE_CONTEXT` -2.产生回复 ----> `ON_DECORATE_REPLY` -3.装饰回复 ----> `ON_SEND_REPLY` +1.收到消息 +---> `ON_HANDLE_CONTEXT` +2.产生回复 +---> `ON_DECORATE_REPLY` +3.装饰回复 +---> `ON_SEND_REPLY` 4.发送回复 ``` @@ -268,6 +268,6 @@ class Hello(Plugin): - 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。 在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。 - + - 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。 - 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。 diff --git a/plugins/__init__.py b/plugins/__init__.py index 6137d4a82..d515edb7c 100644 --- a/plugins/__init__.py +++ b/plugins/__init__.py @@ -1,9 +1,9 @@ -from .plugin_manager import PluginManager from .event import * from .plugin import * +from .plugin_manager import PluginManager instance = PluginManager() -register = instance.register +register = instance.register # load_plugins = instance.load_plugins # emit_event = instance.emit_event diff --git a/plugins/banwords/WordsSearch.py b/plugins/banwords/WordsSearch.py deleted file mode 100644 index d41d6e7f2..000000000 --- a/plugins/banwords/WordsSearch.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python -# -*- coding:utf-8 -*- -# ToolGood.Words.WordsSearch.py -# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words -# Licensed under the Apache License 2.0 -# 更新日志 -# 2020.04.06 第一次提交 -# 2020.05.16 修改,支持大于0xffff的字符 - -__all__ = ['WordsSearch'] -__author__ = 'Lin Zhijun' -__date__ = '2020.05.16' - -class TrieNode(): - def __init__(self): - self.Index = 0 - self.Index = 0 - self.Layer = 0 - self.End = False - self.Char = '' - self.Results = [] - self.m_values = {} - self.Failure = None - self.Parent = None - - def Add(self,c): - if c in self.m_values : - return self.m_values[c] - node = TrieNode() - node.Parent = self - node.Char = c - self.m_values[c] = node - return node - - def SetResults(self,index): - if (self.End == False): - self.End = True - self.Results.append(index) - -class TrieNode2(): - def __init__(self): - self.End = False - self.Results = [] - self.m_values = {} - self.minflag = 0xffff - self.maxflag = 0 - - def Add(self,c,node3): - if (self.minflag > c): - self.minflag = c - if (self.maxflag < c): - self.maxflag = c - self.m_values[c] = node3 - - def SetResults(self,index): - if (self.End == False) : - self.End = True - if (index in self.Results )==False : - self.Results.append(index) - - def HasKey(self,c): - return c in self.m_values - - - def TryGetValue(self,c): - if (self.minflag <= c and self.maxflag >= c): - if c in self.m_values: - return self.m_values[c] - return None - - -class WordsSearch(): - def __init__(self): - self._first = {} - self._keywords = [] - self._indexs=[] - - def SetKeywords(self,keywords): - self._keywords = keywords - self._indexs=[] - for i in range(len(keywords)): - self._indexs.append(i) - - root = TrieNode() - allNodeLayer={} - - for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++) - p = self._keywords[i] - nd = root - for j in range(len(p)): # for (j = 0; j < p.length; j++) - nd = nd.Add(ord(p[j])) - if (nd.Layer == 0): - nd.Layer = j + 1 - if nd.Layer in allNodeLayer: - allNodeLayer[nd.Layer].append(nd) - else: - allNodeLayer[nd.Layer]=[] - allNodeLayer[nd.Layer].append(nd) - nd.SetResults(i) - - - allNode = [] - allNode.append(root) - for key in allNodeLayer.keys(): - for nd in allNodeLayer[key]: - allNode.append(nd) - allNodeLayer=None - - for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) - if i==0 : - continue - nd=allNode[i] - nd.Index = i - r = nd.Parent.Failure - c = nd.Char - while (r != None and (c in r.m_values)==False): - r = r.Failure - if (r == None): - nd.Failure = root - else: - nd.Failure = r.m_values[c] - for key2 in nd.Failure.Results : - nd.SetResults(key2) - root.Failure = root - - allNode2 = [] - for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++) - allNode2.append( TrieNode2()) - - for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++) - oldNode = allNode[i] - newNode = allNode2[i] - - for key in oldNode.m_values : - index = oldNode.m_values[key].Index - newNode.Add(key, allNode2[index]) - - for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++) - item = oldNode.Results[index] - newNode.SetResults(item) - - oldNode=oldNode.Failure - while oldNode != root: - for key in oldNode.m_values : - if (newNode.HasKey(key) == False): - index = oldNode.m_values[key].Index - newNode.Add(key, allNode2[index]) - for index in range(len(oldNode.Results)): - item = oldNode.Results[index] - newNode.SetResults(item) - oldNode=oldNode.Failure - allNode = None - root = None - - # first = [] - # for index in range(65535):# for (index = 0; index < 0xffff; index++) - # first.append(None) - - # for key in allNode2[0].m_values : - # first[key] = allNode2[0].m_values[key] - - self._first = allNode2[0] - - - def FindFirst(self,text): - ptr = None - for index in range(len(text)): # for (index = 0; index < text.length; index++) - t =ord(text[index]) # text.charCodeAt(index) - tn = None - if (ptr == None): - tn = self._first.TryGetValue(t) - else: - tn = ptr.TryGetValue(t) - if (tn==None): - tn = self._first.TryGetValue(t) - - - if (tn != None): - if (tn.End): - item = tn.Results[0] - keyword = self._keywords[item] - return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] } - ptr = tn - return None - - def FindAll(self,text): - ptr = None - list = [] - - for index in range(len(text)): # for (index = 0; index < text.length; index++) - t =ord(text[index]) # text.charCodeAt(index) - tn = None - if (ptr == None): - tn = self._first.TryGetValue(t) - else: - tn = ptr.TryGetValue(t) - if (tn==None): - tn = self._first.TryGetValue(t) - - - if (tn != None): - if (tn.End): - for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++) - item = tn.Results[j] - keyword = self._keywords[item] - list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }) - ptr = tn - return list - - - def ContainsAny(self,text): - ptr = None - for index in range(len(text)): # for (index = 0; index < text.length; index++) - t =ord(text[index]) # text.charCodeAt(index) - tn = None - if (ptr == None): - tn = self._first.TryGetValue(t) - else: - tn = ptr.TryGetValue(t) - if (tn==None): - tn = self._first.TryGetValue(t) - - if (tn != None): - if (tn.End): - return True - ptr = tn - return False - - def Replace(self,text, replaceChar = '*'): - result = list(text) - - ptr = None - for i in range(len(text)): # for (i = 0; i < text.length; i++) - t =ord(text[i]) # text.charCodeAt(index) - tn = None - if (ptr == None): - tn = self._first.TryGetValue(t) - else: - tn = ptr.TryGetValue(t) - if (tn==None): - tn = self._first.TryGetValue(t) - - if (tn != None): - if (tn.End): - maxLength = len( self._keywords[tn.Results[0]]) - start = i + 1 - maxLength - for j in range(start,i+1): # for (j = start; j <= i; j++) - result[j] = replaceChar - ptr = tn - return ''.join(result) \ No newline at end of file diff --git a/plugins/banwords/__init__.py b/plugins/banwords/__init__.py index 97b59e1b5..503a56364 100644 --- a/plugins/banwords/__init__.py +++ b/plugins/banwords/__init__.py @@ -1 +1 @@ -from .banwords import * \ No newline at end of file +from .banwords import * diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py index 2d94af475..138b4c246 100644 --- a/plugins/banwords/banwords.py +++ b/plugins/banwords/banwords.py @@ -2,56 +2,67 @@ import json import os + +import plugins from bridge.context import ContextType from bridge.reply import Reply, ReplyType -import plugins -from plugins import * from common.log import logger -from .WordsSearch import WordsSearch +from plugins import * + +from .lib.WordsSearch import WordsSearch -@plugins.register(name="Banwords", desire_priority=100, hidden=True, desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent") +@plugins.register( + name="Banwords", + desire_priority=100, + hidden=True, + desc="判断消息中是否有敏感词、决定是否回复。", + version="1.0", + author="lanvent", +) class Banwords(Plugin): def __init__(self): super().__init__() try: - curdir=os.path.dirname(__file__) - config_path=os.path.join(curdir,"config.json") - conf=None + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "config.json") + conf = None if not os.path.exists(config_path): - conf={"action":"ignore"} - with open(config_path,"w") as f: - json.dump(conf,f,indent=4) + conf = {"action": "ignore"} + with open(config_path, "w") as f: + json.dump(conf, f, indent=4) else: - with open(config_path,"r") as f: - conf=json.load(f) + with open(config_path, "r") as f: + conf = json.load(f) self.searchr = WordsSearch() self.action = conf["action"] - banwords_path = os.path.join(curdir,"banwords.txt") - with open(banwords_path, 'r', encoding='utf-8') as f: - words=[] + banwords_path = os.path.join(curdir, "banwords.txt") + with open(banwords_path, "r", encoding="utf-8") as f: + words = [] for line in f: word = line.strip() if word: words.append(word) self.searchr.SetKeywords(words) self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context - if conf.get("reply_filter",True): + if conf.get("reply_filter", True): self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply - self.reply_action = conf.get("reply_action","ignore") + self.reply_action = conf.get("reply_action", "ignore") logger.info("[Banwords] inited") except Exception as e: - logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .") + logger.warn( + "[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." + ) raise e - - def on_handle_context(self, e_context: EventContext): - - if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: + if e_context["context"].type not in [ + ContextType.TEXT, + ContextType.IMAGE_CREATE, + ]: return - - content = e_context['context'].content + + content = e_context["context"].content logger.debug("[Banwords] on_handle_context. content: %s" % content) if self.action == "ignore": f = self.searchr.FindFirst(content) @@ -61,31 +72,34 @@ def on_handle_context(self, e_context: EventContext): return elif self.action == "replace": if self.searchr.ContainsAny(content): - reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) - e_context['reply'] = reply + reply = Reply( + ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content) + ) + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return - - def on_decorate_reply(self, e_context: EventContext): - if e_context['reply'].type not in [ReplyType.TEXT]: + def on_decorate_reply(self, e_context: EventContext): + if e_context["reply"].type not in [ReplyType.TEXT]: return - - reply = e_context['reply'] + + reply = e_context["reply"] content = reply.content if self.reply_action == "ignore": f = self.searchr.FindFirst(content) if f: logger.info("[Banwords] %s in reply" % f["Keyword"]) - e_context['reply'] = None + e_context["reply"] = None e_context.action = EventAction.BREAK_PASS return elif self.reply_action == "replace": if self.searchr.ContainsAny(content): - reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n"+self.searchr.Replace(content)) - e_context['reply'] = reply + reply = Reply( + ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content) + ) + e_context["reply"] = reply e_context.action = EventAction.CONTINUE return - + def get_help_text(self, **kwargs): - return Banwords.desc \ No newline at end of file + return Banwords.desc diff --git a/plugins/banwords/config.json.template b/plugins/banwords/config.json.template index 8b4d3d101..3117a831f 100644 --- a/plugins/banwords/config.json.template +++ b/plugins/banwords/config.json.template @@ -1,5 +1,5 @@ { - "action": "replace", - "reply_filter": true, - "reply_action": "ignore" -} \ No newline at end of file + "action": "replace", + "reply_filter": true, + "reply_action": "ignore" +} diff --git a/plugins/bdunit/README.md b/plugins/bdunit/README.md index f9f9fe401..a2f2c782f 100644 --- a/plugins/bdunit/README.md +++ b/plugins/bdunit/README.md @@ -24,7 +24,7 @@ see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087 ``` json { "service_id": "s...", #"机器人ID" - "api_key": "", + "api_key": "", "secret_key": "" } ``` \ No newline at end of file diff --git a/plugins/bdunit/__init__.py b/plugins/bdunit/__init__.py index bb327a29d..28f44b424 100644 --- a/plugins/bdunit/__init__.py +++ b/plugins/bdunit/__init__.py @@ -1 +1 @@ -from .bdunit import * \ No newline at end of file +from .bdunit import * diff --git a/plugins/bdunit/bdunit.py b/plugins/bdunit/bdunit.py index 59302a570..4523b940a 100644 --- a/plugins/bdunit/bdunit.py +++ b/plugins/bdunit/bdunit.py @@ -2,21 +2,29 @@ import json import os import uuid +from uuid import getnode as get_mac + import requests + +import plugins from bridge.context import ContextType from bridge.reply import Reply, ReplyType from common.log import logger -import plugins from plugins import * -from uuid import getnode as get_mac - """利用百度UNIT实现智能对话 如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理 """ -@plugins.register(name="BDunit", desire_priority=0, hidden=True, desc="Baidu unit bot system", version="0.1", author="jackson") +@plugins.register( + name="BDunit", + desire_priority=0, + hidden=True, + desc="Baidu unit bot system", + version="0.1", + author="jackson", +) class BDunit(Plugin): def __init__(self): super().__init__() @@ -40,11 +48,10 @@ def __init__(self): raise e def on_handle_context(self, e_context: EventContext): - - if e_context['context'].type != ContextType.TEXT: + if e_context["context"].type != ContextType.TEXT: return - content = e_context['context'].content + content = e_context["context"].content logger.debug("[BDunit] on_handle_context. content: %s" % content) parsed = self.getUnit2(content) intent = self.getIntent(parsed) @@ -53,7 +60,7 @@ def on_handle_context(self, e_context: EventContext): reply = Reply() reply.type = ReplyType.TEXT reply.content = self.getSay(parsed) - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 else: e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 @@ -70,17 +77,15 @@ def get_token(self): string: access_token """ url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( - self.api_key, self.secret_key) + self.api_key, self.secret_key + ) payload = "" - headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json' - } + headers = {"Content-Type": "application/json", "Accept": "application/json"} response = requests.request("POST", url, headers=headers, data=payload) # print(response.text) - return response.json()['access_token'] + return response.json()["access_token"] def getUnit(self, query): """ @@ -90,11 +95,14 @@ def getUnit(self, query): """ url = ( - 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token ) - request = {"query": query, "user_id": str( - get_mac())[:32], "terminal_id": "88888"} + request = { + "query": query, + "user_id": str(get_mac())[:32], + "terminal_id": "88888", + } body = { "log_id": str(uuid.uuid1()), "version": "3.0", @@ -142,11 +150,7 @@ def getIntent(self, parsed): :param parsed: UNIT 解析结果 :returns: 意图数组 """ - if ( - parsed - and "result" in parsed - and "response_list" in parsed["result"] - ): + if parsed and "result" in parsed and "response_list" in parsed["result"]: try: return parsed["result"]["response_list"][0]["schema"]["intent"] except Exception as e: @@ -163,11 +167,7 @@ def hasIntent(self, parsed, intent): :param intent: 意图的名称 :returns: True: 包含; False: 不包含 """ - if ( - parsed - and "result" in parsed - and "response_list" in parsed["result"] - ): + if parsed and "result" in parsed and "response_list" in parsed["result"]: response_list = parsed["result"]["response_list"] for response in response_list: if ( @@ -189,11 +189,7 @@ def getSlots(self, parsed, intent=""): :returns: 词槽列表。你可以通过 name 属性筛选词槽, 再通过 normalized_word 属性取出相应的值 """ - if ( - parsed - and "result" in parsed - and "response_list" in parsed["result"] - ): + if parsed and "result" in parsed and "response_list" in parsed["result"]: response_list = parsed["result"]["response_list"] if intent == "": try: @@ -236,11 +232,7 @@ def getSayByConfidence(self, parsed): :param parsed: UNIT 解析结果 :returns: UNIT 的回复文本 """ - if ( - parsed - and "result" in parsed - and "response_list" in parsed["result"] - ): + if parsed and "result" in parsed and "response_list" in parsed["result"]: response_list = parsed["result"]["response_list"] answer = {} for response in response_list: @@ -266,11 +258,7 @@ def getSay(self, parsed, intent=""): :param intent: 意图的名称 :returns: UNIT 的回复文本 """ - if ( - parsed - and "result" in parsed - and "response_list" in parsed["result"] - ): + if parsed and "result" in parsed and "response_list" in parsed["result"]: response_list = parsed["result"]["response_list"] if intent == "": try: diff --git a/plugins/bdunit/config.json.template b/plugins/bdunit/config.json.template index aac83b16f..c3bad56f7 100644 --- a/plugins/bdunit/config.json.template +++ b/plugins/bdunit/config.json.template @@ -1,5 +1,5 @@ { - "service_id": "s...", - "api_key": "", - "secret_key": "" -} \ No newline at end of file + "service_id": "s...", + "api_key": "", + "secret_key": "" +} diff --git a/plugins/dungeon/__init__.py b/plugins/dungeon/__init__.py index 4c35098ac..6b1044364 100644 --- a/plugins/dungeon/__init__.py +++ b/plugins/dungeon/__init__.py @@ -1 +1 @@ -from .dungeon import * \ No newline at end of file +from .dungeon import * diff --git a/plugins/dungeon/dungeon.py b/plugins/dungeon/dungeon.py index f7628798b..2e3cdf1ad 100644 --- a/plugins/dungeon/dungeon.py +++ b/plugins/dungeon/dungeon.py @@ -1,17 +1,18 @@ # encoding:utf-8 +import plugins from bridge.bridge import Bridge from bridge.context import ContextType from bridge.reply import Reply, ReplyType +from common import const from common.expired_dict import ExpiredDict +from common.log import logger from config import conf -import plugins from plugins import * -from common.log import logger -from common import const + # https://github.com/bupticybee/ChineseAiDungeonChatGPT -class StoryTeller(): +class StoryTeller: def __init__(self, bot, sessionid, story): self.bot = bot self.sessionid = sessionid @@ -27,67 +28,85 @@ def action(self, user_action): if user_action[-1] != "。": user_action = user_action + "。" if self.first_interact: - prompt = """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 - 开头是,""" + self.story + " " + user_action + prompt = ( + """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 + 开头是,""" + + self.story + + " " + + user_action + ) self.first_interact = False else: prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action return prompt -@plugins.register(name="Dungeon", desire_priority=0, namecn="文字冒险", desc="A plugin to play dungeon game", version="1.0", author="lanvent") +@plugins.register( + name="Dungeon", + desire_priority=0, + namecn="文字冒险", + desc="A plugin to play dungeon game", + version="1.0", + author="lanvent", +) class Dungeon(Plugin): def __init__(self): super().__init__() self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[Dungeon] inited") # 目前没有设计session过期事件,这里先暂时使用过期字典 - if conf().get('expires_in_seconds'): - self.games = ExpiredDict(conf().get('expires_in_seconds')) + if conf().get("expires_in_seconds"): + self.games = ExpiredDict(conf().get("expires_in_seconds")) else: self.games = dict() def on_handle_context(self, e_context: EventContext): - - if e_context['context'].type != ContextType.TEXT: + if e_context["context"].type != ContextType.TEXT: return bottype = Bridge().get_bot_type("chat") if bottype not in (const.CHATGPT, const.OPEN_AI): return bot = Bridge().get_bot("chat") - content = e_context['context'].content[:] - clist = e_context['context'].content.split(maxsplit=1) - sessionid = e_context['context']['session_id'] + content = e_context["context"].content[:] + clist = e_context["context"].content.split(maxsplit=1) + sessionid = e_context["context"]["session_id"] logger.debug("[Dungeon] on_handle_context. content: %s" % clist) - trigger_prefix = conf().get('plugin_trigger_prefix', "$") + trigger_prefix = conf().get("plugin_trigger_prefix", "$") if clist[0] == f"{trigger_prefix}停止冒险": if sessionid in self.games: self.games[sessionid].reset() del self.games[sessionid] reply = Reply(ReplyType.INFO, "冒险结束!") - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games: if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险": - if len(clist)>1 : + if len(clist) > 1: story = clist[1] else: - story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" + story = ( + "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" + ) self.games[sessionid] = StoryTeller(bot, sessionid, story) reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) - e_context['reply'] = reply - e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 else: prompt = self.games[sessionid].action(content) - e_context['context'].type = ContextType.TEXT - e_context['context'].content = prompt - e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑 + e_context["context"].type = ContextType.TEXT + e_context["context"].content = prompt + e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑 + def get_help_text(self, **kwargs): help_text = "可以和机器人一起玩文字冒险游戏。\n" - if kwargs.get('verbose') != True: + if kwargs.get("verbose") != True: return help_text - trigger_prefix = conf().get('plugin_trigger_prefix', "$") - help_text = f"{trigger_prefix}开始冒险 "+"背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"+f"{trigger_prefix}停止冒险: 结束游戏。\n" - if kwargs.get('verbose') == True: + trigger_prefix = conf().get("plugin_trigger_prefix", "$") + help_text = ( + f"{trigger_prefix}开始冒险 " + + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + + f"{trigger_prefix}停止冒险: 结束游戏。\n" + ) + if kwargs.get("verbose") == True: help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" - return help_text \ No newline at end of file + return help_text diff --git a/plugins/event.py b/plugins/event.py index a11a4abbb..df8c609e7 100644 --- a/plugins/event.py +++ b/plugins/event.py @@ -9,17 +9,17 @@ class Event(Enum): e_context = { "channel": 消息channel, "context" : 本次消息的context} """ - ON_HANDLE_CONTEXT = 2 # 处理消息前 + ON_HANDLE_CONTEXT = 2 # 处理消息前 """ e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } """ - ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 + ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 """ e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } """ - ON_SEND_REPLY = 4 # 发送回复前 + ON_SEND_REPLY = 4 # 发送回复前 """ e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } """ @@ -28,9 +28,9 @@ class Event(Enum): class EventAction(Enum): - CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 - BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 - BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 + CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 + BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 + BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 class EventContext: diff --git a/plugins/finish/__init__.py b/plugins/finish/__init__.py index 42e0ec7ca..8c1cfd9c9 100644 --- a/plugins/finish/__init__.py +++ b/plugins/finish/__init__.py @@ -1 +1 @@ -from .finish import * \ No newline at end of file +from .finish import * diff --git a/plugins/finish/finish.py b/plugins/finish/finish.py index 782b2e3bf..a3c87ea03 100644 --- a/plugins/finish/finish.py +++ b/plugins/finish/finish.py @@ -1,14 +1,21 @@ # encoding:utf-8 +import plugins from bridge.context import ContextType from bridge.reply import Reply, ReplyType +from common.log import logger from config import conf -import plugins from plugins import * -from common.log import logger -@plugins.register(name="Finish", desire_priority=-999, hidden=True, desc="A plugin that check unknown command", version="1.0", author="js00000") +@plugins.register( + name="Finish", + desire_priority=-999, + hidden=True, + desc="A plugin that check unknown command", + version="1.0", + author="js00000", +) class Finish(Plugin): def __init__(self): super().__init__() @@ -16,19 +23,18 @@ def __init__(self): logger.info("[Finish] inited") def on_handle_context(self, e_context: EventContext): - - if e_context['context'].type != ContextType.TEXT: + if e_context["context"].type != ContextType.TEXT: return - content = e_context['context'].content + content = e_context["context"].content logger.debug("[Finish] on_handle_context. content: %s" % content) - trigger_prefix = conf().get('plugin_trigger_prefix',"$") + trigger_prefix = conf().get("plugin_trigger_prefix", "$") if content.startswith(trigger_prefix): reply = Reply() reply.type = ReplyType.ERROR reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n" - e_context['reply'] = reply - e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 def get_help_text(self, **kwargs): return "" diff --git a/plugins/godcmd/__init__.py b/plugins/godcmd/__init__.py index a0a68bd08..0e2655242 100644 --- a/plugins/godcmd/__init__.py +++ b/plugins/godcmd/__init__.py @@ -1 +1 @@ -from .godcmd import * \ No newline at end of file +from .godcmd import * diff --git a/plugins/godcmd/config.json.template b/plugins/godcmd/config.json.template index 524073858..ed021e075 100644 --- a/plugins/godcmd/config.json.template +++ b/plugins/godcmd/config.json.template @@ -1,4 +1,4 @@ { - "password": "", - "admin_users": [] -} \ No newline at end of file + "password": "", + "admin_users": [] +} diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 46f11afed..4e1b05969 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -6,14 +6,16 @@ import string import traceback from typing import Tuple + +import plugins from bridge.bridge import Bridge from bridge.context import ContextType from bridge.reply import Reply, ReplyType -from config import conf, load_config -import plugins -from plugins import * from common import const from common.log import logger +from config import conf, load_config +from plugins import * + # 定义指令集 COMMANDS = { "help": { @@ -41,7 +43,7 @@ }, "id": { "alias": ["id", "用户"], - "desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员 + "desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员 }, "reset": { "alias": ["reset", "重置会话"], @@ -114,18 +116,20 @@ "desc": "开启机器调试日志", }, } + + # 定义帮助函数 def get_help_text(isadmin, isgroup): help_text = "通用指令:\n" for cmd, info in COMMANDS.items(): - if cmd=="auth": #不提示认证指令 + if cmd == "auth": # 不提示认证指令 continue - if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]: + if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]: continue - alias=["#"+a for a in info['alias'][:1]] + alias = ["#" + a for a in info["alias"][:1]] help_text += f"{','.join(alias)} " - if 'args' in info: - args=[a for a in info['args']] + if "args" in info: + args = [a for a in info["args"]] help_text += f"{' '.join(args)}" help_text += f": {info['desc']}\n" @@ -135,39 +139,48 @@ def get_help_text(isadmin, isgroup): for plugin in plugins: if plugins[plugin].enabled and not plugins[plugin].hidden: namecn = plugins[plugin].namecn - help_text += "\n%s:"%namecn - help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() + help_text += "\n%s:" % namecn + help_text += ( + PluginManager().instances[plugin].get_help_text(verbose=False).strip() + ) if ADMIN_COMMANDS and isadmin: help_text += "\n\n管理员指令:\n" for cmd, info in ADMIN_COMMANDS.items(): - alias=["#"+a for a in info['alias'][:1]] + alias = ["#" + a for a in info["alias"][:1]] help_text += f"{','.join(alias)} " - if 'args' in info: - args=[a for a in info['args']] + if "args" in info: + args = [a for a in info["args"]] help_text += f"{' '.join(args)}" help_text += f": {info['desc']}\n" return help_text -@plugins.register(name="Godcmd", desire_priority=999, hidden=True, desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent") -class Godcmd(Plugin): +@plugins.register( + name="Godcmd", + desire_priority=999, + hidden=True, + desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", + version="1.0", + author="lanvent", +) +class Godcmd(Plugin): def __init__(self): super().__init__() - curdir=os.path.dirname(__file__) - config_path=os.path.join(curdir,"config.json") - gconf=None + curdir = os.path.dirname(__file__) + config_path = os.path.join(curdir, "config.json") + gconf = None if not os.path.exists(config_path): - gconf={"password":"","admin_users":[]} - with open(config_path,"w") as f: - json.dump(gconf,f,indent=4) + gconf = {"password": "", "admin_users": []} + with open(config_path, "w") as f: + json.dump(gconf, f, indent=4) else: - with open(config_path,"r") as f: - gconf=json.load(f) + with open(config_path, "r") as f: + gconf = json.load(f) if gconf["password"] == "": self.temp_password = "".join(random.sample(string.digits, 4)) - logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。"%self.temp_password) + logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password) else: self.temp_password = None custom_commands = conf().get("clear_memory_commands", []) @@ -178,41 +191,42 @@ def __init__(self): COMMANDS["reset"]["alias"].append(custom_command) self.password = gconf["password"] - self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 - self.isrunning = True # 机器人是否运行中 + self.admin_users = gconf[ + "admin_users" + ] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 + self.isrunning = True # 机器人是否运行中 self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[Godcmd] inited") - def on_handle_context(self, e_context: EventContext): - context_type = e_context['context'].type + context_type = e_context["context"].type if context_type != ContextType.TEXT: if not self.isrunning: e_context.action = EventAction.BREAK_PASS return - content = e_context['context'].content + content = e_context["context"].content logger.debug("[Godcmd] on_handle_context. content: %s" % content) if content.startswith("#"): # msg = e_context['context']['msg'] - channel = e_context['channel'] - user = e_context['context']['receiver'] - session_id = e_context['context']['session_id'] - isgroup = e_context['context'].get("isgroup", False) + channel = e_context["channel"] + user = e_context["context"]["receiver"] + session_id = e_context["context"]["session_id"] + isgroup = e_context["context"].get("isgroup", False) bottype = Bridge().get_bot_type("chat") bot = Bridge().get_bot("chat") # 将命令和参数分割 command_parts = content[1:].strip().split() cmd = command_parts[0] args = command_parts[1:] - isadmin=False + isadmin = False if user in self.admin_users: - isadmin=True - ok=False - result="string" - if any(cmd in info['alias'] for info in COMMANDS.values()): - cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) + isadmin = True + ok = False + result = "string" + if any(cmd in info["alias"] for info in COMMANDS.values()): + cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"]) if cmd == "auth": ok, result = self.authenticate(user, args, isadmin, isgroup) elif cmd == "help" or cmd == "helpp": @@ -224,10 +238,14 @@ def on_handle_context(self, e_context: EventContext): query_name = args[0].upper() # search name and namecn for name, plugincls in plugins.items(): - if not plugincls.enabled : + if not plugincls.enabled: continue if query_name == name or query_name == plugincls.namecn: - ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True) + ok, result = True, PluginManager().instances[ + name + ].get_help_text( + isgroup=isgroup, isadmin=isadmin, verbose=True + ) break if not ok: result = "插件不存在或未启用" @@ -236,14 +254,14 @@ def on_handle_context(self, e_context: EventContext): elif cmd == "set_openai_api_key": if len(args) == 1: user_data = conf().get_user_data(user) - user_data['openai_api_key'] = args[0] + user_data["openai_api_key"] = args[0] ok, result = True, "你的OpenAI私有api_key已设置为" + args[0] else: ok, result = False, "请提供一个api_key" elif cmd == "reset_openai_api_key": try: user_data = conf().get_user_data(user) - user_data.pop('openai_api_key') + user_data.pop("openai_api_key") ok, result = True, "你的OpenAI私有api_key已清除" except Exception as e: ok, result = False, "你没有设置私有api_key" @@ -255,12 +273,16 @@ def on_handle_context(self, e_context: EventContext): else: ok, result = False, "当前对话机器人不支持重置会话" logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) - elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): + elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()): if isadmin: if isgroup: ok, result = False, "群聊不可执行管理员指令" else: - cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) + cmd = next( + c + for c, info in ADMIN_COMMANDS.items() + if cmd in info["alias"] + ) if cmd == "stop": self.isrunning = False ok, result = True, "服务已暂停" @@ -278,13 +300,13 @@ def on_handle_context(self, e_context: EventContext): else: ok, result = False, "当前对话机器人不支持重置会话" elif cmd == "debug": - logger.setLevel('DEBUG') + logger.setLevel("DEBUG") ok, result = True, "DEBUG模式已开启" elif cmd == "plist": plugins = PluginManager().list_plugins() ok = True result = "插件列表:\n" - for name,plugincls in plugins.items(): + for name, plugincls in plugins.items(): result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " if plugincls.enabled: result += "已启用\n" @@ -294,16 +316,20 @@ def on_handle_context(self, e_context: EventContext): new_plugins = PluginManager().scan_plugins() ok, result = True, "插件扫描完成" PluginManager().activate_plugins() - if len(new_plugins) >0 : + if len(new_plugins) > 0: result += "\n发现新插件:\n" - result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) - else : - result +=", 未发现新插件" + result += "\n".join( + [f"{p.name}_v{p.version}" for p in new_plugins] + ) + else: + result += ", 未发现新插件" elif cmd == "setpri": if len(args) != 2: ok, result = False, "请提供插件名和优先级" else: - ok = PluginManager().set_plugin_priority(args[0], int(args[1])) + ok = PluginManager().set_plugin_priority( + args[0], int(args[1]) + ) if ok: result = "插件" + args[0] + "优先级已设置为" + args[1] else: @@ -350,42 +376,42 @@ def on_handle_context(self, e_context: EventContext): else: ok, result = False, "需要管理员权限才能执行该指令" else: - trigger_prefix = conf().get('plugin_trigger_prefix',"$") - if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交 + trigger_prefix = conf().get("plugin_trigger_prefix", "$") + if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交 return ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" - + reply = Reply() if ok: reply.type = ReplyType.INFO else: reply.type = ReplyType.ERROR reply.content = result - e_context['reply'] = reply + e_context["reply"] = reply - e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 elif not self.isrunning: e_context.action = EventAction.BREAK_PASS - def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : + def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]: if isgroup: - return False,"请勿在群聊中认证" - + return False, "请勿在群聊中认证" + if isadmin: - return False,"管理员账号无需认证" - + return False, "管理员账号无需认证" + if len(args) != 1: - return False,"请提供口令" - + return False, "请提供口令" + password = args[0] if password == self.password: self.admin_users.append(userid) - return True,"认证成功" + return True, "认证成功" elif password == self.temp_password: self.admin_users.append(userid) - return True,"认证成功,请尽快设置口令" + return True, "认证成功,请尽快设置口令" else: - return False,"认证失败" + return False, "认证失败" - def get_help_text(self, isadmin = False, isgroup = False, **kwargs): - return get_help_text(isadmin, isgroup) \ No newline at end of file + def get_help_text(self, isadmin=False, isgroup=False, **kwargs): + return get_help_text(isadmin, isgroup) diff --git a/plugins/hello/__init__.py b/plugins/hello/__init__.py index 3b3590ad0..d9b15a1c7 100644 --- a/plugins/hello/__init__.py +++ b/plugins/hello/__init__.py @@ -1 +1 @@ -from .hello import * \ No newline at end of file +from .hello import * diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 941318382..4067c2ac3 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -1,14 +1,21 @@ # encoding:utf-8 +import plugins from bridge.context import ContextType from bridge.reply import Reply, ReplyType from channel.chat_message import ChatMessage -import plugins -from plugins import * from common.log import logger +from plugins import * -@plugins.register(name="Hello", desire_priority=-1, hidden=True, desc="A simple plugin that says hello", version="0.1", author="lanvent") +@plugins.register( + name="Hello", + desire_priority=-1, + hidden=True, + desc="A simple plugin that says hello", + version="0.1", + author="lanvent", +) class Hello(Plugin): def __init__(self): super().__init__() @@ -16,33 +23,34 @@ def __init__(self): logger.info("[Hello] inited") def on_handle_context(self, e_context: EventContext): - - if e_context['context'].type != ContextType.TEXT: + if e_context["context"].type != ContextType.TEXT: return - - content = e_context['context'].content + + content = e_context["context"].content logger.debug("[Hello] on_handle_context. content: %s" % content) if content == "Hello": reply = Reply() reply.type = ReplyType.TEXT - msg:ChatMessage = e_context['context']['msg'] - if e_context['context']['isgroup']: - reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" + msg: ChatMessage = e_context["context"]["msg"] + if e_context["context"]["isgroup"]: + reply.content = ( + f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" + ) else: reply.content = f"Hello, {msg.from_user_nickname}" - e_context['reply'] = reply - e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 + e_context["reply"] = reply + e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 if content == "Hi": reply = Reply() reply.type = ReplyType.TEXT reply.content = "Hi" - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply if content == "End": # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" - e_context['context'].type = ContextType.IMAGE_CREATE + e_context["context"].type = ContextType.IMAGE_CREATE content = "The World" e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 diff --git a/plugins/plugin.py b/plugins/plugin.py index 289b4f84d..6c82c8de7 100644 --- a/plugins/plugin.py +++ b/plugins/plugin.py @@ -3,4 +3,4 @@ def __init__(self): self.handlers = {} def get_help_text(self, **kwargs): - return "暂无帮助信息" \ No newline at end of file + return "暂无帮助信息" diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index 619edea0e..44b400d4d 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -5,17 +5,19 @@ import json import os import sys + +from common.log import logger from common.singleton import singleton from common.sorted_dict import SortedDict -from .event import * -from common.log import logger from config import conf +from .event import * + @singleton class PluginManager: def __init__(self): - self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) + self.plugins = SortedDict(lambda k, v: v.priority, reverse=True) self.listening_plugins = {} self.instances = {} self.pconf = {} @@ -26,17 +28,27 @@ def register(self, name: str, desire_priority: int = 0, **kwargs): def wrapper(plugincls): plugincls.name = name plugincls.priority = desire_priority - plugincls.desc = kwargs.get('desc') - plugincls.author = kwargs.get('author') + plugincls.desc = kwargs.get("desc") + plugincls.author = kwargs.get("author") plugincls.path = self.current_plugin_path - plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0" - plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name - plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False + plugincls.version = ( + kwargs.get("version") if kwargs.get("version") != None else "1.0" + ) + plugincls.namecn = ( + kwargs.get("namecn") if kwargs.get("namecn") != None else name + ) + plugincls.hidden = ( + kwargs.get("hidden") if kwargs.get("hidden") != None else False + ) plugincls.enabled = True if self.current_plugin_path == None: raise Exception("Plugin path not set") self.plugins[name.upper()] = plugincls - logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path)) + logger.info( + "Plugin %s_v%s registered, path=%s" + % (name, plugincls.version, plugincls.path) + ) + return wrapper def save_config(self): @@ -50,10 +62,12 @@ def load_config(self): if os.path.exists("./plugins/plugins.json"): with open("./plugins/plugins.json", "r", encoding="utf-8") as f: pconf = json.load(f) - pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) + pconf["plugins"] = SortedDict( + lambda k, v: v["priority"], pconf["plugins"], reverse=True + ) else: modified = True - pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} + pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} self.pconf = pconf if modified: self.save_config() @@ -67,7 +81,7 @@ def scan_plugins(self): plugin_path = os.path.join(plugins_dir, plugin_name) if os.path.isdir(plugin_path): # 判断插件是否包含同名__init__.py文件 - main_module_path = os.path.join(plugin_path,"__init__.py") + main_module_path = os.path.join(plugin_path, "__init__.py") if os.path.isfile(main_module_path): # 导入插件 import_path = "plugins.{}".format(plugin_name) @@ -76,16 +90,26 @@ def scan_plugins(self): if plugin_path in self.loaded: if self.loaded[plugin_path] == None: logger.info("reload module %s" % plugin_name) - self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) - dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')] + self.loaded[plugin_path] = importlib.reload( + sys.modules[import_path] + ) + dependent_module_names = [ + name + for name in sys.modules.keys() + if name.startswith(import_path + ".") + ] for name in dependent_module_names: logger.info("reload module %s" % name) importlib.reload(sys.modules[name]) else: - self.loaded[plugin_path] = importlib.import_module(import_path) + self.loaded[plugin_path] = importlib.import_module( + import_path + ) self.current_plugin_path = None except Exception as e: - logger.exception("Failed to import plugin %s: %s" % (plugin_name, e)) + logger.exception( + "Failed to import plugin %s: %s" % (plugin_name, e) + ) continue pconf = self.pconf news = [self.plugins[name] for name in self.plugins] @@ -95,21 +119,28 @@ def scan_plugins(self): rawname = plugincls.name if rawname not in pconf["plugins"]: modified = True - logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) - pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} + logger.info( + "Plugin %s not found in pconfig, adding to pconfig..." % name + ) + pconf["plugins"][rawname] = { + "enabled": plugincls.enabled, + "priority": plugincls.priority, + } else: self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] self.plugins[name].priority = pconf["plugins"][rawname]["priority"] - self.plugins._update_heap(name) # 更新下plugins中的顺序 + self.plugins._update_heap(name) # 更新下plugins中的顺序 if modified: self.save_config() return new_plugins def refresh_order(self): for event in self.listening_plugins.keys(): - self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) + self.listening_plugins[event].sort( + key=lambda name: self.plugins[name].priority, reverse=True + ) - def activate_plugins(self): # 生成新开启的插件实例 + def activate_plugins(self): # 生成新开启的插件实例 failed_plugins = [] for name, plugincls in self.plugins.items(): if plugincls.enabled: @@ -129,7 +160,7 @@ def activate_plugins(self): # 生成新开启的插件实例 self.refresh_order() return failed_plugins - def reload_plugin(self, name:str): + def reload_plugin(self, name: str): name = name.upper() if name in self.instances: for event in self.listening_plugins: @@ -139,13 +170,13 @@ def reload_plugin(self, name:str): self.activate_plugins() return True return False - + def load_plugins(self): self.load_config() self.scan_plugins() pconf = self.pconf logger.debug("plugins.json config={}".format(pconf)) - for name,plugin in pconf["plugins"].items(): + for name, plugin in pconf["plugins"].items(): if name.upper() not in self.plugins: logger.error("Plugin %s not found, but found in plugins.json" % name) self.activate_plugins() @@ -153,13 +184,18 @@ def load_plugins(self): def emit_event(self, e_context: EventContext, *args, **kwargs): if e_context.event in self.listening_plugins: for name in self.listening_plugins[e_context.event]: - if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: - logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) + if ( + self.plugins[name].enabled + and e_context.action == EventAction.CONTINUE + ): + logger.debug( + "Plugin %s triggered by event %s" % (name, e_context.event) + ) instance = self.instances[name] instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context - def set_plugin_priority(self, name:str, priority:int): + def set_plugin_priority(self, name: str, priority: int): name = name.upper() if name not in self.plugins: return False @@ -174,11 +210,11 @@ def set_plugin_priority(self, name:str, priority:int): self.refresh_order() return True - def enable_plugin(self, name:str): + def enable_plugin(self, name: str): name = name.upper() if name not in self.plugins: return False, "插件不存在" - if not self.plugins[name].enabled : + if not self.plugins[name].enabled: self.plugins[name].enabled = True rawname = self.plugins[name].name self.pconf["plugins"][rawname]["enabled"] = True @@ -188,43 +224,47 @@ def enable_plugin(self, name:str): return False, "插件开启失败" return True, "插件已开启" return True, "插件已开启" - - def disable_plugin(self, name:str): + + def disable_plugin(self, name: str): name = name.upper() if name not in self.plugins: return False - if self.plugins[name].enabled : + if self.plugins[name].enabled: self.plugins[name].enabled = False rawname = self.plugins[name].name self.pconf["plugins"][rawname]["enabled"] = False self.save_config() return True return True - + def list_plugins(self): return self.plugins - - def install_plugin(self, repo:str): + + def install_plugin(self, repo: str): try: import common.package_manager as pkgmgr + pkgmgr.check_dulwich() except Exception as e: logger.error("Failed to install plugin, {}".format(e)) return False, "无法导入dulwich,安装插件失败" import re + from dulwich import porcelain logger.info("clone git repo: {}".format(repo)) - + match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) - + if not match: try: - with open("./plugins/source.json","r", encoding="utf-8") as f: + with open("./plugins/source.json", "r", encoding="utf-8") as f: source = json.load(f) if repo in source["repo"]: repo = source["repo"][repo]["url"] - match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) + match = re.match( + r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo + ) if not match: return False, "安装插件失败,source中的仓库地址不合法" else: @@ -232,42 +272,53 @@ def install_plugin(self, repo:str): except Exception as e: logger.error("Failed to install plugin, {}".format(e)) return False, "安装插件失败,请检查仓库地址是否正确" - dirname = os.path.join("./plugins",match.group(4)) + dirname = os.path.join("./plugins", match.group(4)) try: repo = porcelain.clone(repo, dirname, checkout=True) - if os.path.exists(os.path.join(dirname,"requirements.txt")): + if os.path.exists(os.path.join(dirname, "requirements.txt")): logger.info("detect requirements.txt,installing...") - pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) + pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt")) return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置" except Exception as e: logger.error("Failed to install plugin, {}".format(e)) - return False, "安装插件失败,"+str(e) - - def update_plugin(self, name:str): + return False, "安装插件失败," + str(e) + + def update_plugin(self, name: str): try: import common.package_manager as pkgmgr + pkgmgr.check_dulwich() except Exception as e: logger.error("Failed to install plugin, {}".format(e)) return False, "无法导入dulwich,更新插件失败" from dulwich import porcelain + name = name.upper() if name not in self.plugins: return False, "插件不存在" - if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]: + if name in [ + "HELLO", + "GODCMD", + "ROLE", + "TOOL", + "BDUNIT", + "BANWORDS", + "FINISH", + "DUNGEON", + ]: return False, "预置插件无法更新,请更新主程序仓库" dirname = self.plugins[name].path try: porcelain.pull(dirname, "origin") - if os.path.exists(os.path.join(dirname,"requirements.txt")): + if os.path.exists(os.path.join(dirname, "requirements.txt")): logger.info("detect requirements.txt,installing...") - pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) + pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt")) return True, "更新插件成功,请重新运行程序" except Exception as e: logger.error("Failed to update plugin, {}".format(e)) - return False, "更新插件失败,"+str(e) - - def uninstall_plugin(self, name:str): + return False, "更新插件失败," + str(e) + + def uninstall_plugin(self, name: str): name = name.upper() if name not in self.plugins: return False, "插件不存在" @@ -276,6 +327,7 @@ def uninstall_plugin(self, name:str): dirname = self.plugins[name].path try: import shutil + shutil.rmtree(dirname) rawname = self.plugins[name].name for event in self.listening_plugins: @@ -288,4 +340,4 @@ def uninstall_plugin(self, name:str): return True, "卸载插件成功" except Exception as e: logger.error("Failed to uninstall plugin, {}".format(e)) - return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e) \ No newline at end of file + return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e) diff --git a/plugins/role/__init__.py b/plugins/role/__init__.py index 9e5b2b1ff..82e73ab5c 100644 --- a/plugins/role/__init__.py +++ b/plugins/role/__init__.py @@ -1 +1 @@ -from .role import * \ No newline at end of file +from .role import * diff --git a/plugins/role/role.py b/plugins/role/role.py index 905a1eaea..9788cc1cd 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -2,17 +2,18 @@ import json import os + +import plugins from bridge.bridge import Bridge from bridge.context import ContextType from bridge.reply import Reply, ReplyType from common import const +from common.log import logger from config import conf -import plugins from plugins import * -from common.log import logger -class RolePlay(): +class RolePlay: def __init__(self, bot, sessionid, desc, wrapper=None): self.bot = bot self.sessionid = sessionid @@ -25,12 +26,20 @@ def reset(self): def action(self, user_action): session = self.bot.sessions.build_session(self.sessionid) - if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 + if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 session.set_system_prompt(self.desc) prompt = self.wrapper % user_action return prompt -@plugins.register(name="Role", desire_priority=0, namecn="角色扮演", desc="为你的Bot设置预设角色", version="1.0", author="lanvent") + +@plugins.register( + name="Role", + desire_priority=0, + namecn="角色扮演", + desc="为你的Bot设置预设角色", + version="1.0", + author="lanvent", +) class Role(Plugin): def __init__(self): super().__init__() @@ -39,7 +48,7 @@ def __init__(self): try: with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) - self.tags = { tag:(desc,[]) for tag,desc in config["tags"].items()} + self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()} self.roles = {} for role in config["roles"]: self.roles[role["title"].lower()] = role @@ -60,12 +69,16 @@ def __init__(self): logger.info("[Role] inited") except Exception as e: if isinstance(e, FileNotFoundError): - logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") + logger.warn( + f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." + ) else: - logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") + logger.warn( + "[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." + ) raise e - def get_role(self, name, find_closest=True, min_sim = 0.35): + def get_role(self, name, find_closest=True, min_sim=0.35): name = name.lower() found_role = None if name in self.roles: @@ -75,6 +88,7 @@ def get_role(self, name, find_closest=True, min_sim = 0.35): def str_simularity(a, b): return difflib.SequenceMatcher(None, a, b).ratio() + max_sim = min_sim max_role = None for role in self.roles: @@ -86,25 +100,24 @@ def str_simularity(a, b): return found_role def on_handle_context(self, e_context: EventContext): - - if e_context['context'].type != ContextType.TEXT: + if e_context["context"].type != ContextType.TEXT: return bottype = Bridge().get_bot_type("chat") if bottype not in (const.CHATGPT, const.OPEN_AI): return bot = Bridge().get_bot("chat") - content = e_context['context'].content[:] - clist = e_context['context'].content.split(maxsplit=1) + content = e_context["context"].content[:] + clist = e_context["context"].content.split(maxsplit=1) desckey = None customize = False - sessionid = e_context['context']['session_id'] - trigger_prefix = conf().get('plugin_trigger_prefix', "$") + sessionid = e_context["context"]["session_id"] + trigger_prefix = conf().get("plugin_trigger_prefix", "$") if clist[0] == f"{trigger_prefix}停止扮演": if sessionid in self.roleplays: self.roleplays[sessionid].reset() del self.roleplays[sessionid] reply = Reply(ReplyType.INFO, "角色扮演结束!") - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return elif clist[0] == f"{trigger_prefix}角色": @@ -114,10 +127,10 @@ def on_handle_context(self, e_context: EventContext): elif clist[0] == f"{trigger_prefix}设定扮演": customize = True elif clist[0] == f"{trigger_prefix}角色类型": - if len(clist) >1: + if len(clist) > 1: tag = clist[1].strip() help_text = "角色列表:\n" - for key,value in self.tags.items(): + for key, value in self.tags.items(): if value[0] == tag: tag = key break @@ -130,57 +143,75 @@ def on_handle_context(self, e_context: EventContext): else: help_text = f"未知角色类型。\n" help_text += "目前的角色类型有: \n" - help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n" + help_text += ( + ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" + ) else: help_text = f"请输入角色类型。\n" help_text += "目前的角色类型有: \n" - help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n" + help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" reply = Reply(ReplyType.INFO, help_text) - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return elif sessionid not in self.roleplays: return logger.debug("[Role] on_handle_context. content: %s" % content) if desckey is not None: - if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): + if len(clist) == 1 or ( + len(clist) > 1 and clist[1].lower() in ["help", "帮助"] + ): reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return role = self.get_role(clist[1]) if role is None: reply = Reply(ReplyType.ERROR, "角色不存在") - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return else: - self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s")) - reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n"+self.roles[role][desckey]) - e_context['reply'] = reply + self.roleplays[sessionid] = RolePlay( + bot, + sessionid, + self.roles[role][desckey], + self.roles[role].get("wrapper", "%s"), + ) + reply = Reply( + ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey] + ) + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS elif customize == True: self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s") reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}") - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS else: prompt = self.roleplays[sessionid].action(content) - e_context['context'].type = ContextType.TEXT - e_context['context'].content = prompt + e_context["context"].type = ContextType.TEXT + e_context["context"].content = prompt e_context.action = EventAction.BREAK def get_help_text(self, verbose=False, **kwargs): help_text = "让机器人扮演不同的角色。\n" if not verbose: return help_text - trigger_prefix = conf().get('plugin_trigger_prefix', "$") - help_text = f"使用方法:\n{trigger_prefix}角色"+" 预设角色名: 设定角色为{预设角色名}。\n"+f"{trigger_prefix}role"+" 预设角色名: 同上,但使用英文设定。\n" - help_text += f"{trigger_prefix}设定扮演"+" 角色设定: 设定自定义角色人设为{角色设定}。\n" + trigger_prefix = conf().get("plugin_trigger_prefix", "$") + help_text = ( + f"使用方法:\n{trigger_prefix}角色" + + " 预设角色名: 设定角色为{预设角色名}。\n" + + f"{trigger_prefix}role" + + " 预设角色名: 同上,但使用英文设定。\n" + ) + help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" - help_text += f"{trigger_prefix}角色类型"+" 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" + help_text += ( + f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" + ) help_text += "\n目前的角色类型有: \n" - help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"。\n" + help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" help_text += f"{trigger_prefix}角色类型 所有\n" help_text += f"{trigger_prefix}停止扮演\n" diff --git a/plugins/role/roles.json b/plugins/role/roles.json index 421167330..826627d9b 100644 --- a/plugins/role/roles.json +++ b/plugins/role/roles.json @@ -428,4 +428,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/plugins/source.json b/plugins/source.json index 7e14854a1..438fd9b85 100644 --- a/plugins/source.json +++ b/plugins/source.json @@ -1,16 +1,16 @@ { - "repo": { - "sdwebui": { - "url": "https://github.com/lanvent/plugin_sdwebui.git", - "desc": "利用stable-diffusion画图的插件" - }, - "replicate": { - "url": "https://github.com/lanvent/plugin_replicate.git", - "desc": "利用replicate api画图的插件" - }, - "summary": { - "url": "https://github.com/lanvent/plugin_summary.git", - "desc": "总结聊天记录的插件" - } + "repo": { + "sdwebui": { + "url": "https://github.com/lanvent/plugin_sdwebui.git", + "desc": "利用stable-diffusion画图的插件" + }, + "replicate": { + "url": "https://github.com/lanvent/plugin_replicate.git", + "desc": "利用replicate api画图的插件" + }, + "summary": { + "url": "https://github.com/lanvent/plugin_summary.git", + "desc": "总结聊天记录的插件" } -} \ No newline at end of file + } +} diff --git a/plugins/tool/README.md b/plugins/tool/README.md index 8333d7b92..0a78193fa 100644 --- a/plugins/tool/README.md +++ b/plugins/tool/README.md @@ -1,14 +1,14 @@ ## 插件描述 -一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力 +一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力 使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功 ### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) - - + + ## 使用说明 -使用该插件后将默认使用4个工具, 无需额外配置长期生效: -### 1. python +使用该插件后将默认使用4个工具, 无需额外配置长期生效: +### 1. python ###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务 - + ### 2. url-get ###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响 @@ -23,16 +23,16 @@ > meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334 -## 使用本插件对话(prompt)技巧 -### 1. 有指引的询问 +## 使用本插件对话(prompt)技巧 +### 1. 有指引的询问 #### 例如: -- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub +- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub - 使用Terminal执行curl cip.cc - 使用python查询今天日期 - + ### 2. 使用搜索引擎工具 - 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气 - + ## 其他工具 ### 5. wikipedia @@ -55,9 +55,9 @@ ### 10. google-search * ###### google搜索引擎,申请流程较bing-search繁琐 -###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持 +###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持 #### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md) - + ## config.json 配置说明 ###### 默认工具无需配置,其它工具需手动配置,一个例子: ```json @@ -71,15 +71,15 @@ } ``` -注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对 +注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对 - `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key - `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置 - `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置 - `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具 - `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2 - `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认 - - + + ## 备注 - 强烈建议申请搜索工具搭配使用,推荐bing-search - 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤 diff --git a/plugins/tool/__init__.py b/plugins/tool/__init__.py index d3bf330de..8c9d8dde4 100644 --- a/plugins/tool/__init__.py +++ b/plugins/tool/__init__.py @@ -1 +1 @@ -from .tool import * \ No newline at end of file +from .tool import * diff --git a/plugins/tool/config.json.template b/plugins/tool/config.json.template index 3f237a87f..00d643d7d 100644 --- a/plugins/tool/config.json.template +++ b/plugins/tool/config.json.template @@ -1,8 +1,13 @@ { - "tools": ["python", "url-get", "terminal", "meteo-weather"], + "tools": [ + "python", + "url-get", + "terminal", + "meteo-weather" + ], "kwargs": { - "top_k_results": 2, - "no_default": false, - "model_name": "gpt-3.5-turbo" + "top_k_results": 2, + "no_default": false, + "model_name": "gpt-3.5-turbo" } -} \ No newline at end of file +} diff --git a/plugins/tool/tool.py b/plugins/tool/tool.py index 3e3fc818d..a8c481754 100644 --- a/plugins/tool/tool.py +++ b/plugins/tool/tool.py @@ -4,6 +4,7 @@ from chatgpt_tool_hub.apps import load_app from chatgpt_tool_hub.apps.app import App from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names + import plugins from bridge.bridge import Bridge from bridge.context import ContextType @@ -14,7 +15,13 @@ from plugins import * -@plugins.register(name="tool", desc="Arming your ChatGPT bot with various tools", version="0.3", author="goldfishh", desire_priority=0) +@plugins.register( + name="tool", + desc="Arming your ChatGPT bot with various tools", + version="0.3", + author="goldfishh", + desire_priority=0, +) class Tool(Plugin): def __init__(self): super().__init__() @@ -28,22 +35,26 @@ def get_help_text(self, verbose=False, **kwargs): help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。" if not verbose: return help_text - trigger_prefix = conf().get('plugin_trigger_prefix', "$") + trigger_prefix = conf().get("plugin_trigger_prefix", "$") help_text += "使用说明:\n" - help_text += f"{trigger_prefix}tool "+"命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" + help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" help_text += f"{trigger_prefix}tool reset: 重置工具。\n" return help_text def on_handle_context(self, e_context: EventContext): - if e_context['context'].type != ContextType.TEXT: + if e_context["context"].type != ContextType.TEXT: return # 暂时不支持未来扩展的bot - if Bridge().get_bot_type("chat") not in (const.CHATGPT, const.OPEN_AI, const.CHATGPTONAZURE): + if Bridge().get_bot_type("chat") not in ( + const.CHATGPT, + const.OPEN_AI, + const.CHATGPTONAZURE, + ): return - content = e_context['context'].content - content_list = e_context['context'].content.split(maxsplit=1) + content = e_context["context"].content + content_list = e_context["context"].content.split(maxsplit=1) if not content or len(content_list) < 1: e_context.action = EventAction.CONTINUE @@ -52,13 +63,13 @@ def on_handle_context(self, e_context: EventContext): logger.debug("[tool] on_handle_context. content: %s" % content) reply = Reply() reply.type = ReplyType.TEXT - trigger_prefix = conf().get('plugin_trigger_prefix', "$") + trigger_prefix = conf().get("plugin_trigger_prefix", "$") # todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能 if content.startswith(f"{trigger_prefix}tool"): if len(content_list) == 1: logger.debug("[tool]: get help") reply.content = self.get_help_text() - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return elif len(content_list) > 1: @@ -66,12 +77,14 @@ def on_handle_context(self, e_context: EventContext): logger.debug("[tool]: reset config") self.app = self._reset_app() reply.content = "重置工具成功" - e_context['reply'] = reply + e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return elif content_list[1].startswith("reset"): logger.debug("[tool]: remind") - e_context['context'].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" + e_context[ + "context" + ].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" e_context.action = EventAction.BREAK return @@ -80,34 +93,35 @@ def on_handle_context(self, e_context: EventContext): # Don't modify bot name all_sessions = Bridge().get_bot("chat").sessions - user_session = all_sessions.session_query(query, e_context['context']['session_id']).messages + user_session = all_sessions.session_query( + query, e_context["context"]["session_id"] + ).messages # chatgpt-tool-hub will reply you with many tools logger.debug("[tool]: just-go") try: _reply = self.app.ask(query, user_session) e_context.action = EventAction.BREAK_PASS - all_sessions.session_reply(_reply, e_context['context']['session_id']) + all_sessions.session_reply( + _reply, e_context["context"]["session_id"] + ) except Exception as e: logger.exception(e) logger.error(str(e)) - e_context['context'].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理" + e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理" reply.type = ReplyType.ERROR e_context.action = EventAction.BREAK return reply.content = _reply - e_context['reply'] = reply + e_context["reply"] = reply return def _read_json(self) -> dict: curdir = os.path.dirname(__file__) config_path = os.path.join(curdir, "config.json") - tool_config = { - "tools": [], - "kwargs": {} - } + tool_config = {"tools": [], "kwargs": {}} if not os.path.exists(config_path): return tool_config else: @@ -123,7 +137,9 @@ def _build_tool_kwargs(self, kwargs: dict): "proxy": conf().get("proxy", ""), "request_timeout": conf().get("request_timeout", 60), # note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置 - "model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"), + "model_name": tool_model_name + if tool_model_name + else conf().get("model", "gpt-3.5-turbo"), "no_default": kwargs.get("no_default", False), "top_k_results": kwargs.get("top_k_results", 2), # for news tool @@ -160,4 +176,7 @@ def _reset_app(self) -> App: # filter not support tool tool_list = self._filter_tool_list(tool_config.get("tools", [])) - return load_app(tools_list=tool_list, **self._build_tool_kwargs(tool_config.get("kwargs", {}))) + return load_app( + tools_list=tool_list, + **self._build_tool_kwargs(tool_config.get("kwargs", {})), + ) diff --git a/requirements.txt b/requirements.txt index 61f979d99..c47df1e33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ PyQRCode>=1.2.1 qrcode>=7.4.2 requests>=2.28.2 chardet>=5.1.0 +pre-commit \ No newline at end of file diff --git a/scripts/start.sh b/scripts/start.sh index ac92f8851..3037eb5a2 100755 --- a/scripts/start.sh +++ b/scripts/start.sh @@ -8,7 +8,7 @@ echo $BASE_DIR # check the nohup.out log output file if [ ! -f "${BASE_DIR}/nohup.out" ]; then touch "${BASE_DIR}/nohup.out" -echo "create file ${BASE_DIR}/nohup.out" +echo "create file ${BASE_DIR}/nohup.out" fi nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out" diff --git a/scripts/tout.sh b/scripts/tout.sh index 5b71491ad..ffe6de38a 100755 --- a/scripts/tout.sh +++ b/scripts/tout.sh @@ -7,7 +7,7 @@ echo $BASE_DIR # check the nohup.out log output file if [ ! -f "${BASE_DIR}/nohup.out" ]; then - echo "No file ${BASE_DIR}/nohup.out" + echo "No file ${BASE_DIR}/nohup.out" exit -1; fi diff --git a/voice/audio_convert.py b/voice/audio_convert.py index 53e4b151b..241a3a678 100644 --- a/voice/audio_convert.py +++ b/voice/audio_convert.py @@ -1,9 +1,12 @@ import shutil import wave + import pysilk from pydub import AudioSegment -sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 +sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 + + def find_closest_sil_supports(sample_rate): """ 找到最接近的支持的采样率 @@ -19,6 +22,7 @@ def find_closest_sil_supports(sample_rate): mindiff = diff return closest + def get_pcm_from_wav(wav_path): """ 从 wav 文件中读取 pcm @@ -29,31 +33,42 @@ def get_pcm_from_wav(wav_path): wav = wave.open(wav_path, "rb") return wav.readframes(wav.getnframes()) + def any_to_wav(any_path, wav_path): """ 把任意格式转成wav文件 """ - if any_path.endswith('.wav'): + if any_path.endswith(".wav"): shutil.copy2(any_path, wav_path) return - if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): + if ( + any_path.endswith(".sil") + or any_path.endswith(".silk") + or any_path.endswith(".slk") + ): return sil_to_wav(any_path, wav_path) audio = AudioSegment.from_file(any_path) audio.export(wav_path, format="wav") + def any_to_sil(any_path, sil_path): """ 把任意格式转成sil文件 """ - if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): + if ( + any_path.endswith(".sil") + or any_path.endswith(".silk") + or any_path.endswith(".slk") + ): shutil.copy2(any_path, sil_path) return 10000 - if any_path.endswith('.wav'): + if any_path.endswith(".wav"): return pcm_to_sil(any_path, sil_path) - if any_path.endswith('.mp3'): + if any_path.endswith(".mp3"): return mp3_to_sil(any_path, sil_path) raise NotImplementedError("Not support file type: {}".format(any_path)) + def mp3_to_wav(mp3_path, wav_path): """ 把mp3格式转成pcm文件 @@ -61,6 +76,7 @@ def mp3_to_wav(mp3_path, wav_path): audio = AudioSegment.from_mp3(mp3_path) audio.export(wav_path, format="wav") + def pcm_to_sil(pcm_path, silk_path): """ wav 文件转成 silk @@ -72,12 +88,12 @@ def pcm_to_sil(pcm_path, silk_path): pcm_s16 = audio.set_sample_width(2) pcm_s16 = pcm_s16.set_frame_rate(rate) wav_data = pcm_s16.raw_data - silk_data = pysilk.encode( - wav_data, data_rate=rate, sample_rate=rate) + silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate) with open(silk_path, "wb") as f: f.write(silk_data) return audio.duration_seconds * 1000 + def mp3_to_sil(mp3_path, silk_path): """ mp3 文件转成 silk @@ -95,6 +111,7 @@ def mp3_to_sil(mp3_path, silk_path): f.write(silk_data) return audio.duration_seconds * 1000 + def sil_to_wav(silk_path, wav_path, rate: int = 24000): """ silk 文件转 wav diff --git a/voice/azure/azure_voice.py b/voice/azure/azure_voice.py index 57ea30ce0..9317b8352 100644 --- a/voice/azure/azure_voice.py +++ b/voice/azure/azure_voice.py @@ -1,16 +1,18 @@ - """ azure voice service """ import json import os import time + import azure.cognitiveservices.speech as speechsdk + from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir -from voice.voice import Voice from config import conf +from voice.voice import Voice + """ Azure voice 主目录设置文件中需填写azure_voice_api_key和azure_voice_region @@ -19,50 +21,68 @@ """ -class AzureVoice(Voice): +class AzureVoice(Voice): def __init__(self): try: curdir = os.path.dirname(__file__) config_path = os.path.join(curdir, "config.json") config = None - if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件 - config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"} + if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 + config = { + "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", + "speech_recognition_language": "zh-CN", + } with open(config_path, "w") as fw: json.dump(config, fw, indent=4) else: with open(config_path, "r") as fr: config = json.load(fr) - self.api_key = conf().get('azure_voice_api_key') - self.api_region = conf().get('azure_voice_region') - self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region) - self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"] - self.speech_config.speech_recognition_language = config["speech_recognition_language"] + self.api_key = conf().get("azure_voice_api_key") + self.api_region = conf().get("azure_voice_region") + self.speech_config = speechsdk.SpeechConfig( + subscription=self.api_key, region=self.api_region + ) + self.speech_config.speech_synthesis_voice_name = config[ + "speech_synthesis_voice_name" + ] + self.speech_config.speech_recognition_language = config[ + "speech_recognition_language" + ] except Exception as e: logger.warn("AzureVoice init failed: %s, ignore " % e) def voiceToText(self, voice_file): audio_config = speechsdk.AudioConfig(filename=voice_file) - speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config) + speech_recognizer = speechsdk.SpeechRecognizer( + speech_config=self.speech_config, audio_config=audio_config + ) result = speech_recognizer.recognize_once() if result.reason == speechsdk.ResultReason.RecognizedSpeech: - logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text)) + logger.info( + "[Azure] voiceToText voice file name={} text={}".format( + voice_file, result.text + ) + ) reply = Reply(ReplyType.TEXT, result.text) else: - logger.error('[Azure] voiceToText error, result={}'.format(result)) + logger.error("[Azure] voiceToText error, result={}".format(result)) reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") return reply def textToVoice(self, text): - fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav' + fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" audio_config = speechsdk.AudioConfig(filename=fileName) - speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config) + speech_synthesizer = speechsdk.SpeechSynthesizer( + speech_config=self.speech_config, audio_config=audio_config + ) result = speech_synthesizer.speak_text(text) if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: logger.info( - '[Azure] textToVoice text={} voice file name={}'.format(text, fileName)) + "[Azure] textToVoice text={} voice file name={}".format(text, fileName) + ) reply = Reply(ReplyType.VOICE, fileName) else: - logger.error('[Azure] textToVoice error, result={}'.format(result)) + logger.error("[Azure] textToVoice error, result={}".format(result)) reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") return reply diff --git a/voice/azure/config.json.template b/voice/azure/config.json.template index 13b1fbdcc..2dc2176f9 100644 --- a/voice/azure/config.json.template +++ b/voice/azure/config.json.template @@ -1,4 +1,4 @@ { - "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", - "speech_recognition_language": "zh-CN" -} \ No newline at end of file + "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", + "speech_recognition_language": "zh-CN" +} diff --git a/voice/baidu/README.md b/voice/baidu/README.md index 815bb0aef..d4628a1f8 100644 --- a/voice/baidu/README.md +++ b/voice/baidu/README.md @@ -29,7 +29,7 @@ dev_pid 必填 语言选择,填写语言对应的dev_pid值 2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。 参数 可需 描述 -tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 +tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh spd 选填 语速,取值0-15,默认为5中语速 pit 选填 音调,取值0-15,默认为5中语调 @@ -40,14 +40,14 @@ aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav 关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。 ### 配置文件 - + 将文件夹中`config.json.template`复制为`config.json`。 ``` json { - "lang": "zh", + "lang": "zh", "ctp": 1, - "spd": 5, + "spd": 5, "pit": 5, "vol": 5, "per": 0 diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index 8a177282a..ccde6c4d1 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -1,17 +1,19 @@ - """ baidu voice service """ import json import os import time + from aip import AipSpeech + from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir -from voice.voice import Voice -from voice.audio_convert import get_pcm_from_wav from config import conf +from voice.audio_convert import get_pcm_from_wav +from voice.voice import Voice + """ 百度的语音识别API. dev_pid: @@ -28,40 +30,37 @@ class BaiduVoice(Voice): - def __init__(self): try: curdir = os.path.dirname(__file__) config_path = os.path.join(curdir, "config.json") bconf = None - if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件 - bconf = { "lang": "zh", "ctp": 1, "spd": 5, - "pit": 5, "vol": 5, "per": 0} + if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 + bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0} with open(config_path, "w") as fw: json.dump(bconf, fw, indent=4) else: with open(config_path, "r") as fr: bconf = json.load(fr) - - self.app_id = conf().get('baidu_app_id') - self.api_key = conf().get('baidu_api_key') - self.secret_key = conf().get('baidu_secret_key') - self.dev_id = conf().get('baidu_dev_pid') + + self.app_id = conf().get("baidu_app_id") + self.api_key = conf().get("baidu_api_key") + self.secret_key = conf().get("baidu_secret_key") + self.dev_id = conf().get("baidu_dev_pid") self.lang = bconf["lang"] self.ctp = bconf["ctp"] self.spd = bconf["spd"] self.pit = bconf["pit"] self.vol = bconf["vol"] self.per = bconf["per"] - + self.client = AipSpeech(self.app_id, self.api_key, self.secret_key) except Exception as e: logger.warn("BaiduVoice init failed: %s, ignore " % e) - def voiceToText(self, voice_file): # 识别本地文件 - logger.debug('[Baidu] voice file name={}'.format(voice_file)) + logger.debug("[Baidu] voice file name={}".format(voice_file)) pcm = get_pcm_from_wav(voice_file) res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id}) if res["err_no"] == 0: @@ -72,21 +71,25 @@ def voiceToText(self, voice_file): logger.info("百度语音识别出错了: {}".format(res["err_msg"])) if res["err_msg"] == "request pv too much": logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") - reply = Reply(ReplyType.ERROR, - "百度语音识别出错了;{0}".format(res["err_msg"])) + reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"])) return reply def textToVoice(self, text): - result = self.client.synthesis(text, self.lang, self.ctp, { - 'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per}) + result = self.client.synthesis( + text, + self.lang, + self.ctp, + {"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per}, + ) if not isinstance(result, dict): - fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3' - with open(fileName, 'wb') as f: + fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" + with open(fileName, "wb") as f: f.write(result) logger.info( - '[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) + "[Baidu] textToVoice text={} voice file name={}".format(text, fileName) + ) reply = Reply(ReplyType.VOICE, fileName) else: - logger.error('[Baidu] textToVoice error={}'.format(result)) + logger.error("[Baidu] textToVoice error={}".format(result)) reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") return reply diff --git a/voice/baidu/config.json.template b/voice/baidu/config.json.template index 9937aa5f8..19e812f0b 100644 --- a/voice/baidu/config.json.template +++ b/voice/baidu/config.json.template @@ -1,8 +1,8 @@ - { - "lang": "zh", - "ctp": 1, - "spd": 5, - "pit": 5, - "vol": 5, - "per": 0 - } \ No newline at end of file +{ + "lang": "zh", + "ctp": 1, + "spd": 5, + "pit": 5, + "vol": 5, + "per": 0 +} diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 3a0537f1e..4f7b8ade3 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -1,11 +1,12 @@ - """ google voice service """ import time + import speech_recognition from gtts import gTTS + from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir @@ -22,9 +23,12 @@ def voiceToText(self, voice_file): with speech_recognition.AudioFile(voice_file) as source: audio = self.recognizer.record(source) try: - text = self.recognizer.recognize_google(audio, language='zh-CN') + text = self.recognizer.recognize_google(audio, language="zh-CN") logger.info( - '[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) + "[Google] voiceToText text={} voice file name={}".format( + text, voice_file + ) + ) reply = Reply(ReplyType.TEXT, text) except speech_recognition.UnknownValueError: reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") @@ -32,13 +36,15 @@ def voiceToText(self, voice_file): reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) finally: return reply + def textToVoice(self, text): try: - mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3' - tts = gTTS(text=text, lang='zh') - tts.save(mp3File) + mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" + tts = gTTS(text=text, lang="zh") + tts.save(mp3File) logger.info( - '[Google] textToVoice text={} voice file name={}'.format(text, mp3File)) + "[Google] textToVoice text={} voice file name={}".format(text, mp3File) + ) reply = Reply(ReplyType.VOICE, mp3File) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index c98d0c90b..06c221b21 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -1,29 +1,32 @@ - """ google voice service """ import json + import openai + from bridge.reply import Reply, ReplyType -from config import conf from common.log import logger +from config import conf from voice.voice import Voice class OpenaiVoice(Voice): def __init__(self): - openai.api_key = conf().get('open_ai_api_key') + openai.api_key = conf().get("open_ai_api_key") def voiceToText(self, voice_file): - logger.debug( - '[Openai] voice file name={}'.format(voice_file)) + logger.debug("[Openai] voice file name={}".format(voice_file)) try: file = open(voice_file, "rb") result = openai.Audio.transcribe("whisper-1", file) text = result["text"] reply = Reply(ReplyType.TEXT, text) logger.info( - '[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) + "[Openai] voiceToText text={} voice file name={}".format( + text, voice_file + ) + ) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) finally: diff --git a/voice/pytts/pytts_voice.py b/voice/pytts/pytts_voice.py index 2e9cdc045..072e28b41 100644 --- a/voice/pytts/pytts_voice.py +++ b/voice/pytts/pytts_voice.py @@ -1,10 +1,11 @@ - """ pytts voice service (offline) """ import time + import pyttsx3 + from bridge.reply import Reply, ReplyType from common.log import logger from common.tmp_dir import TmpDir @@ -16,20 +17,21 @@ class PyttsVoice(Voice): def __init__(self): # 语速 - self.engine.setProperty('rate', 125) + self.engine.setProperty("rate", 125) # 音量 - self.engine.setProperty('volume', 1.0) - for voice in self.engine.getProperty('voices'): + self.engine.setProperty("volume", 1.0) + for voice in self.engine.getProperty("voices"): if "Chinese" in voice.name: - self.engine.setProperty('voice', voice.id) + self.engine.setProperty("voice", voice.id) def textToVoice(self, text): try: - wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav' + wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" self.engine.save_to_file(text, wavFile) self.engine.runAndWait() logger.info( - '[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile)) + "[Pytts] textToVoice text={} voice file name={}".format(text, wavFile) + ) reply = Reply(ReplyType.VOICE, wavFile) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) diff --git a/voice/voice.py b/voice/voice.py index 52d8aaa52..1ca199b57 100644 --- a/voice/voice.py +++ b/voice/voice.py @@ -2,6 +2,7 @@ Voice service abstract class """ + class Voice(object): def voiceToText(self, voice_file): """ @@ -13,4 +14,4 @@ def textToVoice(self, text): """ Send text to voice service and get voice """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/voice/voice_factory.py b/voice/voice_factory.py index de4b3d99a..45fe0d16d 100644 --- a/voice/voice_factory.py +++ b/voice/voice_factory.py @@ -2,25 +2,31 @@ voice factory """ + def create_voice(voice_type): """ create a voice instance :param voice_type: voice type code :return: voice instance """ - if voice_type == 'baidu': + if voice_type == "baidu": from voice.baidu.baidu_voice import BaiduVoice + return BaiduVoice() - elif voice_type == 'google': + elif voice_type == "google": from voice.google.google_voice import GoogleVoice + return GoogleVoice() - elif voice_type == 'openai': + elif voice_type == "openai": from voice.openai.openai_voice import OpenaiVoice + return OpenaiVoice() - elif voice_type == 'pytts': + elif voice_type == "pytts": from voice.pytts.pytts_voice import PyttsVoice + return PyttsVoice() - elif voice_type == 'azure': + elif voice_type == "azure": from voice.azure.azure_voice import AzureVoice + return AzureVoice() raise RuntimeError