/
flappy-agent.ts
200 lines (180 loc) · 7.2 KB
/
flappy-agent.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import { type FlappyAgentInterface, type FlappyAgentConfig } from './flappy-agent.interface'
import { type LLMBase } from './llms/llm-base'
import { type ChatMLResponse, type ChatMLMessage } from './llms/interface'
import { STEP_PREFIX } from './flappy-agent.constants'
import { z } from './flappy-type'
import { convertJsonToYaml, zodToCleanJsonSchema, log } from './utils'
import { type FindFlappyFeature, type FlappyFeatureNames, type AnyFlappyFeature } from './flappy-feature'
import { type JsonObject } from 'roarr/dist/types'
import { templateRenderer } from './renderer'
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
const lanOutputSchema = (enableCoT: boolean) => {
const baseStep = {
id: z.number().int().positive().describe('Increment id starting from 1'),
functionName: z.string(),
args: z
.record(z.any())
.describe(
`an object encapsulating all arguments for a function call. If an argument's value is derived from the return of a previous step, it should be as '${STEP_PREFIX}' + the ID of the previous step (e.g. '${STEP_PREFIX}1'). If an argument's value is derived from the **previous** step's function's return value's properties, '.' should be used to access its properties, else just use id with prefix. This approach should remain compatible with the 'args' attribute in the function's JSON schema.`
)
}
const thought = enableCoT
? {
thought: z.string().describe('The thought why this step is needed.')
}
: ({} as any)
return z
.array(
z.object({
...thought,
...baseStep
})
)
.describe('An array storing the steps.')
}
const DEFAULT_RETRY = 1
export class FlappyAgent<
TFeatures extends readonly AnyFlappyFeature[] = readonly AnyFlappyFeature[],
TNames extends string = FlappyFeatureNames<TFeatures>
> implements FlappyAgentInterface
{
config: FlappyAgentConfig<TFeatures>
llm: LLMBase
llmPlaner: LLMBase
retry: number
constructor(config: FlappyAgentConfig<TFeatures>) {
this.config = config
this.llm = config.llm
this.llmPlaner = config.llmPlaner ?? config.llm
this.retry = config.retry ?? DEFAULT_RETRY
}
/**
* Get function definitions as a JSON Schema object array.
*/
public featuresDefinitions(): object[] {
return this.config.features.map((fn: AnyFlappyFeature) => fn.callingSchema)
}
/**
* Find function by name.
*/
public findFeature<TName extends TNames, TFunction extends AnyFlappyFeature = FindFlappyFeature<TFeatures, TName>>(
name: TName
): TFunction {
const fn = this.config.features.find((fn: AnyFlappyFeature) => fn.define.name === name)
if (!fn) throw new Error(`Function definition not found: ${name}`)
return fn as TFunction
}
/**
* Call a feature by name.
*/
public async callFeature<
TName extends TNames,
TFunction extends AnyFlappyFeature = FindFlappyFeature<TFeatures, TName>
>(name: TName, args: Parameters<TFunction['call']>[1]): Promise<ReturnType<TFunction['call']>> {
const fn = this.findFeature(name)
// eslint-disable-next-line @typescript-eslint/return-await
return await fn.call(this, args)
}
public executePlanSystemMessage(enableCot: boolean = true): ChatMLMessage {
const functions = convertJsonToYaml(this.featuresDefinitions())
const zodSchema = lanOutputSchema(enableCot)
const returnSchema = JSON.stringify(zodToCleanJsonSchema(zodSchema), null, 4)
return {
role: 'system',
content: templateRenderer('agent/systemMessage', { functions, returnSchema })
}
}
/**
* executePlan
* @param prompt user input prompt
* @param enableCot enable CoT to improve the plan quality, but it will be generally more tokens. Default is true.
*/
public async executePlan(prompt: string, enableCot: boolean = true): Promise<any> {
log.debug('Start planing')
const zodSchema = lanOutputSchema(enableCot)
const originalRequestMessage: ChatMLMessage[] = [
this.executePlanSystemMessage(enableCot),
{ role: 'user', content: templateRenderer('agent/userMessage', { prompt }) }
]
let requestMessage = originalRequestMessage
let plan: any[] = []
let retry = this.retry
let result: ChatMLResponse | undefined
while (true) {
try {
if (retry !== this.retry) log.debug(`Attempt retry: ${this.retry - retry}`)
log.debug({ data: requestMessage } as unknown as JsonObject, 'Submit the request message')
result = await this.llmPlaner.chatComplete(requestMessage)
plan = this.parseComplete(result)
// check for function calling in each step
for (const step of plan) {
const fn = this.findFeature(step.functionName)
if (!fn) throw new Error(`Function definition not found: ${step.functionName}`)
}
break
} catch (err) {
console.error(err)
if (retry <= 0) throw new Error('Interrupted, create plan failed. Please refer to the error message above.')
retry -= 1
// if the response came from chatComplete is failed, retry it directly.
// Otherwise, update message for repairing
if (result?.success && result.data) {
requestMessage = [
...originalRequestMessage,
{
role: 'assistant',
content: result?.data ?? ''
},
{
role: 'user',
content: templateRenderer('error/retry', { message: (err as Error).message })
}
]
}
}
}
zodSchema.parse(plan)
const returnStore = new Map()
for (let i = 0; i < plan.length; i++) {
const step = plan[i]
const fn = this.findFeature<TNames>(step.functionName)
const args = Object.fromEntries(
Object.entries(step.args).map(([k, v]) => {
if (typeof v === 'string' && v.startsWith(STEP_PREFIX)) {
const keys = v.slice(STEP_PREFIX.length).split('.')
const stepId = parseInt(keys[0]!, 10)
const stepResult = returnStore.get(stepId)
if (keys.length === 1) return [k, stepResult]
// access object property
return [k, keys.slice(1).reduce((acc, cur) => acc[cur], stepResult)]
}
return [k, v]
})
)
log.debug(`Start step ${i + 1}`)
log.debug(`Start function call: ${step.functionName}`)
const result = await fn.call(this, args)
log.debug(`End Function call: ${step.functionName}`)
returnStore.set(step.id, result)
}
// step id starts from 1, so plan.length is the last step id
// return the last step result
return returnStore.get(plan.length)
}
protected parseComplete(resp: ChatMLResponse): any[] {
const startIdx = resp.data!.indexOf('[')
const endIdx = resp.data!.lastIndexOf(']')
if (!(startIdx >= 0 && endIdx > startIdx)) throw new Error('Invalid JSON response')
const json = JSON.parse(resp.data!.slice(startIdx, endIdx + 1))
log.debug({ data: json }, 'The plan details')
return json
}
}
/**
* Create a flappy agent.
* @param config
* @returns
*/
export const createFlappyAgent = <const TFeatures extends readonly AnyFlappyFeature[]>(
config: FlappyAgentConfig<TFeatures>
): FlappyAgent<TFeatures> => new FlappyAgent(config)