-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
types.py
172 lines (139 loc) · 5.51 KB
/
types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from enum import Enum
from typing import Any, AsyncGenerator, Generator, Optional, Union, List
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
try:
from pydantic import BaseModel as V2BaseModel
from pydantic.v1 import BaseModel as V1BaseModel
except ImportError:
from pydantic import BaseModel as V2BaseModel
V1BaseModel = V2BaseModel
class MessageRole(str, Enum):
"""Message role."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
FUNCTION = "function"
TOOL = "tool"
CHATBOT = "chatbot"
MODEL = "model"
# ===== Generic Model Input - Chat =====
class ChatMessage(BaseModel):
"""Chat message."""
role: MessageRole = MessageRole.USER
content: Optional[Any] = ""
additional_kwargs: dict = Field(default_factory=dict)
def __str__(self) -> str:
return f"{self.role.value}: {self.content}"
@classmethod
def from_str(
cls,
content: str,
role: Union[MessageRole, str] = MessageRole.USER,
**kwargs: Any,
) -> "ChatMessage":
if isinstance(role, str):
role = MessageRole(role)
return cls(role=role, content=content, **kwargs)
def _recursive_serialization(self, value: Any) -> Any:
if isinstance(value, (V1BaseModel, V2BaseModel)):
return value.dict()
if isinstance(value, dict):
return {
key: self._recursive_serialization(value)
for key, value in value.items()
}
if isinstance(value, list):
return [self._recursive_serialization(item) for item in value]
return value
def dict(self, **kwargs: Any) -> dict:
# ensure all additional_kwargs are serializable
msg = super().dict(**kwargs)
for key, value in msg.get("additional_kwargs", {}).items():
value = self._recursive_serialization(value)
if not isinstance(value, (str, int, float, bool, dict, list, type(None))):
raise ValueError(
f"Failed to serialize additional_kwargs value: {value}"
)
msg["additional_kwargs"][key] = value
return msg
class LogProb(BaseModel):
"""LogProb of a token."""
token: str = Field(default_factory=str)
logprob: float = Field(default_factory=float)
bytes: List[int] = Field(default_factory=list)
# ===== Generic Model Output - Chat =====
class ChatResponse(BaseModel):
"""Chat response."""
message: ChatMessage
raw: Optional[dict] = None
delta: Optional[str] = None
logprobs: Optional[List[List[LogProb]]] = None
additional_kwargs: dict = Field(default_factory=dict)
def __str__(self) -> str:
return str(self.message)
ChatResponseGen = Generator[ChatResponse, None, None]
ChatResponseAsyncGen = AsyncGenerator[ChatResponse, None]
# ===== Generic Model Output - Completion =====
class CompletionResponse(BaseModel):
"""
Completion response.
Fields:
text: Text content of the response if not streaming, or if streaming,
the current extent of streamed text.
additional_kwargs: Additional information on the response(i.e. token
counts, function calling information).
raw: Optional raw JSON that was parsed to populate text, if relevant.
delta: New text that just streamed in (only relevant when streaming).
"""
text: str
additional_kwargs: dict = Field(default_factory=dict)
raw: Optional[dict] = None
logprobs: Optional[List[List[LogProb]]] = None
delta: Optional[str] = None
def __str__(self) -> str:
return self.text
CompletionResponseGen = Generator[CompletionResponse, None, None]
CompletionResponseAsyncGen = AsyncGenerator[CompletionResponse, None]
class LLMMetadata(BaseModel):
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(
"Total number of tokens the model can be input and output for one response."
),
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="Number of tokens the model can output when generating a response.",
)
is_chat_model: bool = Field(
default=False,
description=(
"Set True if the model exposes a chat interface (i.e. can be passed a"
" sequence of messages, rather than text), like OpenAI's"
" /v1/chat/completions endpoint."
),
)
is_function_calling_model: bool = Field(
default=False,
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
description=(
"Set True if the model supports function calling messages, similar to"
" OpenAI's function calling API. For example, converting 'Email Anya to"
" see if she wants to get coffee next Friday' to a function call like"
" `send_email(to: string, body: string)`."
),
)
model_name: str = Field(
default="unknown",
description=(
"The model's name used for logging, testing, and sanity checking. For some"
" models this can be automatically discerned. For other models, like"
" locally loaded models, this must be manually specified."
),
)
system_role: MessageRole = Field(
default=MessageRole.SYSTEM,
description="The role this specific LLM provider"
"expects for system prompt. E.g. 'SYSTEM' for OpenAI, 'CHATBOT' for Cohere",
)