-
Notifications
You must be signed in to change notification settings - Fork 678
/
Copy pathopenai_api.py
187 lines (153 loc) · 7.06 KB
/
openai_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import json
import uuid
from datetime import datetime, timezone
from typing import Optional
import httpx
from pydantic import ValidationError
from api.conf import Config, Credentials
from api.enums import OpenaiApiChatModels, ChatSourceTypes
from api.exceptions import OpenaiApiException
from api.models.doc import OpenaiApiChatMessage, OpenaiApiConversationHistoryDocument, OpenaiApiChatMessageMetadata, \
OpenaiApiChatMessageTextContent
from api.schemas.openai_schemas import OpenaiChatResponse
from utils.common import SingletonMeta
from utils.logger import get_logger
logger = get_logger(__name__)
config = Config()
credentials = Credentials()
MAX_CONTEXT_MESSAGE_COUNT = 1000
async def _check_response(response: httpx.Response) -> None:
# 改成自带的错误处理
try:
response.raise_for_status()
except httpx.HTTPStatusError as ex:
await response.aread()
error = OpenaiApiException(
message=response.text,
code=response.status_code,
)
raise error from ex
def make_session() -> httpx.AsyncClient:
if config.openai_api.proxy is not None:
proxies = {
"http://": config.openai_api.proxy,
"https://": config.openai_api.proxy,
}
session = httpx.AsyncClient(proxies=proxies, timeout=None)
else:
session = httpx.AsyncClient(timeout=None)
return session
class OpenaiApiChatManager(metaclass=SingletonMeta):
"""
OpenAI API Manager
"""
def __init__(self):
self.session = make_session()
def reset_session(self):
self.session = make_session()
async def complete(self, model: OpenaiApiChatModels, text_content: str, conversation_id: uuid.UUID = None,
parent_message_id: uuid.UUID = None,
context_message_count: int = -1, extra_args: Optional[dict] = None, **_kwargs):
assert config.openai_api.enabled, "openai_api is not enabled"
now_time = datetime.now().astimezone(tz=timezone.utc)
message_id = uuid.uuid4()
new_message = OpenaiApiChatMessage(
source="openai_api",
id=message_id,
role="user",
create_time=now_time,
parent=parent_message_id,
children=[],
content=OpenaiApiChatMessageTextContent(content_type="text", text=text_content),
metadata=OpenaiApiChatMessageMetadata(
source="openai_api",
)
)
messages = []
if not conversation_id:
assert parent_message_id is None, "parent_id must be None when conversation_id is None"
messages = [new_message]
else:
conv_history = await OpenaiApiConversationHistoryDocument.get(conversation_id)
if not conv_history:
raise ValueError("conversation_id not found")
if conv_history.source != ChatSourceTypes.openai_api:
raise ValueError(f"{conversation_id} is not api conversation")
if not conv_history.mapping.get(str(parent_message_id)):
raise ValueError(f"{parent_message_id} is not a valid parent of {conversation_id}")
# 从 current_node 开始往前找 context_message_count 个 message
if not conv_history.current_node:
raise ValueError(f"{conversation_id} current_node is None")
msg = conv_history.mapping.get(str(conv_history.current_node))
assert msg, f"{conv_history.id} current_node({conv_history.current_node}) not found in mapping"
count = 0
iter_count = 0
while msg:
count += 1
messages.append(msg)
if context_message_count != -1 and count >= context_message_count:
break
iter_count += 1
if iter_count > MAX_CONTEXT_MESSAGE_COUNT:
raise ValueError(f"too many messages to iterate, conversation_id={conversation_id}")
msg = conv_history.mapping.get(str(msg.parent))
messages.reverse()
messages.append(new_message)
# TODO: credits 判断
base_url = config.openai_api.openai_base_url
data = {
"model": model.code(),
"messages": [{"role": msg.role, "content": msg.content.text} for msg in messages],
"stream": True,
**(extra_args or {})
}
reply_message = None
text_content = ""
timeout = httpx.Timeout(config.openai_api.read_timeout, connect=config.openai_api.connect_timeout)
async with self.session.stream(method="POST",
url=f"{base_url}chat/completions",
json=data,
headers={"Authorization": f"Bearer {credentials.openai_api_key}"},
timeout=timeout
) as response:
await _check_response(response)
async for line in response.aiter_lines():
if not line or line is None:
continue
if "data: " in line:
line = line[6:]
if "[DONE]" in line:
break
try:
line = json.loads(line)
resp = OpenaiChatResponse.model_validate(line)
if not resp.choices or len(resp.choices) == 0:
continue
if resp.choices[0].message is not None:
text_content = resp.choices[0].message.get("content")
if resp.choices[0].delta is not None:
text_content += resp.choices[0].delta.get("content", "")
if reply_message is None:
reply_message = OpenaiApiChatMessage(
source="openai_api",
id=uuid.uuid4(),
role="assistant",
model=model,
create_time=datetime.now().astimezone(tz=timezone.utc),
parent=message_id,
children=[],
content=OpenaiApiChatMessageTextContent(content_type="text", text=text_content),
metadata=OpenaiApiChatMessageMetadata(
source="openai_api",
finish_reason=resp.choices[0].finish_reason,
)
)
else:
reply_message.content = OpenaiApiChatMessageTextContent(content_type="text", text=text_content)
if resp.usage:
reply_message.metadata.usage = resp.usage
yield reply_message
except json.decoder.JSONDecodeError:
logger.warning(f"OpenAIChatResponse parse json error")
except ValidationError as e:
logger.warning(f"OpenAIChatResponse validate error: {e}")