-
Notifications
You must be signed in to change notification settings - Fork 310
/
SummaryIndexRetriever.ts
134 lines (116 loc) · 4.13 KB
/
SummaryIndexRetriever.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import _ from "lodash";
import { globalsHelper } from "../../GlobalsHelper";
import { NodeWithScore } from "../../Node";
import { ChoiceSelectPrompt, defaultChoiceSelectPrompt } from "../../Prompt";
import { BaseRetriever } from "../../Retriever";
import { ServiceContext } from "../../ServiceContext";
import { Event } from "../../callbacks/CallbackManager";
import { SummaryIndex } from "./SummaryIndex";
import {
ChoiceSelectParserFunction,
NodeFormatterFunction,
defaultFormatNodeBatchFn,
defaultParseChoiceSelectAnswerFn,
} from "./utils";
/**
* Simple retriever for SummaryIndex that returns all nodes
*/
export class SummaryIndexRetriever implements BaseRetriever {
index: SummaryIndex;
constructor(index: SummaryIndex) {
this.index = index;
}
async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> {
const nodeIds = this.index.indexStruct.nodes;
const nodes = await this.index.docStore.getNodes(nodeIds);
const result = nodes.map((node) => ({
node: node,
score: 1,
}));
if (this.index.serviceContext.callbackManager.onRetrieve) {
this.index.serviceContext.callbackManager.onRetrieve({
query,
nodes: result,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
return result;
}
getServiceContext(): ServiceContext {
return this.index.serviceContext;
}
}
/**
* LLM retriever for SummaryIndex which lets you select the most relevant chunks.
*/
export class SummaryIndexLLMRetriever implements BaseRetriever {
index: SummaryIndex;
choiceSelectPrompt: ChoiceSelectPrompt;
choiceBatchSize: number;
formatNodeBatchFn: NodeFormatterFunction;
parseChoiceSelectAnswerFn: ChoiceSelectParserFunction;
serviceContext: ServiceContext;
constructor(
index: SummaryIndex,
choiceSelectPrompt?: ChoiceSelectPrompt,
choiceBatchSize: number = 10,
formatNodeBatchFn?: NodeFormatterFunction,
parseChoiceSelectAnswerFn?: ChoiceSelectParserFunction,
serviceContext?: ServiceContext,
) {
this.index = index;
this.choiceSelectPrompt = choiceSelectPrompt || defaultChoiceSelectPrompt;
this.choiceBatchSize = choiceBatchSize;
this.formatNodeBatchFn = formatNodeBatchFn || defaultFormatNodeBatchFn;
this.parseChoiceSelectAnswerFn =
parseChoiceSelectAnswerFn || defaultParseChoiceSelectAnswerFn;
this.serviceContext = serviceContext || index.serviceContext;
}
async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> {
const nodeIds = this.index.indexStruct.nodes;
const results: NodeWithScore[] = [];
for (let idx = 0; idx < nodeIds.length; idx += this.choiceBatchSize) {
const nodeIdsBatch = nodeIds.slice(idx, idx + this.choiceBatchSize);
const nodesBatch = await this.index.docStore.getNodes(nodeIdsBatch);
const fmtBatchStr = this.formatNodeBatchFn(nodesBatch);
const input = { context: fmtBatchStr, query: query };
const rawResponse = (
await this.serviceContext.llm.complete(this.choiceSelectPrompt(input))
).message.content;
// parseResult is a map from doc number to relevance score
const parseResult = this.parseChoiceSelectAnswerFn(
rawResponse,
nodesBatch.length,
);
const choiceNodeIds = nodeIdsBatch.filter((nodeId, idx) => {
return `${idx}` in parseResult;
});
const choiceNodes = await this.index.docStore.getNodes(choiceNodeIds);
const nodeWithScores = choiceNodes.map((node, i) => ({
node: node,
score: _.get(parseResult, `${i + 1}`, 1),
}));
results.push(...nodeWithScores);
}
if (this.serviceContext.callbackManager.onRetrieve) {
this.serviceContext.callbackManager.onRetrieve({
query,
nodes: results,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
return results;
}
getServiceContext(): ServiceContext {
return this.serviceContext;
}
}
// Legacy
export type ListIndexRetriever = SummaryIndexRetriever;
export type ListIndexLLMRetriever = SummaryIndexLLMRetriever;