-
Notifications
You must be signed in to change notification settings - Fork 4.8k
/
settings.py
304 lines (230 loc) · 9.47 KB
/
settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional
if TYPE_CHECKING:
from llama_index.core.service_context import ServiceContext
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.callbacks.base import BaseCallbackHandler, CallbackManager
from llama_index.core.embeddings.utils import EmbedType, resolve_embed_model
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.llms import LLM
from llama_index.core.llms.utils import LLMType, resolve_llm
from llama_index.core.node_parser import NodeParser, SentenceSplitter
from llama_index.core.schema import TransformComponent
from llama_index.core.types import PydanticProgramMode
from llama_index.core.utils import get_tokenizer, set_global_tokenizer
@dataclass
class _Settings:
"""Settings for the Llama Index, lazily initialized."""
# lazy initialization
_llm: Optional[LLM] = None
_embed_model: Optional[BaseEmbedding] = None
_callback_manager: Optional[CallbackManager] = None
_tokenizer: Optional[Callable[[str], List[Any]]] = None
_node_parser: Optional[NodeParser] = None
_prompt_helper: Optional[PromptHelper] = None
_transformations: Optional[List[TransformComponent]] = None
# ---- LLM ----
@property
def llm(self) -> LLM:
"""Get the LLM."""
if self._llm is None:
self._llm = resolve_llm("default")
if self._callback_manager is not None:
self._llm.callback_manager = self._callback_manager
return self._llm
@llm.setter
def llm(self, llm: LLMType) -> None:
"""Set the LLM."""
self._llm = resolve_llm(llm)
@property
def pydantic_program_mode(self) -> PydanticProgramMode:
"""Get the pydantic program mode."""
return self.llm.pydantic_program_mode
@pydantic_program_mode.setter
def pydantic_program_mode(self, pydantic_program_mode: PydanticProgramMode) -> None:
"""Set the pydantic program mode."""
self.llm.pydantic_program_mode = pydantic_program_mode
# ---- Embedding ----
@property
def embed_model(self) -> BaseEmbedding:
"""Get the embedding model."""
if self._embed_model is None:
self._embed_model = resolve_embed_model("default")
if self._callback_manager is not None:
self._embed_model.callback_manager = self._callback_manager
return self._embed_model
@embed_model.setter
def embed_model(self, embed_model: EmbedType) -> None:
"""Set the embedding model."""
self._embed_model = resolve_embed_model(embed_model)
# ---- Callbacks ----
@property
def global_handler(self) -> Optional[BaseCallbackHandler]:
"""Get the global handler."""
import llama_index.core
# TODO: deprecated?
return llama_index.core.global_handler
@global_handler.setter
def global_handler(self, eval_mode: str, **eval_params: Any) -> None:
"""Set the global handler."""
from llama_index.core import set_global_handler
# TODO: deprecated?
set_global_handler(eval_mode, **eval_params)
@property
def callback_manager(self) -> CallbackManager:
"""Get the callback manager."""
if self._callback_manager is None:
self._callback_manager = CallbackManager()
return self._callback_manager
@callback_manager.setter
def callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set the callback manager."""
self._callback_manager = callback_manager
# ---- Tokenizer ----
@property
def tokenizer(self) -> Callable[[str], List[Any]]:
"""Get the tokenizer."""
import llama_index.core
if llama_index.core.global_tokenizer is None:
return get_tokenizer()
# TODO: deprecated?
return llama_index.core.global_tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: Callable[[str], List[Any]]) -> None:
"""Set the tokenizer."""
try:
from transformers import PreTrainedTokenizerBase # pants: no-infer-dep
if isinstance(tokenizer, PreTrainedTokenizerBase):
from functools import partial
tokenizer = partial(tokenizer.encode, add_special_tokens=False)
except ImportError:
pass
# TODO: deprecated?
set_global_tokenizer(tokenizer)
# ---- Node parser ----
@property
def node_parser(self) -> NodeParser:
"""Get the node parser."""
if self._node_parser is None:
self._node_parser = SentenceSplitter()
if self._callback_manager is not None:
self._node_parser.callback_manager = self._callback_manager
return self._node_parser
@node_parser.setter
def node_parser(self, node_parser: NodeParser) -> None:
"""Set the node parser."""
self._node_parser = node_parser
@property
def chunk_size(self) -> int:
"""Get the chunk size."""
if hasattr(self.node_parser, "chunk_size"):
return self.node_parser.chunk_size
else:
raise ValueError("Configured node parser does not have chunk size.")
@chunk_size.setter
def chunk_size(self, chunk_size: int) -> None:
"""Set the chunk size."""
if hasattr(self.node_parser, "chunk_size"):
self.node_parser.chunk_size = chunk_size
else:
raise ValueError("Configured node parser does not have chunk size.")
@property
def chunk_overlap(self) -> int:
"""Get the chunk overlap."""
if hasattr(self.node_parser, "chunk_overlap"):
return self.node_parser.chunk_overlap
else:
raise ValueError("Configured node parser does not have chunk overlap.")
@chunk_overlap.setter
def chunk_overlap(self, chunk_overlap: int) -> None:
"""Set the chunk overlap."""
if hasattr(self.node_parser, "chunk_overlap"):
self.node_parser.chunk_overlap = chunk_overlap
else:
raise ValueError("Configured node parser does not have chunk overlap.")
# ---- Node parser alias ----
@property
def text_splitter(self) -> NodeParser:
"""Get the text splitter."""
return self.node_parser
@text_splitter.setter
def text_splitter(self, text_splitter: NodeParser) -> None:
"""Set the text splitter."""
self.node_parser = text_splitter
@property
def prompt_helper(self) -> PromptHelper:
"""Get the prompt helper."""
if self._llm is not None and self._prompt_helper is None:
self._prompt_helper = PromptHelper.from_llm_metadata(self._llm.metadata)
elif self._prompt_helper is None:
self._prompt_helper = PromptHelper()
return self._prompt_helper
@prompt_helper.setter
def prompt_helper(self, prompt_helper: PromptHelper) -> None:
"""Set the prompt helper."""
self._prompt_helper = prompt_helper
@property
def num_output(self) -> int:
"""Get the number of outputs."""
return self.prompt_helper.num_output
@num_output.setter
def num_output(self, num_output: int) -> None:
"""Set the number of outputs."""
self.prompt_helper.num_output = num_output
@property
def context_window(self) -> int:
"""Get the context window."""
return self.prompt_helper.context_window
@context_window.setter
def context_window(self, context_window: int) -> None:
"""Set the context window."""
self.prompt_helper.context_window = context_window
# ---- Transformations ----
@property
def transformations(self) -> List[TransformComponent]:
"""Get the transformations."""
if self._transformations is None:
self._transformations = [self.node_parser]
return self._transformations
@transformations.setter
def transformations(self, transformations: List[TransformComponent]) -> None:
"""Set the transformations."""
self._transformations = transformations
# Singleton
Settings = _Settings()
# -- Helper functions for deprecation/migration --
def llm_from_settings_or_context(
settings: _Settings, context: Optional["ServiceContext"]
) -> LLM:
"""Get settings from either settings or context."""
if context is not None:
return context.llm
return settings.llm
def embed_model_from_settings_or_context(
settings: _Settings, context: Optional["ServiceContext"]
) -> BaseEmbedding:
"""Get settings from either settings or context."""
if context is not None:
return context.embed_model
return settings.embed_model
def callback_manager_from_settings_or_context(
settings: _Settings, context: Optional["ServiceContext"]
) -> CallbackManager:
"""Get settings from either settings or context."""
if context is not None:
return context.callback_manager
return settings.callback_manager
def node_parser_from_settings_or_context(
settings: _Settings, context: Optional["ServiceContext"]
) -> NodeParser:
"""Get settings from either settings or context."""
if context is not None:
return context.node_parser
return settings.node_parser
def transformations_from_settings_or_context(
settings: _Settings, context: Optional["ServiceContext"]
) -> List[TransformComponent]:
"""Get settings from either settings or context."""
if context is not None:
return context.transformations
return settings.transformations