Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

11 #31

Merged
merged 7 commits into from
Jun 1, 2023
Merged

11 #31

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bot/bot_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ def create_bot(bot_type):
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot

return AzureChatGPTBot()

elif bot_type == const.LINKAI:
from bot.linkai.link_ai_bot import LinkAIBot
return LinkAIBot()

raise RuntimeError
16 changes: 11 additions & 5 deletions bot/chatgpt/chat_gpt_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ def reply(self, query, context=None):
logger.debug("[CHATGPT] session query={}".format(session.messages))

api_key = context.get("openai_api_key")
self.args['model'] = context.get('gpt_model') or "gpt-3.5-turbo"
model = context.get("gpt_model")
new_args = None
if model:
new_args = self.args.copy()
new_args["model"] = model
# if context.get('stream'):
# # reply in stream
# return self.reply_text_stream(query, new_query, session_id)

reply_content = self.reply_text(session, api_key)
reply_content = self.reply_text(session, api_key, args=new_args)
logger.debug(
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
session.messages,
Expand Down Expand Up @@ -102,7 +106,7 @@ def reply(self, query, context=None):
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
"""
call openai's ChatCompletion to get the answer
:param session: a conversation session
Expand All @@ -114,7 +118,9 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
if args is None:
args = self.args
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {
"total_tokens": response["usage"]["total_tokens"],
Expand Down Expand Up @@ -150,7 +156,7 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di

if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, api_key, retry_count + 1)
return self.reply_text(session, api_key, args, retry_count + 1)
else:
return result

Expand Down
72 changes: 72 additions & 0 deletions bot/linkai/link_ai_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from bot.bot import Bot
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from bridge.context import Context
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.session_manager import SessionManager
from config import conf
import requests
import time

class LinkAIBot(Bot):

# authentication failed
AUTH_FAILED_CODE = 401

def __init__(self):
self.base_url = "https://api.link-ai.chat/v1"
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")

def reply(self, query, context: Context = None) -> Reply:
return self._chat(query, context)

def _chat(self, query, context, retry_count=0):
if retry_count >= 2:
# exit from retry 2 times
logger.warn("[LINKAI] failed after maximum number of retry times")
return Reply(ReplyType.ERROR, "请再问我一次吧")

try:
session_id = context["session_id"]

session = self.sessions.session_query(query, session_id)

# remove system message
if session.messages[0].get("role") == "system":
session.messages.pop(0)

# load config
app_code = conf().get("linkai_app_code")
linkai_api_key = conf().get("linkai_api_key")
logger.info(f"[LINKAI] query={query}, app_code={app_code}")

body = {
"appCode": app_code,
"messages": session.messages
}
headers = {"Authorization": "Bearer " + linkai_api_key}

# do http request
res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json()

if not res or not res["success"]:
if res.get("code") == self.AUTH_FAILED_CODE:
logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}")
return Reply(ReplyType.ERROR, "请再问我一次吧")
else:
# retry
time.sleep(2)
logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)
# execute success
reply_content = res["data"]["content"]
logger.info(f"[LINKAI] reply={reply_content}")
self.sessions.session_reply(reply_content, session_id)
return Reply(ReplyType.TEXT, reply_content)
except Exception as e:
logger.exception(e)
# retry
time.sleep(2)
logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)
2 changes: 2 additions & 0 deletions bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(self):
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
self.btype["chat"] = const.CHATGPTONAZURE
if conf().get("linkai_api_key") and conf().get("linkai_app_code"):
self.btype["chat"] = const.LINKAI
self.bots = {}

def get_bot(self, typename):
Expand Down
1 change: 0 additions & 1 deletion channel/wechat/wechat_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from config import conf, get_appdata_dir
from lib import itchat
from lib.itchat.content import *
from plugins import *


@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE])
Expand Down
1 change: 1 addition & 0 deletions common/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
CHATGPT = "chatGPT"
BAIDU = "baidu"
CHATGPTONAZURE = "chatGPTOnAzure"
LINKAI = "linkai"

VERSION = "1.3.0"
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@
"appdata_dir": "", # 数据目录
# 插件配置
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
# 知识库平台配置
"linkai_api_key": "",
"linkai_app_code": ""
}


Expand Down