-
Notifications
You must be signed in to change notification settings - Fork 314
/
ContextChatEngine.ts
121 lines (113 loc) · 3.85 KB
/
ContextChatEngine.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import type { ChatHistory } from "../../ChatHistory.js";
import { getHistory } from "../../ChatHistory.js";
import type { ContextSystemPrompt } from "../../Prompt.js";
import { Response } from "../../Response.js";
import type { BaseRetriever } from "../../Retriever.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatMessage, ChatResponseChunk, LLM } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import type { MessageContent } from "../../llm/types.js";
import {
extractText,
streamConverter,
streamReducer,
} from "../../llm/utils.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import { PromptMixin } from "../../prompts/Mixin.js";
import { DefaultContextGenerator } from "./DefaultContextGenerator.js";
import type {
ChatEngine,
ChatEngineParamsNonStreaming,
ChatEngineParamsStreaming,
ContextGenerator,
} from "./types.js";
/**
* ContextChatEngine uses the Index to get the appropriate context for each query.
* The context is stored in the system prompt, and the chat history is preserved,
* ideally allowing the appropriate context to be surfaced for each query.
*/
export class ContextChatEngine extends PromptMixin implements ChatEngine {
chatModel: LLM;
chatHistory: ChatHistory;
contextGenerator: ContextGenerator;
constructor(init: {
retriever: BaseRetriever;
chatModel?: LLM;
chatHistory?: ChatMessage[];
contextSystemPrompt?: ContextSystemPrompt;
nodePostprocessors?: BaseNodePostprocessor[];
}) {
super();
this.chatModel =
init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" });
this.chatHistory = getHistory(init?.chatHistory);
this.contextGenerator = new DefaultContextGenerator({
retriever: init.retriever,
contextSystemPrompt: init?.contextSystemPrompt,
nodePostprocessors: init?.nodePostprocessors,
});
}
protected _getPromptModules(): Record<string, ContextGenerator> {
return {
contextGenerator: this.contextGenerator,
};
}
chat(params: ChatEngineParamsStreaming): Promise<AsyncIterable<Response>>;
chat(params: ChatEngineParamsNonStreaming): Promise<Response>;
@wrapEventCaller
async chat(
params: ChatEngineParamsStreaming | ChatEngineParamsNonStreaming,
): Promise<Response | AsyncIterable<Response>> {
const { message, stream } = params;
const chatHistory = params.chatHistory
? getHistory(params.chatHistory)
: this.chatHistory;
const requestMessages = await this.prepareRequestMessages(
message,
chatHistory,
);
if (stream) {
const stream = await this.chatModel.chat({
messages: requestMessages.messages,
stream: true,
});
return streamConverter(
streamReducer({
stream,
initialValue: "",
reducer: (accumulator, part) => (accumulator += part.delta),
finished: (accumulator) => {
chatHistory.addMessage({ content: accumulator, role: "assistant" });
},
}),
(r: ChatResponseChunk) => new Response(r.delta, requestMessages.nodes),
);
}
const response = await this.chatModel.chat({
messages: requestMessages.messages,
});
chatHistory.addMessage(response.message);
return new Response(
extractText(response.message.content),
requestMessages.nodes,
);
}
reset() {
this.chatHistory.reset();
}
private async prepareRequestMessages(
message: MessageContent,
chatHistory: ChatHistory,
) {
chatHistory.addMessage({
content: message,
role: "user",
});
const textOnly = extractText(message);
const context = await this.contextGenerator.generate(textOnly);
const messages = await chatHistory.requestMessages(
context ? [context.message] : undefined,
);
return { nodes: context.nodes, messages };
}
}