-
Notifications
You must be signed in to change notification settings - Fork 360
/
RetrieverQueryEngine.ts
90 lines (80 loc) · 2.61 KB
/
RetrieverQueryEngine.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import type { NodeWithScore } from "../../Node.js";
import type { Response } from "../../Response.js";
import type { BaseRetriever } from "../../Retriever.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import { PromptMixin } from "../../prompts/Mixin.js";
import type { BaseSynthesizer } from "../../synthesizers/index.js";
import { ResponseSynthesizer } from "../../synthesizers/index.js";
import type {
QueryEngine,
QueryEngineParamsNonStreaming,
QueryEngineParamsStreaming,
} from "../../types.js";
/**
* A query engine that uses a retriever to query an index and then synthesizes the response.
*/
export class RetrieverQueryEngine extends PromptMixin implements QueryEngine {
retriever: BaseRetriever;
responseSynthesizer: BaseSynthesizer;
nodePostprocessors: BaseNodePostprocessor[];
preFilters?: unknown;
constructor(
retriever: BaseRetriever,
responseSynthesizer?: BaseSynthesizer,
preFilters?: unknown,
nodePostprocessors?: BaseNodePostprocessor[],
) {
super();
this.retriever = retriever;
this.responseSynthesizer =
responseSynthesizer ||
new ResponseSynthesizer({
serviceContext: retriever.serviceContext,
});
this.preFilters = preFilters;
this.nodePostprocessors = nodePostprocessors || [];
}
_getPromptModules() {
return {
responseSynthesizer: this.responseSynthesizer,
};
}
private async applyNodePostprocessors(nodes: NodeWithScore[], query: string) {
let nodesWithScore = nodes;
for (const postprocessor of this.nodePostprocessors) {
nodesWithScore = await postprocessor.postprocessNodes(
nodesWithScore,
query,
);
}
return nodesWithScore;
}
private async retrieve(query: string) {
const nodes = await this.retriever.retrieve({
query,
preFilters: this.preFilters,
});
return await this.applyNodePostprocessors(nodes, query);
}
query(params: QueryEngineParamsStreaming): Promise<AsyncIterable<Response>>;
query(params: QueryEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async query(
params: QueryEngineParamsStreaming | QueryEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
const { query, stream } = params;
const nodesWithScore = await this.retrieve(query);
if (stream) {
return this.responseSynthesizer.synthesize({
query,
nodesWithScore,
stream: true,
});
}
return this.responseSynthesizer.synthesize({
query,
nodesWithScore,
});
}
}