/
pydantic_selectors.py
166 lines (142 loc) · 5.5 KB
/
pydantic_selectors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
from llama_index.core.base.base_selector import (
BaseSelector,
MultiSelection,
SelectorResult,
SingleSelection,
)
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.schema import QueryBundle
from llama_index.core.selectors.llm_selectors import _build_choices_text
from llama_index.core.selectors.prompts import (
DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
)
from llama_index.core.tools.types import ToolMetadata
from llama_index.core.types import BasePydanticProgram
if TYPE_CHECKING:
from llama_index.llms.openai import OpenAI # pants: no-infer-dep
def _pydantic_output_to_selector_result(output: Any) -> SelectorResult:
"""
Convert pydantic output to selector result.
Takes into account zero-indexing on answer indexes.
"""
if isinstance(output, SingleSelection):
output.index -= 1
return SelectorResult(selections=[output])
elif isinstance(output, MultiSelection):
for idx in range(len(output.selections)):
output.selections[idx].index -= 1
return SelectorResult(selections=output.selections)
else:
raise ValueError(f"Unsupported output type: {type(output)}")
class PydanticSingleSelector(BaseSelector):
def __init__(self, selector_program: BasePydanticProgram) -> None:
self._selector_program = selector_program
@classmethod
def from_defaults(
cls,
program: Optional[BasePydanticProgram] = None,
llm: Optional["OpenAI"] = None,
prompt_template_str: str = DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
verbose: bool = False,
) -> "PydanticSingleSelector":
try:
from llama_index.program.openai import (
OpenAIPydanticProgram,
) # pants: no-infer-dep
except ImportError as e:
raise ImportError(
"`llama-index-program-openai` package is missing. "
"Please install using `pip install llama-index-program-openai`."
)
if program is None:
program = OpenAIPydanticProgram.from_defaults(
output_cls=SingleSelection,
prompt_template_str=prompt_template_str,
llm=llm,
verbose=verbose,
)
return cls(selector_program=program)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
# TODO: no accessible prompts for a base pydantic program
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def _select(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
# prepare input
choices_text = _build_choices_text(choices)
# predict
prediction = self._selector_program(
num_choices=len(choices),
context_list=choices_text,
query_str=query.query_str,
)
# parse output
return _pydantic_output_to_selector_result(prediction)
async def _aselect(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
raise NotImplementedError(
"Async selection not supported for Pydantic Selectors."
)
class PydanticMultiSelector(BaseSelector):
def __init__(
self, selector_program: BasePydanticProgram, max_outputs: Optional[int] = None
) -> None:
self._selector_program = selector_program
self._max_outputs = max_outputs
@classmethod
def from_defaults(
cls,
program: Optional[BasePydanticProgram] = None,
llm: Optional["OpenAI"] = None,
prompt_template_str: str = DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
max_outputs: Optional[int] = None,
verbose: bool = False,
) -> "PydanticMultiSelector":
try:
from llama_index.program.openai import (
OpenAIPydanticProgram,
) # pants: no-infer-dep
except ImportError as e:
raise ImportError(
"`llama-index-program-openai` package is missing. "
"Please install using `pip install llama-index-program-openai`."
)
if program is None:
program = OpenAIPydanticProgram.from_defaults(
output_cls=MultiSelection,
prompt_template_str=prompt_template_str,
llm=llm,
verbose=verbose,
)
return cls(selector_program=program, max_outputs=max_outputs)
def _get_prompts(self) -> Dict[str, Any]:
"""Get prompts."""
# TODO: no accessible prompts for a base pydantic program
return {}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
def _select(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
# prepare input
context_list = _build_choices_text(choices)
max_outputs = self._max_outputs or len(choices)
# predict
prediction = self._selector_program(
num_choices=len(choices),
max_outputs=max_outputs,
context_list=context_list,
query_str=query.query_str,
)
# parse output
return _pydantic_output_to_selector_result(prediction)
async def _aselect(
self, choices: Sequence[ToolMetadata], query: QueryBundle
) -> SelectorResult:
return self._select(choices, query)