-
Notifications
You must be signed in to change notification settings - Fork 324
/
VectorIndexRetriever.ts
72 lines (62 loc) · 1.98 KB
/
VectorIndexRetriever.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import { globalsHelper } from "../../GlobalsHelper";
import { NodeWithScore } from "../../Node";
import { BaseRetriever } from "../../Retriever";
import { ServiceContext } from "../../ServiceContext";
import { Event } from "../../callbacks/CallbackManager";
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
import {
VectorStoreQuery,
VectorStoreQueryMode,
} from "../../storage/vectorStore/types";
import { VectorStoreIndex } from "./VectorStoreIndex";
/**
* VectorIndexRetriever retrieves nodes from a VectorIndex.
*/
export class VectorIndexRetriever implements BaseRetriever {
index: VectorStoreIndex;
similarityTopK;
private serviceContext: ServiceContext;
constructor({
index,
similarityTopK,
}: {
index: VectorStoreIndex;
similarityTopK?: number;
}) {
this.index = index;
this.serviceContext = this.index.serviceContext;
this.similarityTopK = similarityTopK ?? DEFAULT_SIMILARITY_TOP_K;
}
async retrieve(query: string, parentEvent?: Event): Promise<NodeWithScore[]> {
const queryEmbedding =
await this.serviceContext.embedModel.getQueryEmbedding(query);
const q: VectorStoreQuery = {
queryEmbedding: queryEmbedding,
mode: VectorStoreQueryMode.DEFAULT,
similarityTopK: this.similarityTopK,
};
const result = await this.index.vectorStore.query(q);
let nodesWithScores: NodeWithScore[] = [];
for (let i = 0; i < result.ids.length; i++) {
const node = this.index.indexStruct.nodesDict[result.ids[i]];
nodesWithScores.push({
node: node,
score: result.similarities[i],
});
}
if (this.serviceContext.callbackManager.onRetrieve) {
this.serviceContext.callbackManager.onRetrieve({
query,
nodes: nodesWithScores,
event: globalsHelper.createEvent({
parentEvent,
type: "retrieve",
}),
});
}
return nodesWithScores;
}
getServiceContext(): ServiceContext {
return this.serviceContext;
}
}