/
retrievers.py
183 lines (155 loc) · 6.93 KB
/
retrievers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""Document summary retrievers.
This module contains retrievers for document summary indices.
"""
import logging
from typing import Any, Callable, List, Optional
from llama_index.legacy.callbacks.base import CallbackManager
from llama_index.legacy.core.base_retriever import BaseRetriever
from llama_index.legacy.indices.document_summary.base import DocumentSummaryIndex
from llama_index.legacy.indices.utils import (
default_format_node_batch_fn,
default_parse_choice_select_answer_fn,
)
from llama_index.legacy.prompts import BasePromptTemplate
from llama_index.legacy.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT
from llama_index.legacy.schema import NodeWithScore, QueryBundle
from llama_index.legacy.service_context import ServiceContext
from llama_index.legacy.vector_stores.types import VectorStoreQuery
logger = logging.getLogger(__name__)
class DocumentSummaryIndexLLMRetriever(BaseRetriever):
"""Document Summary Index LLM Retriever.
By default, select relevant summaries from index using LLM calls.
Args:
index (DocumentSummaryIndex): The index to retrieve from.
choice_select_prompt (Optional[BasePromptTemplate]): The prompt to use for selecting relevant summaries.
choice_batch_size (int): The number of summary nodes to send to LLM at a time.
choice_top_k (int): The number of summary nodes to retrieve.
format_node_batch_fn (Callable): Function to format a batch of nodes for LLM.
parse_choice_select_answer_fn (Callable): Function to parse LLM response.
service_context (ServiceContext): The service context to use.
"""
def __init__(
self,
index: DocumentSummaryIndex,
choice_select_prompt: Optional[BasePromptTemplate] = None,
choice_batch_size: int = 10,
choice_top_k: int = 1,
format_node_batch_fn: Optional[Callable] = None,
parse_choice_select_answer_fn: Optional[Callable] = None,
service_context: Optional[ServiceContext] = None,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
self._index = index
self._choice_select_prompt = (
choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
)
self._choice_batch_size = choice_batch_size
self._choice_top_k = choice_top_k
self._format_node_batch_fn = (
format_node_batch_fn or default_format_node_batch_fn
)
self._parse_choice_select_answer_fn = (
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
)
self._service_context = service_context or index.service_context
super().__init__(
callback_manager=callback_manager, object_map=object_map, verbose=verbose
)
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Retrieve nodes."""
summary_ids = self._index.index_struct.summary_ids
all_summary_ids: List[str] = []
all_relevances: List[float] = []
for idx in range(0, len(summary_ids), self._choice_batch_size):
summary_ids_batch = summary_ids[idx : idx + self._choice_batch_size]
summary_nodes = self._index.docstore.get_nodes(summary_ids_batch)
query_str = query_bundle.query_str
fmt_batch_str = self._format_node_batch_fn(summary_nodes)
# call each batch independently
raw_response = self._service_context.llm.predict(
self._choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)
raw_choices, relevances = self._parse_choice_select_answer_fn(
raw_response, len(summary_nodes)
)
choice_idxs = [choice - 1 for choice in raw_choices]
choice_summary_ids = [summary_ids_batch[ci] for ci in choice_idxs]
all_summary_ids.extend(choice_summary_ids)
all_relevances.extend(relevances)
zipped_list = list(zip(all_summary_ids, all_relevances))
sorted_list = sorted(zipped_list, key=lambda x: x[1], reverse=True)
top_k_list = sorted_list[: self._choice_top_k]
results = []
for summary_id, relevance in top_k_list:
node_ids = self._index.index_struct.summary_id_to_node_ids[summary_id]
nodes = self._index.docstore.get_nodes(node_ids)
results.extend([NodeWithScore(node=n, score=relevance) for n in nodes])
return results
class DocumentSummaryIndexEmbeddingRetriever(BaseRetriever):
"""Document Summary Index Embedding Retriever.
Args:
index (DocumentSummaryIndex): The index to retrieve from.
similarity_top_k (int): The number of summary nodes to retrieve.
"""
def __init__(
self,
index: DocumentSummaryIndex,
similarity_top_k: int = 1,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
"""Init params."""
self._index = index
self._vector_store = self._index.vector_store
self._service_context = self._index.service_context
self._docstore = self._index.docstore
self._index_struct = self._index.index_struct
self._similarity_top_k = similarity_top_k
super().__init__(
callback_manager=callback_manager, object_map=object_map, verbose=verbose
)
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Retrieve nodes."""
if self._vector_store.is_embedding_query:
if query_bundle.embedding is None:
query_bundle.embedding = (
self._service_context.embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
query = VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=self._similarity_top_k,
)
query_result = self._vector_store.query(query)
top_k_summary_ids: List[str]
if query_result.ids is not None:
top_k_summary_ids = query_result.ids
elif query_result.nodes is not None:
top_k_summary_ids = [n.node_id for n in query_result.nodes]
else:
raise ValueError(
"Vector store query result should return "
"at least one of nodes or ids."
)
results = []
for summary_id in top_k_summary_ids:
node_ids = self._index_struct.summary_id_to_node_ids[summary_id]
nodes = self._docstore.get_nodes(node_ids)
results.extend([NodeWithScore(node=n) for n in nodes])
return results
# legacy, backward compatibility
DocumentSummaryIndexRetriever = DocumentSummaryIndexLLMRetriever