-
Notifications
You must be signed in to change notification settings - Fork 295
/
types.ts
101 lines (80 loc) · 2.5 KB
/
types.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import type { BaseNode } from "../Node.js";
import { MetadataMode } from "../Node.js";
import type { TransformComponent } from "../ingestion/types.js";
import { SimilarityType, similarity } from "./utils.js";
const DEFAULT_EMBED_BATCH_SIZE = 10;
type EmbedFunc<T> = (values: T[]) => Promise<Array<number[]>>;
export abstract class BaseEmbedding implements TransformComponent {
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
similarity(
embedding1: number[],
embedding2: number[],
mode: SimilarityType = SimilarityType.DEFAULT,
): number {
return similarity(embedding1, embedding2, mode);
}
abstract getTextEmbedding(text: string): Promise<number[]>;
abstract getQueryEmbedding(query: string): Promise<number[]>;
/**
* Optionally override this method to retrieve multiple embeddings in a single request
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
const embeddings: number[][] = [];
for (const text of texts) {
const embedding = await this.getTextEmbedding(text);
embeddings.push(embedding);
}
return embeddings;
}
/**
* Get embeddings for a batch of texts
* @param texts
* @param options
*/
async getTextEmbeddingsBatch(
texts: string[],
options?: {
logProgress?: boolean;
},
): Promise<Array<number[]>> {
return await batchEmbeddings(
texts,
this.getTextEmbeddings.bind(this),
this.embedBatchSize,
options,
);
}
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.getTextEmbeddingsBatch(texts, _options);
for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i];
}
return nodes;
}
}
export async function batchEmbeddings<T>(
values: T[],
embedFunc: EmbedFunc<T>,
chunkSize: number,
options?: {
logProgress?: boolean;
},
): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = [];
const queue: T[] = values;
const curBatch: T[] = [];
for (let i = 0; i < queue.length; i++) {
curBatch.push(queue[i]);
if (i == queue.length - 1 || curBatch.length == chunkSize) {
const embeddings = await embedFunc(curBatch);
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`getting embedding progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
}
}
return resultEmbeddings;
}