Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ConversationManager,
SlidingWindowConversationManager,
)
from .execution_state import ExecutionState
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With interrupts, the agent now enters into a state that requires specific input from the user to continue. To track this, I created an ExecutionState enum. More details below.

from .state import AgentState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,6 +143,17 @@ def caller(
Raises:
AttributeError: If the tool doesn't exist.
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to figure out how to support interrupts in direct tool calls. I would prefer to allow users to pass in the resume context into the call (e.g., agent.tool.my_tool({"resume": ...})). I don't think though we can add this in a backwards compatible manner.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to support interruptions for direct tool calls? Seems a bit silly IMHO

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

if should_record_direct_tool_call and self._agent.execution_state != ExecutionState.ASSISTANT:
raise RuntimeError(
f"execution_state=<{self._agent.execution_state}> "
f"| recording direct tool calls is only allowed in ASSISTANT execution state"
)

normalized_name = self._find_normalized_tool_name(name)

# Create unique tool ID and set up the tool request
Expand All @@ -167,11 +179,6 @@ def tcall() -> ToolResult:
future = executor.submit(tcall)
tool_result = future.result()

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
Expand Down Expand Up @@ -349,6 +356,10 @@ def __init__(
self.hooks.add_hook(hook)
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))

self.execution_state = ExecutionState.ASSISTANT

self.interrupts = {}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If in an INTERRUPT execution state, the agent will hold a reference to the interrupts raised by the user. To get things working, I am storing the interrupts in a dictionary in memory. As a follow up, I will think of a more formal mechanism for storing the interrupt state that can also be serialized for session management.


@property
def tool(self) -> ToolCaller:
"""Call tool as a function.
Expand Down Expand Up @@ -540,6 +551,7 @@ async def stream_async(
Args:
prompt: User input in various formats:
- str: Simple text input
- ContentBlock: Multi-modal content block
- list[ContentBlock]: Multi-modal content blocks
- list[Message]: Complete messages with roles
- None: Use existing conversation history
Expand All @@ -564,6 +576,8 @@ async def stream_async(
yield event["data"]
```
"""
self._resume(prompt)

callback_handler = kwargs.get("callback_handler", self.callback_handler)

# Process input and get message to add (if any)
Expand All @@ -585,6 +599,11 @@ async def stream_async(

result = AgentResult(*event["stop"])
callback_handler(result=result)

if result.stop_reason == "interrupt":
self.execution_state = ExecutionState.INTERRUPT
self.interrupts = {interrupt.name: interrupt for interrupt in result.interrupts}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Right now we are only supporting interrupts for tools. We use tool names for the interrupt names.
  • We support raising multiple interrupts in a single request to the agent because the agent can execute multiple tools in parallel.
  • Only one interrupt is allowed for each tool. However, users can provide multiple reasons for interrupting a tool. More details on this below.


yield AgentResultEvent(result=result).as_dict()

self._end_agent_trace_span(response=result)
Expand All @@ -593,6 +612,16 @@ async def stream_async(
self._end_agent_trace_span(error=e)
raise

def _resume(self, prompt: AgentInput) -> None:
if self.execution_state != ExecutionState.INTERRUPT:
return

if not isinstance(prompt, dict) or "resume" not in prompt:
raise ValueError("<TODO>.")

for interrupt in self.interrupts.values():
interrupt.resume = prompt["resume"][interrupt.name]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users fill in a resume content block that maps interrupt names to the user provided input required for resuming a tool execution after interrupt.


async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Execute the agent's event loop with the given message and parameters.

Expand Down Expand Up @@ -673,6 +702,8 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
if isinstance(prompt, str):
# String input - convert to user message
messages = [{"role": "user", "content": [{"text": prompt}]}]
elif isinstance(prompt, dict):
messages = [{"role": "user", "content": prompt}] if "resume" not in prompt else []
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not add resume prompts to the model messages array since it is Strands specific.

elif isinstance(prompt, list):
if len(prompt) == 0:
# Empty list
Expand All @@ -692,7 +723,9 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
else:
messages = []
if messages is None:
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
raise ValueError(
"Input prompt must be of type: `str | ContentBlock | list[Contentblock] | Messages | None`."
)
return messages

def _record_tool_execution(
Expand Down
4 changes: 3 additions & 1 deletion src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""

from dataclasses import dataclass
from typing import Any
from typing import Any, Optional

from ..hooks.interrupt import Interrupt
from ..telemetry.metrics import EventLoopMetrics
from ..types.content import Message
from ..types.streaming import StopReason
Expand All @@ -26,6 +27,7 @@ class AgentResult:
message: Message
metrics: EventLoopMetrics
state: Any
interrupts: Optional[list[Interrupt]] = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users can parse the raised interrupts from the AgentResult returned from the agent invoke. interrupts will be populated when stop_reason is "interrupt".


def __str__(self) -> str:
"""Get the agent's last message as a string.
Expand Down
14 changes: 14 additions & 0 deletions src/strands/agent/execution_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Agent execution state."""

from enum import Enum


class ExecutionState(Enum):
"""Represents the current execution state of an agent.
ASSISTANT: Agent is waiting for user message (default).
INTERRUPT: Agent is waiting for user feedback to resume tool execution.
"""

ASSISTANT = "assistant"
INTERRUPT = "interrupt"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we could add another state called "TOOL" to indicate that the agent is waiting for tool results. Under this state, users would be able to pass in tool result content blocks into agent invoke. This is something to consider for follow up though.

Loading
Loading