-
Notifications
You must be signed in to change notification settings - Fork 421
[DRAFT] tool interrupt #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ | |
ConversationManager, | ||
SlidingWindowConversationManager, | ||
) | ||
from .execution_state import ExecutionState | ||
from .state import AgentState | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -142,6 +143,17 @@ def caller( | |
Raises: | ||
AttributeError: If the tool doesn't exist. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still need to figure out how to support interrupts in direct tool calls. I would prefer to allow users to pass in the resume context into the call (e.g., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have to support interruptions for direct tool calls? Seems a bit silly IMHO |
||
if record_direct_tool_call is not None: | ||
should_record_direct_tool_call = record_direct_tool_call | ||
else: | ||
should_record_direct_tool_call = self._agent.record_direct_tool_call | ||
|
||
if should_record_direct_tool_call and self._agent.execution_state != ExecutionState.ASSISTANT: | ||
raise RuntimeError( | ||
f"execution_state=<{self._agent.execution_state}> " | ||
f"| recording direct tool calls is only allowed in ASSISTANT execution state" | ||
) | ||
|
||
normalized_name = self._find_normalized_tool_name(name) | ||
|
||
# Create unique tool ID and set up the tool request | ||
|
@@ -167,11 +179,6 @@ def tcall() -> ToolResult: | |
future = executor.submit(tcall) | ||
tool_result = future.result() | ||
|
||
if record_direct_tool_call is not None: | ||
should_record_direct_tool_call = record_direct_tool_call | ||
else: | ||
should_record_direct_tool_call = self._agent.record_direct_tool_call | ||
|
||
if should_record_direct_tool_call: | ||
# Create a record of this tool execution in the message history | ||
self._agent._record_tool_execution(tool_use, tool_result, user_message_override) | ||
|
@@ -349,6 +356,10 @@ def __init__( | |
self.hooks.add_hook(hook) | ||
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) | ||
|
||
self.execution_state = ExecutionState.ASSISTANT | ||
|
||
self.interrupts = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If in an INTERRUPT execution state, the agent will hold a reference to the interrupts raised by the user. To get things working, I am storing the interrupts in a dictionary in memory. As a follow up, I will think of a more formal mechanism for storing the interrupt state that can also be serialized for session management. |
||
|
||
@property | ||
def tool(self) -> ToolCaller: | ||
"""Call tool as a function. | ||
|
@@ -540,6 +551,7 @@ async def stream_async( | |
Args: | ||
prompt: User input in various formats: | ||
- str: Simple text input | ||
- ContentBlock: Multi-modal content block | ||
- list[ContentBlock]: Multi-modal content blocks | ||
- list[Message]: Complete messages with roles | ||
- None: Use existing conversation history | ||
|
@@ -564,6 +576,8 @@ async def stream_async( | |
yield event["data"] | ||
``` | ||
""" | ||
self._resume(prompt) | ||
|
||
callback_handler = kwargs.get("callback_handler", self.callback_handler) | ||
|
||
# Process input and get message to add (if any) | ||
|
@@ -585,6 +599,11 @@ async def stream_async( | |
|
||
result = AgentResult(*event["stop"]) | ||
callback_handler(result=result) | ||
|
||
if result.stop_reason == "interrupt": | ||
self.execution_state = ExecutionState.INTERRUPT | ||
self.interrupts = {interrupt.name: interrupt for interrupt in result.interrupts} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
yield AgentResultEvent(result=result).as_dict() | ||
|
||
self._end_agent_trace_span(response=result) | ||
|
@@ -593,6 +612,16 @@ async def stream_async( | |
self._end_agent_trace_span(error=e) | ||
raise | ||
|
||
def _resume(self, prompt: AgentInput) -> None: | ||
if self.execution_state != ExecutionState.INTERRUPT: | ||
return | ||
|
||
if not isinstance(prompt, dict) or "resume" not in prompt: | ||
raise ValueError("<TODO>.") | ||
|
||
for interrupt in self.interrupts.values(): | ||
interrupt.resume = prompt["resume"][interrupt.name] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users fill in a resume content block that maps interrupt names to the user provided input required for resuming a tool execution after interrupt. |
||
|
||
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: | ||
"""Execute the agent's event loop with the given message and parameters. | ||
|
||
|
@@ -673,6 +702,8 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: | |
if isinstance(prompt, str): | ||
# String input - convert to user message | ||
messages = [{"role": "user", "content": [{"text": prompt}]}] | ||
elif isinstance(prompt, dict): | ||
messages = [{"role": "user", "content": prompt}] if "resume" not in prompt else [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do not add resume prompts to the model messages array since it is Strands specific. |
||
elif isinstance(prompt, list): | ||
if len(prompt) == 0: | ||
# Empty list | ||
|
@@ -692,7 +723,9 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: | |
else: | ||
messages = [] | ||
if messages is None: | ||
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") | ||
raise ValueError( | ||
"Input prompt must be of type: `str | ContentBlock | list[Contentblock] | Messages | None`." | ||
) | ||
return messages | ||
|
||
def _record_tool_execution( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,9 @@ | |
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any | ||
from typing import Any, Optional | ||
|
||
from ..hooks.interrupt import Interrupt | ||
from ..telemetry.metrics import EventLoopMetrics | ||
from ..types.content import Message | ||
from ..types.streaming import StopReason | ||
|
@@ -26,6 +27,7 @@ class AgentResult: | |
message: Message | ||
metrics: EventLoopMetrics | ||
state: Any | ||
interrupts: Optional[list[Interrupt]] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users can parse the raised interrupts from the AgentResult returned from the agent invoke. |
||
|
||
def __str__(self) -> str: | ||
"""Get the agent's last message as a string. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Agent execution state.""" | ||
|
||
from enum import Enum | ||
|
||
|
||
class ExecutionState(Enum): | ||
"""Represents the current execution state of an agent. | ||
ASSISTANT: Agent is waiting for user message (default). | ||
INTERRUPT: Agent is waiting for user feedback to resume tool execution. | ||
""" | ||
|
||
ASSISTANT = "assistant" | ||
INTERRUPT = "interrupt" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking we could add another state called "TOOL" to indicate that the agent is waiting for tool results. Under this state, users would be able to pass in tool result content blocks into agent invoke. This is something to consider for follow up though. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With interrupts, the agent now enters into a state that requires specific input from the user to continue. To track this, I created an ExecutionState enum. More details below.