-
Notifications
You must be signed in to change notification settings - Fork 305
/
ClipEmbedding.ts
79 lines (67 loc) · 2.31 KB
/
ClipEmbedding.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import type { ImageType } from "../Node.js";
import { MultiModalEmbedding } from "./MultiModalEmbedding.js";
import { readImage } from "./utils.js";
export enum ClipEmbeddingModelType {
XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32",
XENOVA_CLIP_VIT_BASE_PATCH16 = "Xenova/clip-vit-base-patch16",
}
export class ClipEmbedding extends MultiModalEmbedding {
modelType: ClipEmbeddingModelType =
ClipEmbeddingModelType.XENOVA_CLIP_VIT_BASE_PATCH16;
private tokenizer: any;
private processor: any;
private visionModel: any;
private textModel: any;
async getTokenizer() {
if (!this.tokenizer) {
const { AutoTokenizer } = await import("@xenova/transformers");
this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType);
}
return this.tokenizer;
}
async getProcessor() {
if (!this.processor) {
const { AutoProcessor } = await import("@xenova/transformers");
this.processor = await AutoProcessor.from_pretrained(this.modelType);
}
return this.processor;
}
async getVisionModel() {
if (!this.visionModel) {
const { CLIPVisionModelWithProjection } = await import(
"@xenova/transformers"
);
this.visionModel = await CLIPVisionModelWithProjection.from_pretrained(
this.modelType,
);
}
return this.visionModel;
}
async getTextModel() {
if (!this.textModel) {
const { CLIPTextModelWithProjection } = await import(
"@xenova/transformers"
);
this.textModel = await CLIPTextModelWithProjection.from_pretrained(
this.modelType,
);
}
return this.textModel;
}
async getImageEmbedding(image: ImageType): Promise<number[]> {
const loadedImage = await readImage(image);
const imageInputs = await (await this.getProcessor())(loadedImage);
const { image_embeds } = await (await this.getVisionModel())(imageInputs);
return Array.from(image_embeds.data);
}
async getTextEmbedding(text: string): Promise<number[]> {
const textInputs = await (
await this.getTokenizer()
)([text], { padding: true, truncation: true });
const { text_embeds } = await (await this.getTextModel())(textInputs);
return text_embeds.data;
}
async getQueryEmbedding(query: string): Promise<number[]> {
return this.getTextEmbedding(query);
}
}