diff --git a/bot/bot_factory.py b/bot/bot_factory.py index 77797e76a..2e9cb2d6a 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -33,4 +33,9 @@ def create_bot(bot_type): from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot return AzureChatGPTBot() + + elif bot_type == const.LINKAI: + from bot.linkai.link_ai_bot import LinkAIBot + return LinkAIBot() + raise RuntimeError diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 9c54faf4f..60fc3f845 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -66,12 +66,16 @@ def reply(self, query, context=None): logger.debug("[CHATGPT] session query={}".format(session.messages)) api_key = context.get("openai_api_key") - self.args['model'] = context.get('gpt_model') or "gpt-3.5-turbo" + model = context.get("gpt_model") + new_args = None + if model: + new_args = self.args.copy() + new_args["model"] = model # if context.get('stream'): # # reply in stream # return self.reply_text_stream(query, new_query, session_id) - reply_content = self.reply_text(session, api_key) + reply_content = self.reply_text(session, api_key, args=new_args) logger.debug( "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( session.messages, @@ -102,7 +106,7 @@ def reply(self, query, context=None): reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) return reply - def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict: + def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict: """ call openai's ChatCompletion to get the answer :param session: a conversation session @@ -114,7 +118,9 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") # if api_key == None, the default openai.api_key will be used - response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args) + if args is None: + args = self.args + response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) return { "total_tokens": response["usage"]["total_tokens"], @@ -150,7 +156,7 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di if need_retry: logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, api_key, retry_count + 1) + return self.reply_text(session, api_key, args, retry_count + 1) else: return result diff --git a/bot/linkai/link_ai_bot.py b/bot/linkai/link_ai_bot.py new file mode 100644 index 000000000..a92466db1 --- /dev/null +++ b/bot/linkai/link_ai_bot.py @@ -0,0 +1,72 @@ +from bot.bot import Bot +from bridge.context import ContextType +from bridge.reply import Reply, ReplyType +from common.log import logger +from bridge.context import Context +from bot.chatgpt.chat_gpt_session import ChatGPTSession +from bot.session_manager import SessionManager +from config import conf +import requests +import time + +class LinkAIBot(Bot): + + # authentication failed + AUTH_FAILED_CODE = 401 + + def __init__(self): + self.base_url = "https://api.link-ai.chat/v1" + self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") + + def reply(self, query, context: Context = None) -> Reply: + return self._chat(query, context) + + def _chat(self, query, context, retry_count=0): + if retry_count >= 2: + # exit from retry 2 times + logger.warn("[LINKAI] failed after maximum number of retry times") + return Reply(ReplyType.ERROR, "请再问我一次吧") + + try: + session_id = context["session_id"] + + session = self.sessions.session_query(query, session_id) + + # remove system message + if session.messages[0].get("role") == "system": + session.messages.pop(0) + + # load config + app_code = conf().get("linkai_app_code") + linkai_api_key = conf().get("linkai_api_key") + logger.info(f"[LINKAI] query={query}, app_code={app_code}") + + body = { + "appCode": app_code, + "messages": session.messages + } + headers = {"Authorization": "Bearer " + linkai_api_key} + + # do http request + res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json() + + if not res or not res["success"]: + if res.get("code") == self.AUTH_FAILED_CODE: + logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}") + return Reply(ReplyType.ERROR, "请再问我一次吧") + else: + # retry + time.sleep(2) + logger.warn(f"[LINKAI] do retry, times={retry_count}") + return self._chat(query, context, retry_count + 1) + # execute success + reply_content = res["data"]["content"] + logger.info(f"[LINKAI] reply={reply_content}") + self.sessions.session_reply(reply_content, session_id) + return Reply(ReplyType.TEXT, reply_content) + except Exception as e: + logger.exception(e) + # retry + time.sleep(2) + logger.warn(f"[LINKAI] do retry, times={retry_count}") + return self._chat(query, context, retry_count + 1) diff --git a/bridge/bridge.py b/bridge/bridge.py index 78fe23458..119e855c4 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -23,6 +23,8 @@ def __init__(self): self.btype["chat"] = const.OPEN_AI if conf().get("use_azure_chatgpt", False): self.btype["chat"] = const.CHATGPTONAZURE + if conf().get("linkai_api_key") and conf().get("linkai_app_code"): + self.btype["chat"] = const.LINKAI self.bots = {} def get_bot(self, typename): diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index c888157f3..69f56f3a2 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -23,7 +23,6 @@ from config import conf, get_appdata_dir from lib import itchat from lib.itchat.content import * -from plugins import * @itchat.msg_register([TEXT, VOICE, PICTURE, NOTE]) diff --git a/common/const.py b/common/const.py index b60d44576..481741d13 100644 --- a/common/const.py +++ b/common/const.py @@ -3,5 +3,6 @@ CHATGPT = "chatGPT" BAIDU = "baidu" CHATGPTONAZURE = "chatGPTOnAzure" +LINKAI = "linkai" VERSION = "1.3.0" diff --git a/config.py b/config.py index ae1cfd778..26b2f8a6b 100644 --- a/config.py +++ b/config.py @@ -99,6 +99,9 @@ "appdata_dir": "", # 数据目录 # 插件配置 "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 + # 知识库平台配置 + "linkai_api_key": "", + "linkai_app_code": "" }