-
Notifications
You must be signed in to change notification settings - Fork 66
/
sequential_action.py
277 lines (204 loc) · 7.74 KB
/
sequential_action.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import math # noqa: F401
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional
from fvalues import F
from ice.recipe import recipe
from ice.recipes.primer.search_string import search_string
Log = list[str]
def render_enumerate(items: Sequence[object]) -> str:
"""Render numbered list, one per line"""
return F("\n\n").join(F(f"{i+1}. {item}") for i, item in enumerate(items))
def render_context(question: str, log: Log) -> str:
question_context = F(f'The question you want to answer: "{question}"')
if not log:
return question_context
return F(
f"""{question_context}
What you've done so far:
{render_enumerate(log)}"""
)
def render_action_context(question: str, log: Log, max_actions: int) -> str:
action_count_text = (
"You have one action left. (The one you're taking right now.)"
if max_actions == 1
else F(
f"You have {max_actions} actions left. (The one you're taking right now, and {max_actions - 1} follow-up actions.)"
)
)
return F(
f"""{render_context(question, log)}
{action_count_text}"""
)
def make_knowledge_prompt(question: str, log: Log) -> str:
return F(
f"""{render_context(question, log)}
Q: Do you have enough information to correctly answer the question? Say "A: Yes" or "A: No"
A:"""
)
def make_answer_prompt(question: str, log: Log) -> str:
return F(
f"""{render_context(question, log)}
Q: {question}
A:"""
)
async def is_info_sufficient(question: str, log: Log) -> bool:
knowledge_prompt = make_knowledge_prompt(question, log)
has_knowledge_probs, _ = await recipe.agent().classify(
prompt=knowledge_prompt, choices=(" Yes", " No")
)
return has_knowledge_probs.get(" Yes", 0.0) > 0.7
async def answer_directly(question: str, log: Log) -> str:
answer_prompt = make_answer_prompt(question, log)
answer = await recipe.agent("instruct-reasoning").complete(prompt=answer_prompt)
return answer
class Action(ABC):
@classmethod
@abstractmethod
async def propose(cls, question: str, log: Log, max_actions: int) -> "Action":
...
@abstractmethod
def run(self):
...
@abstractmethod
def make_log_entry(self, result: str) -> str:
...
@dataclass
class CalculationAction(Action):
calculation: str
@classmethod
def make_proposal_prompt(cls, question: str, log: Log, max_actions: int) -> str:
return F(
f"""{render_action_context(question, log, max_actions)}
You have chosen to take the action "Do a calculation".
You have access to a Python interpreter. What single-line calculation would most help you answer the question "{question}"?
>>> import math
>>>"""
)
@classmethod
async def propose(
cls, question: str, log: Log, max_actions: int
) -> "CalculationAction":
calculation_prompt = cls.make_proposal_prompt(question, log, max_actions)
calculation = await recipe.agent("instruct-reasoning").complete(
prompt=calculation_prompt, stop="\n"
)
return cls(calculation)
async def run(self) -> str:
try:
return str(eval(self.calculation))
except Exception as e:
return F(f"Error: {e}")
def make_log_entry(self, result: str) -> str:
return F(f"You calculated '{self.calculation}' and got the result '{result}'.")
def __str__(self):
return F(f"Do calculation: {self.calculation}")
@dataclass
class WebSearchAction(Action):
search_term: str
@classmethod
def make_proposal_prompt(cls, question: str, log: Log, max_actions: int) -> str:
return F(
f"""{render_action_context(question, log, max_actions)}
You have chosen to take the action "Run a web search".
What is a first web search query you could run to help you answer the question "{question}"?
Query:"""
)
@classmethod
async def propose(
cls, question: str, log: Log, max_actions: int
) -> "WebSearchAction":
search_term_prompt = cls.make_proposal_prompt(question, log, max_actions)
search_term = await recipe.agent("instruct-reasoning").complete(
prompt=search_term_prompt, stop='"'
)
return cls(search_term)
async def run(self) -> str:
results_str = await search_string(self.search_term)
return results_str
def make_log_entry(self, result: str) -> str:
return F(
f"You searched the web for '{self.search_term}' and got the result '{result}'."
)
def __str__(self):
return F(f"Run web search: {self.search_term}")
async def get_action_candidates(
question: str, log: Log, max_actions: int
) -> list[Action]:
calculation_action = await CalculationAction.propose(question, log, max_actions)
websearch_action = await WebSearchAction.propose(question, log, max_actions)
return [calculation_action, websearch_action]
def render_numbers(n: int) -> str:
numbers = ", ".join(str(i) for i in range(1, n))
return "{} or {}".format(numbers, n) if n > 1 else str(n)
def make_action_choice_prompt(
question: str, log: Log, actions: list[Action], max_actions: int
) -> str:
follow_up_text = (
""
if max_actions == 1
else F(f", and {max_actions - 1} similar follow-up actions")
)
return F(
f"""{render_context(question, log)}
You can take one of the following actions now{follow_up_text} before you need to answer:
{render_enumerate(actions)}
Question: What next action should you take to make progress on answering the question "{question}"? {render_numbers(len(actions))}?
Answer:"""
)
async def choose_action(
question: str, log: Log, actions: list[Action], max_actions
) -> Action:
action_choice_prompt = make_action_choice_prompt(
question, log, actions, max_actions
)
action_choice_probs = await get_action_choice_probs(action_choice_prompt, actions)
best_action_index = get_best_action_index(action_choice_probs)
return actions[best_action_index]
async def get_action_choice_probs(action_choice_prompt, actions):
action_choice_probs, _ = await recipe.agent("instruct-reasoning-crowd").classify(
prompt=action_choice_prompt,
choices=tuple(F(f" {i}") for i in range(1, len(actions) + 1)),
)
return action_choice_probs
def get_best_action_index(action_choice_probs):
best_action_index = None
best_action_prob = 0
for action_index, action_prob in action_choice_probs.items():
if action_prob > best_action_prob:
best_action_index = int(action_index.strip())
best_action_prob = action_prob
assert best_action_index is not None
return best_action_index - 1
async def gather_info(
*,
question: str,
log: Optional[Log] = None,
max_actions: int = 3,
) -> Log:
if log is None:
log = []
actions = await get_action_candidates(question, log, max_actions)
chosen_action = await choose_action(question, log, actions, max_actions)
result = await chosen_action.run()
return log + [chosen_action.make_log_entry(result)]
async def sequential_action(
*,
question: str = "How far would all the film frames that make up the 400-plus episodes of The Simpsons stretch?",
max_actions: int = 3,
):
log: list[str] = []
for actions_left in range(max_actions, 0, -1):
sufficient_info = await is_info_sufficient(question, log)
if sufficient_info:
break
log = await gather_info(
question=question,
log=log,
max_actions=actions_left,
)
answer = await answer_directly(question, log)
return answer
recipe.main(sequential_action)