-
Notifications
You must be signed in to change notification settings - Fork 295
/
MultiModalResponseSynthesizer.ts
99 lines (90 loc) · 2.93 KB
/
MultiModalResponseSynthesizer.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import type { ImageNode } from "../Node.js";
import { MetadataMode, ModalityType, splitNodesByType } from "../Node.js";
import { Response } from "../Response.js";
import type { ServiceContext } from "../ServiceContext.js";
import { llmFromSettingsOrContext } from "../Settings.js";
import { imageToDataUrl } from "../embeddings/index.js";
import type { MessageContentDetail } from "../llm/types.js";
import { PromptMixin } from "../prompts/Mixin.js";
import type { TextQaPrompt } from "./../Prompt.js";
import { defaultTextQaPrompt } from "./../Prompt.js";
import type {
BaseSynthesizer,
SynthesizeParamsNonStreaming,
SynthesizeParamsStreaming,
} from "./types.js";
export class MultiModalResponseSynthesizer
extends PromptMixin
implements BaseSynthesizer
{
serviceContext?: ServiceContext;
metadataMode: MetadataMode;
textQATemplate: TextQaPrompt;
constructor({
serviceContext,
textQATemplate,
metadataMode,
}: Partial<MultiModalResponseSynthesizer> = {}) {
super();
this.serviceContext = serviceContext;
this.metadataMode = metadataMode ?? MetadataMode.NONE;
this.textQATemplate = textQATemplate ?? defaultTextQaPrompt;
}
protected _getPrompts(): { textQATemplate: TextQaPrompt } {
return {
textQATemplate: this.textQATemplate,
};
}
protected _updatePrompts(promptsDict: {
textQATemplate: TextQaPrompt;
}): void {
if (promptsDict.textQATemplate) {
this.textQATemplate = promptsDict.textQATemplate;
}
}
synthesize(
params: SynthesizeParamsStreaming,
): Promise<AsyncIterable<Response>>;
synthesize(params: SynthesizeParamsNonStreaming): Promise<Response>;
async synthesize({
query,
nodesWithScore,
stream,
}: SynthesizeParamsStreaming | SynthesizeParamsNonStreaming): Promise<
AsyncIterable<Response> | Response
> {
if (stream) {
throw new Error("streaming not implemented");
}
const nodes = nodesWithScore.map(({ node }) => node);
const nodeMap = splitNodesByType(nodes);
const imageNodes: ImageNode[] =
(nodeMap[ModalityType.IMAGE] as ImageNode[]) ?? [];
const textNodes = nodeMap[ModalityType.TEXT] ?? [];
const textChunks = textNodes.map((node) =>
node.getContent(this.metadataMode),
);
// TODO: use builders to generate context
const context = textChunks.join("\n\n");
const textPrompt = this.textQATemplate({ context, query });
const images = await Promise.all(
imageNodes.map(async (node: ImageNode) => {
return {
type: "image_url",
image_url: {
url: await imageToDataUrl(node.image),
},
} as MessageContentDetail;
}),
);
const prompt: MessageContentDetail[] = [
{ type: "text", text: textPrompt },
...images,
];
const llm = llmFromSettingsOrContext(this.serviceContext);
const response = await llm.complete({
prompt,
});
return new Response(response.text, nodesWithScore);
}
}