-
Notifications
You must be signed in to change notification settings - Fork 314
/
QuestionGenerator.ts
56 lines (49 loc) · 1.49 KB
/
QuestionGenerator.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import {
BaseOutputParser,
StructuredOutput,
SubQuestionOutputParser,
} from "./OutputParser";
import {
SubQuestionPrompt,
buildToolsText,
defaultSubQuestionPrompt,
} from "./Prompt";
import { ToolMetadata } from "./Tool";
import { LLM, OpenAI } from "./llm/LLM";
export interface SubQuestion {
subQuestion: string;
toolName: string;
}
/**
* QuestionGenerators generate new questions for the LLM using tools and a user query.
*/
export interface BaseQuestionGenerator {
generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]>;
}
/**
* LLMQuestionGenerator uses the LLM to generate new questions for the LLM using tools and a user query.
*/
export class LLMQuestionGenerator implements BaseQuestionGenerator {
llm: LLM;
prompt: SubQuestionPrompt;
outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>;
constructor(init?: Partial<LLMQuestionGenerator>) {
this.llm = init?.llm ?? new OpenAI();
this.prompt = init?.prompt ?? defaultSubQuestionPrompt;
this.outputParser = init?.outputParser ?? new SubQuestionOutputParser();
}
async generate(tools: ToolMetadata[], query: string): Promise<SubQuestion[]> {
const toolsStr = buildToolsText(tools);
const queryStr = query;
const prediction = (
await this.llm.complete(
this.prompt({
toolsStr,
queryStr,
}),
)
).message.content;
const structuredOutput = this.outputParser.parse(prediction);
return structuredOutput.parsedOutput;
}
}