Skip to content

Commit

Permalink
formatting: run precommit on all files
Browse files Browse the repository at this point in the history
  • Loading branch information
lanvent committed Apr 22, 2023
1 parent eaf4e91 commit 618c94e
Show file tree
Hide file tree
Showing 40 changed files with 228 additions and 646 deletions.
1 change: 1 addition & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def func(_signo, _stack_frame):
if callable(old_handler): # check old_handler
return old_handler(_signo, _stack_frame)
sys.exit(0)

signal.signal(_signo, func)


Expand Down
12 changes: 2 additions & 10 deletions bot/baidu/baidu_unit_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
class BaiduUnitBot(Bot):
def reply(self, query, context=None):
token = self.get_token()
url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ token
)
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
post_data = (
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ query
Expand All @@ -32,12 +29,7 @@ def reply(self, query, context=None):
def get_token(self):
access_key = "YOUR_ACCESS_KEY"
secret_key = "YOUR_SECRET_KEY"
host = (
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
+ access_key
+ "&client_secret="
+ secret_key
)
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
response = requests.get(host)
if response:
print(response.json())
Expand Down
29 changes: 7 additions & 22 deletions bot/chatgpt/chat_gpt_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,15 @@ def __init__(self):
if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))

self.sessions = SessionManager(
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
)
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数
"top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
}

Expand Down Expand Up @@ -87,15 +79,10 @@ def reply(self, query, context=None):
reply_content["completion_tokens"],
)
)
if (
reply_content["completion_tokens"] == 0
and len(reply_content["content"]) > 0
):
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(
reply_content["content"], session_id, reply_content["total_tokens"]
)
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content["content"])
Expand Down Expand Up @@ -126,9 +113,7 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create(
api_key=api_key, messages=session.messages, **self.args
)
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {
"total_tokens": response["usage"]["total_tokens"],
Expand Down
22 changes: 5 additions & 17 deletions bot/chatgpt/chat_gpt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = False
if cur_tokens is None:
raise e
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 2:
self.messages.pop(1)
Expand All @@ -39,16 +37,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
cur_tokens = cur_tokens - max_tokens
break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn(
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
)
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug(
"max_tokens={}, total_tokens={}, len(messages)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
Expand All @@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
logger.warn(
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
)
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
num_tokens = 0
for message in messages:
Expand Down
28 changes: 7 additions & 21 deletions bot/openai/open_ai_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,15 @@ def __init__(self):
if proxy:
openai.proxy = proxy

self.sessions = SessionManager(
OpenAISession, model=conf().get("model") or "text-davinci-003"
)
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens": 1200, # 回复最大的字符数
"top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"],
}
Expand All @@ -71,17 +63,13 @@ def reply(self, query, context=None):
result["content"],
)
logger.debug(
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
str(session), session_id, reply_content, completion_tokens
)
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
)

if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content)
else:
self.sessions.session_reply(
reply_content, session_id, total_tokens
)
self.sessions.session_reply(reply_content, session_id, total_tokens)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
Expand All @@ -96,9 +84,7 @@ def reply(self, query, context=None):
def reply_text(self, session: OpenAISession, retry_count=0):
try:
response = openai.Completion.create(prompt=str(session), **self.args)
res_content = (
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
)
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content))
Expand Down
10 changes: 2 additions & 8 deletions bot/openai/open_ai_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def create_img(self, query, retry_count=0):
response = openai.Image.create(
prompt=query, # 图片描述
n=1, # 每次生成图片的数量
size=conf().get(
"image_create_size", "256x256"
), # 图片大小,可选有 256x256, 512x512, 1024x1024
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
)
image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url))
Expand All @@ -34,11 +32,7 @@ def create_img(self, query, retry_count=0):
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn(
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
retry_count + 1
)
)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
return self.create_img(query, retry_count + 1)
else:
return False, "提问太快啦,请休息一下再问我吧"
Expand Down
16 changes: 3 additions & 13 deletions bot/openai/open_ai_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = False
if cur_tokens is None:
raise e
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 1:
self.messages.pop(0)
Expand All @@ -50,18 +48,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
cur_tokens = len(str(self))
break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn(
"user question exceed max_tokens. total_tokens={}".format(
cur_tokens
)
)
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug(
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
Expand Down
20 changes: 4 additions & 16 deletions bot/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def build_session(self, session_id, system_prompt=None):
return self.sessioncls(session_id, system_prompt, **self.session_args)

if session_id not in self.sessions:
self.sessions[session_id] = self.sessioncls(
session_id, system_prompt, **self.session_args
)
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id]
Expand All @@ -71,9 +69,7 @@ def session_query(self, query, session_id):
total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e:
logger.debug(
"Exception when counting tokens precisely for prompt: {}".format(str(e))
)
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session

def session_reply(self, reply, session_id, total_tokens=None):
Expand All @@ -82,17 +78,9 @@ def session_reply(self, reply, session_id, total_tokens=None):
try:
max_tokens = conf().get("conversation_max_tokens", 1000)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug(
"raw total_tokens={}, savesession tokens={}".format(
total_tokens, tokens_cnt
)
)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e:
logger.debug(
"Exception when counting tokens precisely for session: {}".format(
str(e)
)
)
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
return session

def clear_session(self, session_id):
Expand Down
4 changes: 1 addition & 3 deletions bridge/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,4 @@ def __delitem__(self, key):
del self.kwargs[key]

def __str__(self):
return "Context(type={}, content={}, kwargs={})".format(
self.type, self.content, self.kwargs
)
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
Loading

0 comments on commit 618c94e

Please sign in to comment.