-
Notifications
You must be signed in to change notification settings - Fork 305
/
JinaAIReranker.ts
89 lines (77 loc) · 2.26 KB
/
JinaAIReranker.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import { getEnv } from "@llamaindex/env";
import type { NodeWithScore } from "../../Node.js";
import { MetadataMode } from "../../Node.js";
import type { BaseNodePostprocessor } from "../types.js";
interface JinaAIRerankerResult {
index: number;
document?: {
text?: string;
};
relevance_score: number;
}
export class JinaAIReranker implements BaseNodePostprocessor {
model: string = "jina-reranker-v1-base-en";
topN?: number;
apiKey?: string = undefined;
constructor(init?: Partial<JinaAIReranker>) {
this.topN = init?.topN ?? 2;
this.model = init?.model ?? "jina-reranker-v1-base-en";
this.apiKey = getEnv("JINAAI_API_KEY");
if (!this.apiKey) {
throw new Error(
"Set Jina AI API Key in JINAAI_API_KEY env variable. Get one for free or top up your key at https://jina.ai/reranker",
);
}
}
async rerank(
query: string,
documents: string[],
topN: number | undefined = this.topN,
): Promise<JinaAIRerankerResult[]> {
const url = "https://api.jina.ai/v1/rerank";
const headers = {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
};
const data = {
model: this.model,
query: query,
documents: documents,
top_n: topN,
};
try {
const response = await fetch(url, {
method: "POST",
headers: headers,
body: JSON.stringify(data),
});
const jsonData = await response.json();
return jsonData.results;
} catch (error) {
console.error("Error while reranking:", error);
throw new Error("Failed to rerank documents due to an API error");
}
}
async postprocessNodes(
nodes: NodeWithScore[],
query?: string,
): Promise<NodeWithScore[]> {
if (nodes.length === 0) {
return [];
}
if (query === undefined) {
throw new Error("JinaAIReranker requires a query");
}
const documents = nodes.map((n) => n.node.getContent(MetadataMode.ALL));
const results = await this.rerank(query, documents, this.topN);
const newNodes: NodeWithScore[] = [];
for (const result of results) {
const node = nodes[result.index];
newNodes.push({
node: node.node,
score: result.relevance_score,
});
}
return newNodes;
}
}