-
Notifications
You must be signed in to change notification settings - Fork 323
/
LlamaCloudRetriever.ts
86 lines (78 loc) · 2.65 KB
/
LlamaCloudRetriever.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import type { PlatformApi, PlatformApiClient } from "@llamaindex/cloud";
import type { NodeWithScore } from "../Node.js";
import { ObjectType, jsonToNode } from "../Node.js";
import type { BaseRetriever, RetrieveParams } from "../Retriever.js";
import { Settings } from "../Settings.js";
import { wrapEventCaller } from "../internal/context/EventCaller.js";
import type { ClientParams, CloudConstructorParams } from "./types.js";
import { DEFAULT_PROJECT_NAME } from "./types.js";
import { getClient } from "./utils.js";
export type CloudRetrieveParams = Omit<
PlatformApi.RetrievalParams,
"query" | "searchFilters" | "pipelineId" | "className"
> & { similarityTopK?: number };
export class LlamaCloudRetriever implements BaseRetriever {
client?: PlatformApiClient;
clientParams: ClientParams;
retrieveParams: CloudRetrieveParams;
projectName: string = DEFAULT_PROJECT_NAME;
pipelineName: string;
private resultNodesToNodeWithScore(
nodes: PlatformApi.TextNodeWithScore[],
): NodeWithScore[] {
return nodes.map((node: PlatformApi.TextNodeWithScore) => {
return {
// Currently LlamaCloud only supports text nodes
node: jsonToNode(node.node, ObjectType.TEXT),
score: node.score,
};
});
}
constructor(params: CloudConstructorParams & CloudRetrieveParams) {
this.clientParams = { apiKey: params.apiKey, baseUrl: params.baseUrl };
if (params.similarityTopK) {
params.denseSimilarityTopK = params.similarityTopK;
}
this.retrieveParams = params;
this.pipelineName = params.name;
if (params.projectName) {
this.projectName = params.projectName;
}
}
private async getClient(): Promise<PlatformApiClient> {
if (!this.client) {
this.client = await getClient(this.clientParams);
}
return this.client;
}
@wrapEventCaller
async retrieve({
query,
preFilters,
}: RetrieveParams): Promise<NodeWithScore[]> {
const pipelines = await (
await this.getClient()
).pipeline.searchPipelines({
projectName: this.projectName,
pipelineName: this.pipelineName,
});
if (pipelines.length !== 1 && !pipelines[0]?.id) {
throw new Error(
`No pipeline found with name ${this.pipelineName} in project ${this.projectName}`,
);
}
const results = await (
await this.getClient()
).pipeline.runSearch(pipelines[0].id, {
...this.retrieveParams,
query,
searchFilters: preFilters as Record<string, unknown[]>,
});
const nodes = this.resultNodesToNodeWithScore(results.retrievalNodes);
Settings.callbackManager.dispatchEvent("retrieve", {
query,
nodes,
});
return nodes;
}
}