/
openai_like.py
168 lines (141 loc) · 5.99 KB
/
openai_like.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from typing import Any, Optional, Sequence, Union
from llama_index.legacy.bridge.pydantic import Field
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW
from llama_index.legacy.llms.generic_utils import (
async_stream_completion_response_to_chat_response,
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.legacy.llms.openai import OpenAI, Tokenizer
from llama_index.legacy.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
)
class OpenAILike(OpenAI):
"""
OpenAILike is a thin wrapper around the OpenAI model that makes it compatible with
3rd party tools that provide an openai-compatible api.
Currently, llama_index prevents using custom models with their OpenAI class
because they need to be able to infer some metadata from the model name.
NOTE: You still need to set the OPENAI_BASE_API and OPENAI_API_KEY environment
variables or the api_key and api_base constructor arguments.
OPENAI_API_KEY/api_key can normally be set to anything in this case,
but will depend on the tool you're using.
"""
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=LLMMetadata.__fields__["context_window"].field_info.description,
)
is_chat_model: bool = Field(
default=False,
description=LLMMetadata.__fields__["is_chat_model"].field_info.description,
)
is_function_calling_model: bool = Field(
default=False,
description=LLMMetadata.__fields__[
"is_function_calling_model"
].field_info.description,
)
tokenizer: Union[Tokenizer, str, None] = Field(
default=None,
description=(
"An instance of a tokenizer object that has an encode method, or the name"
" of a tokenizer model from Hugging Face. If left as None, then this"
" disables inference of max_tokens."
),
)
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_tokens or -1,
is_chat_model=self.is_chat_model,
is_function_calling_model=self.is_function_calling_model,
model_name=self.model,
)
@property
def _tokenizer(self) -> Optional[Tokenizer]:
if isinstance(self.tokenizer, str):
try:
from transformers import AutoTokenizer
except ImportError as exc:
raise ImportError(
"Please install transformers (pip install transformers) to use "
"huggingface tokenizers with OpenAILike."
) from exc
return AutoTokenizer.from_pretrained(self.tokenizer)
return self.tokenizer
@classmethod
def class_name(cls) -> str:
return "OpenAILike"
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
"""Complete the prompt."""
if not formatted:
prompt = self.completion_to_prompt(prompt)
return super().complete(prompt, **kwargs)
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
"""Stream complete the prompt."""
if not formatted:
prompt = self.completion_to_prompt(prompt)
return super().stream_complete(prompt, **kwargs)
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
"""Chat with the model."""
if not self.metadata.is_chat_model:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
return super().chat(messages, **kwargs)
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
if not self.metadata.is_chat_model:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response)
return super().stream_chat(messages, **kwargs)
# -- Async methods --
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
"""Complete the prompt."""
if not formatted:
prompt = self.completion_to_prompt(prompt)
return await super().acomplete(prompt, **kwargs)
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
"""Stream complete the prompt."""
if not formatted:
prompt = self.completion_to_prompt(prompt)
return await super().astream_complete(prompt, **kwargs)
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
"""Chat with the model."""
if not self.metadata.is_chat_model:
prompt = self.messages_to_prompt(messages)
completion_response = await self.acomplete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
return await super().achat(messages, **kwargs)
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
if not self.metadata.is_chat_model:
prompt = self.messages_to_prompt(messages)
completion_response = await self.astream_complete(
prompt, formatted=True, **kwargs
)
return async_stream_completion_response_to_chat_response(
completion_response
)
return await super().astream_chat(messages, **kwargs)