Skip to content

Commit 1fe2bc0

Browse files
committed
interrupt
1 parent 1f25512 commit 1fe2bc0

File tree

14 files changed

+298
-110
lines changed

14 files changed

+298
-110
lines changed

src/strands/agent/agent.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
ConversationManager,
6262
SlidingWindowConversationManager,
6363
)
64+
from .execution_state import ExecutionState
6465
from .state import AgentState
6566

6667
logger = logging.getLogger(__name__)
@@ -142,6 +143,17 @@ def caller(
142143
Raises:
143144
AttributeError: If the tool doesn't exist.
144145
"""
146+
if record_direct_tool_call is not None:
147+
should_record_direct_tool_call = record_direct_tool_call
148+
else:
149+
should_record_direct_tool_call = self._agent.record_direct_tool_call
150+
151+
if should_record_direct_tool_call and self._agent.execution_state != ExecutionState.ASSISTANT:
152+
raise RuntimeError(
153+
f"execution_state=<{self._agent.execution_state}> "
154+
f"| recording direct tool calls is only allowed in ASSISTANT execution state"
155+
)
156+
145157
normalized_name = self._find_normalized_tool_name(name)
146158

147159
# Create unique tool ID and set up the tool request
@@ -167,11 +179,6 @@ def tcall() -> ToolResult:
167179
future = executor.submit(tcall)
168180
tool_result = future.result()
169181

170-
if record_direct_tool_call is not None:
171-
should_record_direct_tool_call = record_direct_tool_call
172-
else:
173-
should_record_direct_tool_call = self._agent.record_direct_tool_call
174-
175182
if should_record_direct_tool_call:
176183
# Create a record of this tool execution in the message history
177184
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)
@@ -349,6 +356,10 @@ def __init__(
349356
self.hooks.add_hook(hook)
350357
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
351358

359+
self.execution_state = ExecutionState.ASSISTANT
360+
361+
self.interrupts = {}
362+
352363
@property
353364
def tool(self) -> ToolCaller:
354365
"""Call tool as a function.
@@ -540,6 +551,7 @@ async def stream_async(
540551
Args:
541552
prompt: User input in various formats:
542553
- str: Simple text input
554+
- ContentBlock: Multi-modal content block
543555
- list[ContentBlock]: Multi-modal content blocks
544556
- list[Message]: Complete messages with roles
545557
- None: Use existing conversation history
@@ -564,6 +576,8 @@ async def stream_async(
564576
yield event["data"]
565577
```
566578
"""
579+
self._resume(prompt)
580+
567581
callback_handler = kwargs.get("callback_handler", self.callback_handler)
568582

569583
# Process input and get message to add (if any)
@@ -585,6 +599,11 @@ async def stream_async(
585599

586600
result = AgentResult(*event["stop"])
587601
callback_handler(result=result)
602+
603+
if result.stop_reason == "interrupt":
604+
self.execution_state = ExecutionState.INTERRUPT
605+
self.interrupts = {interrupt.name: interrupt for interrupt in result.interrupts}
606+
588607
yield AgentResultEvent(result=result).as_dict()
589608

590609
self._end_agent_trace_span(response=result)
@@ -593,6 +612,16 @@ async def stream_async(
593612
self._end_agent_trace_span(error=e)
594613
raise
595614

615+
def _resume(self, prompt: AgentInput) -> None:
616+
if self.execution_state != ExecutionState.INTERRUPT:
617+
return
618+
619+
if not isinstance(prompt, dict) or "resume" not in prompt:
620+
raise ValueError("<TODO>.")
621+
622+
for interrupt in self.interrupts.values():
623+
interrupt.resume = prompt["resume"][interrupt.name]
624+
596625
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
597626
"""Execute the agent's event loop with the given message and parameters.
598627
@@ -673,6 +702,8 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
673702
if isinstance(prompt, str):
674703
# String input - convert to user message
675704
messages = [{"role": "user", "content": [{"text": prompt}]}]
705+
elif isinstance(prompt, dict):
706+
messages = [{"role": "user", "content": prompt}] if "resume" not in prompt else []
676707
elif isinstance(prompt, list):
677708
if len(prompt) == 0:
678709
# Empty list
@@ -692,7 +723,9 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
692723
else:
693724
messages = []
694725
if messages is None:
695-
raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.")
726+
raise ValueError(
727+
"Input prompt must be of type: `str | ContentBlock | list[Contentblock] | Messages | None`."
728+
)
696729
return messages
697730

698731
def _record_tool_execution(

src/strands/agent/agent_result.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"""
55

66
from dataclasses import dataclass
7-
from typing import Any
7+
from typing import Any, Optional
88

9+
from ..hooks.interrupt import Interrupt
910
from ..telemetry.metrics import EventLoopMetrics
1011
from ..types.content import Message
1112
from ..types.streaming import StopReason
@@ -26,6 +27,7 @@ class AgentResult:
2627
message: Message
2728
metrics: EventLoopMetrics
2829
state: Any
30+
interrupts: Optional[list[Interrupt]] = None
2931

3032
def __str__(self) -> str:
3133
"""Get the agent's last message as a string.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Agent execution state."""
2+
3+
from enum import Enum
4+
5+
6+
class ExecutionState(Enum):
7+
"""Represents the current execution state of an agent.
8+
9+
ASSISTANT: Agent is waiting for user message (default).
10+
INTERRUPT: Agent is waiting for user feedback to resume tool execution.
11+
"""
12+
13+
ASSISTANT = "assistant"
14+
INTERRUPT = "interrupt"

0 commit comments

Comments
 (0)