From c1491ad3867a3e912f4b1d9f98fd34eb25fa31c9 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 8 Sep 2025 12:06:33 -0400 Subject: [PATCH] interrupt --- src/strands/agent/agent.py | 45 ++++- src/strands/agent/agent_result.py | 4 +- src/strands/agent/execution_state.py | 14 ++ src/strands/event_loop/event_loop.py | 219 +++++++++++++--------- src/strands/experimental/hooks/events.py | 8 +- src/strands/hooks/interrupt.py | 37 ++++ src/strands/hooks/registry.py | 7 +- src/strands/tools/decorator.py | 24 ++- src/strands/tools/executors/_executor.py | 47 ++++- src/strands/tools/executors/concurrent.py | 2 +- src/strands/tools/executors/sequential.py | 5 +- src/strands/types/_events.py | 15 +- src/strands/types/agent.py | 2 +- src/strands/types/content.py | 4 +- src/strands/types/event_loop.py | 2 + src/strands/types/tools.py | 3 + 16 files changed, 325 insertions(+), 113 deletions(-) create mode 100644 src/strands/agent/execution_state.py create mode 100644 src/strands/hooks/interrupt.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bb602d66b..ea1b7c13f 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -61,6 +61,7 @@ ConversationManager, SlidingWindowConversationManager, ) +from .execution_state import ExecutionState from .state import AgentState logger = logging.getLogger(__name__) @@ -142,6 +143,17 @@ def caller( Raises: AttributeError: If the tool doesn't exist. """ + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + if should_record_direct_tool_call and self._agent.execution_state != ExecutionState.ASSISTANT: + raise RuntimeError( + f"execution_state=<{self._agent.execution_state}> " + f"| recording direct tool calls is only allowed in ASSISTANT execution state" + ) + normalized_name = self._find_normalized_tool_name(name) # Create unique tool ID and set up the tool request @@ -167,11 +179,6 @@ def tcall() -> ToolResult: future = executor.submit(tcall) tool_result = future.result() - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call - if should_record_direct_tool_call: # Create a record of this tool execution in the message history self._agent._record_tool_execution(tool_use, tool_result, user_message_override) @@ -349,6 +356,10 @@ def __init__( self.hooks.add_hook(hook) self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + self.execution_state = ExecutionState.ASSISTANT + + self.interrupts = {} + @property def tool(self) -> ToolCaller: """Call tool as a function. @@ -540,6 +551,7 @@ async def stream_async( Args: prompt: User input in various formats: - str: Simple text input + - ContentBlock: Multi-modal content block - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history @@ -564,6 +576,8 @@ async def stream_async( yield event["data"] ``` """ + self._resume(prompt) + callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) @@ -585,6 +599,11 @@ async def stream_async( result = AgentResult(*event["stop"]) callback_handler(result=result) + + if result.stop_reason == "interrupt": + self.execution_state = ExecutionState.INTERRUPT + self.interrupts = {interrupt.name: interrupt for interrupt in result.interrupts} + yield AgentResultEvent(result=result).as_dict() self._end_agent_trace_span(response=result) @@ -593,6 +612,16 @@ async def stream_async( self._end_agent_trace_span(error=e) raise + def _resume(self, prompt: AgentInput) -> None: + if self.execution_state != ExecutionState.INTERRUPT: + return + + if not isinstance(prompt, dict) or "resume" not in prompt: + raise ValueError(".") + + for interrupt in self.interrupts.values(): + interrupt.resume = prompt["resume"][interrupt.name] + async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -673,6 +702,8 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: if isinstance(prompt, str): # String input - convert to user message messages = [{"role": "user", "content": [{"text": prompt}]}] + elif isinstance(prompt, dict): + messages = [{"role": "user", "content": prompt}] if "resume" not in prompt else [] elif isinstance(prompt, list): if len(prompt) == 0: # Empty list @@ -692,7 +723,9 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: else: messages = [] if messages is None: - raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") + raise ValueError( + "Input prompt must be of type: `str | ContentBlock | list[Contentblock] | Messages | None`." + ) return messages def _record_tool_execution( diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index f3758c8d2..44267e103 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,8 +4,9 @@ """ from dataclasses import dataclass -from typing import Any +from typing import Any, Optional +from ..hooks.interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics from ..types.content import Message from ..types.streaming import StopReason @@ -26,6 +27,7 @@ class AgentResult: message: Message metrics: EventLoopMetrics state: Any + interrupts: Optional[list[Interrupt]] = None def __str__(self) -> str: """Get the agent's last message as a string. diff --git a/src/strands/agent/execution_state.py b/src/strands/agent/execution_state.py new file mode 100644 index 000000000..40bb555b2 --- /dev/null +++ b/src/strands/agent/execution_state.py @@ -0,0 +1,14 @@ +"""Agent execution state.""" + +from enum import Enum + + +class ExecutionState(Enum): + """Represents the current execution state of an agent. + + ASSISTANT: Agent is waiting for user message (default). + INTERRUPT: Agent is waiting for user feedback to resume tool execution. + """ + + ASSISTANT = "assistant" + INTERRUPT = "interrupt" diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 1d437e944..f5e15fa1e 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,6 +15,7 @@ from opentelemetry import trace as trace_api +from ..agent.execution_state import ExecutionState from ..experimental.hooks import ( AfterModelInvocationEvent, BeforeModelInvocationEvent, @@ -23,7 +24,7 @@ MessageAddedEvent, ) from ..telemetry.metrics import Trace -from ..telemetry.tracer import get_tracer +from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools from ..types._events import ( EventLoopStopEvent, @@ -33,6 +34,7 @@ ModelStopReason, StartEvent, StartEventLoopEvent, + ToolInterruptEvent, ToolResultMessageEvent, TypedEvent, ) @@ -112,104 +114,22 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) invocation_state["event_loop_cycle_span"] = cycle_span - # Create a trace for the stream_messages call - stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) - cycle_trace.add_child(stream_trace) - - # Process messages with exponential backoff for throttling - message: Message stop_reason: StopReason - usage: Any - metrics: Metrics - - # Retry loop for handling throttling exceptions - current_delay = INITIAL_DELAY - for attempt in range(MAX_ATTEMPTS): - model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None - model_invoke_span = tracer.start_model_invoke_span( - messages=agent.messages, - parent_span=cycle_span, - model_id=model_id, - ) - with trace_api.use_span(model_invoke_span): - agent.hooks.invoke_callbacks( - BeforeModelInvocationEvent( - agent=agent, - ) - ) - - tool_specs = agent.tool_registry.get_all_tool_specs() - try: - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if not isinstance(event, ModelStopReason): - yield event + if agent.execution_state == ExecutionState.INTERRUPT: + stop_reason = "tool_use" + else: + events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) + async for event in events: + if isinstance(event, ModelStopReason): + stop_reason = event["stop"][0] + continue - stop_reason, message, usage, metrics = event["stop"] - invocation_state.setdefault("request_state", {}) + yield event - agent.hooks.invoke_callbacks( - AfterModelInvocationEvent( - agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( - stop_reason=stop_reason, - message=message, - ), - ) - ) - - if stop_reason == "max_tokens": - message = recover_message_on_max_tokens_reached(message) - - if model_invoke_span: - tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) - break # Success! Break out of retry loop - - except Exception as e: - if model_invoke_span: - tracer.end_span_with_error(model_invoke_span, str(e), e) - - agent.hooks.invoke_callbacks( - AfterModelInvocationEvent( - agent=agent, - exception=e, - ) - ) - - if isinstance(e, ModelThrottledException): - if attempt + 1 == MAX_ATTEMPTS: - yield ForceStopEvent(reason=e) - raise e - - logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, - ) - await asyncio.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) - - yield EventLoopThrottleEvent(delay=current_delay) - else: - raise e + message = agent.messages[-1] try: - # Add message in trace and mark the end of the stream messages trace - stream_trace.add_message(message) - stream_trace.end() - - # Add the response message to the conversation - agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield ModelMessageEvent(message=message) - - # Update metrics - agent.event_loop_metrics.update_usage(usage) - agent.event_loop_metrics.update_metrics(metrics) - if stop_reason == "max_tokens": """ Handle max_tokens limit reached by the model. @@ -307,6 +227,105 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) - recursive_trace.end() +async def _handle_model_execution( + agent: "Agent", + cycle_span: Any, + cycle_trace: Trace, + invocation_state: dict[str, Any], + tracer: Tracer, +) -> AsyncGenerator[TypedEvent, None]: + """.""" + # Create a trace for the stream_messages call + stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) + cycle_trace.add_child(stream_trace) + + # Retry loop for handling throttling exceptions + current_delay = INITIAL_DELAY + for attempt in range(MAX_ATTEMPTS): + model_id = agent.model.config.get("model_id") if hasattr(agent.model, "config") else None + model_invoke_span = tracer.start_model_invoke_span( + messages=agent.messages, + parent_span=cycle_span, + model_id=model_id, + ) + with trace_api.use_span(model_invoke_span): + agent.hooks.invoke_callbacks( + BeforeModelInvocationEvent( + agent=agent, + ) + ) + + tool_specs = agent.tool_registry.get_all_tool_specs() + + try: + async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): + yield event + + stop_reason, message, usage, metrics = event["stop"] + invocation_state.setdefault("request_state", {}) + + agent.hooks.invoke_callbacks( + AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_reason=stop_reason, + message=message, + ), + ) + ) + + if stop_reason == "max_tokens": + message = recover_message_on_max_tokens_reached(message) + + if model_invoke_span: + tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) + break # Success! Break out of retry loop + + except Exception as e: + if model_invoke_span: + tracer.end_span_with_error(model_invoke_span, str(e), e) + + agent.hooks.invoke_callbacks( + AfterModelInvocationEvent( + agent=agent, + exception=e, + ) + ) + + if isinstance(e, ModelThrottledException): + if attempt + 1 == MAX_ATTEMPTS: + yield ForceStopEvent(reason=e) + raise e + + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered " + "| delaying before next retry", + current_delay, + MAX_ATTEMPTS, + attempt + 1, + ) + await asyncio.sleep(current_delay) + current_delay = min(current_delay * 2, MAX_DELAY) + + yield EventLoopThrottleEvent(delay=current_delay) + else: + raise e + + # Add message in trace and mark the end of the stream messages trace + stream_trace.add_message(message) + stream_trace.end() + + # Update metrics + agent.event_loop_metrics.update_usage(usage) + agent.event_loop_metrics.update_metrics(metrics) + + # Add the response message to the conversation + agent.messages.append(message) + agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + yield ModelMessageEvent(message=message) + + async def _handle_tool_execution( stop_reason: StopReason, message: Message, @@ -345,15 +364,29 @@ async def _handle_tool_execution( yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) return + tool_interrupts = [] tool_events = agent.tool_executor._execute( agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state ) async for tool_event in tool_events: yield tool_event + if isinstance(tool_event, ToolInterruptEvent): + tool_interrupts.append(tool_event["tool_interrupt_event"]["interrupt"]) + # Store parent cycle ID for the next cycle invocation_state["event_loop_parent_cycle_id"] = invocation_state["event_loop_cycle_id"] + if tool_interrupts: + # TODO: deal with metrics and traces + yield EventLoopStopEvent( + "interrupt", message, agent.event_loop_metrics, invocation_state["request_state"], tool_interrupts + ) + return + + agent.execution_state = ExecutionState.ASSISTANT + agent.interrupts = {} + tool_result_message: Message = { "role": "user", "content": [{"toolResult": result} for result in tool_results], diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index d03e65d85..0834293b6 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -7,13 +7,14 @@ from typing import Any, Optional from ...hooks import HookEvent +from ...hooks.interrupt import InterruptEvent from ...types.content import Message from ...types.streaming import StopReason from ...types.tools import AgentTool, ToolResult, ToolUse @dataclass -class BeforeToolInvocationEvent(HookEvent): +class BeforeToolInvocationEvent(HookEvent, InterruptEvent): """Event triggered before a tool is invoked. This event is fired just before the agent executes a tool, allowing hook @@ -26,14 +27,17 @@ class BeforeToolInvocationEvent(HookEvent): to change which tool gets executed. This may be None if tool lookup failed. tool_use: The tool parameters that will be passed to selected_tool. invocation_state: Keyword arguments that will be passed to the tool. + cancel: A user defined message that when set, will lead to canceling of the tool call. + The message is used to populate a tool result with status "error". """ selected_tool: Optional[AgentTool] tool_use: ToolUse invocation_state: dict[str, Any] + cancel: Optional[str] = None def _can_write(self, name: str) -> bool: - return name in ["selected_tool", "tool_use"] + return name in ["cancel", "interrupt", "selected_tool", "tool_use"] @dataclass diff --git a/src/strands/hooks/interrupt.py b/src/strands/hooks/interrupt.py new file mode 100644 index 000000000..fbc19aea7 --- /dev/null +++ b/src/strands/hooks/interrupt.py @@ -0,0 +1,37 @@ +""".""" +import abc +from dataclasses import dataclass +from typing import Any + + +@dataclass +class Interrupt: + """.""" + + name: str + reasons: list[Any] + resume: Any = None + activated: bool = False + + def __call__(self, reason: Any) -> Any: + """.""" + if self.resume: + self.activated = False + return self.resume + + self.reasons.append(reason) + self.activated = True + raise InterruptException(self) + + +class InterruptException(Exception): + """.""" + def __init__(self, interrupt: Interrupt) -> None: + self.interrupt = interrupt + + +@dataclass +class InterruptEvent: + """.""" + + interrupt: Interrupt diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index a3b76d743..e6afc3db7 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -10,6 +10,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from .interrupt import InterruptException + if TYPE_CHECKING: from ..agent import Agent @@ -200,7 +202,10 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: ``` """ for callback in self.get_callbacks_for(event): - callback(event) + try: + callback(event) + except InterruptException: + pass return event diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 99aa7e372..86c18c960 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -62,7 +62,8 @@ def my_tool(param1: str, param2: int = 42) -> dict: from pydantic import BaseModel, Field, create_model from typing_extensions import override -from ..types._events import ToolResultEvent, ToolStreamEvent +from ..hooks.interrupt import Interrupt, InterruptException +from ..types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from ..types.tools import AgentTool, JSONSchema, ToolContext, ToolGenerator, ToolResult, ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -270,9 +271,16 @@ def inject_special_parameters( invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), agent.invoke_async(), etc.). """ + if self._context_param and self._context_param in self.signature.parameters: + tool_name = tool_use["name"] + agent = invocation_state["agent"] + tool_context = ToolContext( - tool_use=tool_use, agent=invocation_state["agent"], invocation_state=invocation_state + tool_use=tool_use, + agent=agent, + invocation_state=invocation_state, + interrupt=agent.interrupts.get(tool_name) or Interrupt(tool_name, reasons=[]), ) validated_input[self._context_param] = tool_context @@ -447,6 +455,7 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw """ # This is a tool use call - process accordingly tool_use_id = tool_use.get("toolUseId", "unknown") + tool_name = tool_use["name"] tool_input: dict[str, Any] = tool_use.get("input", {}) try: @@ -477,6 +486,17 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw result = await asyncio.to_thread(self._tool_func, **validated_input) # type: ignore yield self._wrap_tool_result(tool_use_id, result) + except InterruptException as e: + yield ToolInterruptEvent(e.interrupt) + yield self._wrap_tool_result( + tool_use_id, + { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"{tool_name} interrupted"}] + + }, + ) except ValueError as e: # Special handling for validation errors error_msg = str(e) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 5354991c3..fff3f692c 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -12,9 +12,10 @@ from opentelemetry import trace as trace_api from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent +from ...hooks.interrupt import Interrupt, InterruptException from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer -from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent +from ...types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse @@ -78,9 +79,51 @@ async def _stream( selected_tool=tool_func, tool_use=tool_use, invocation_state=invocation_state, + interrupt=agent.interrupts.get(tool_name) or Interrupt(tool_name, reasons=[]), ) ) + if before_event.interrupt.activated: + yield ToolInterruptEvent(before_event.interrupt) + + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": f"{tool_name} interrupted"}], + } + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + selected_tool=None, + ) + ) + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + + return + + if before_event.cancel: + result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": before_event.cancel}], + } + after_event = agent.hooks.invoke_callbacks( + AfterToolInvocationEvent( + agent=agent, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + selected_tool=None, + ) + ) + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return + try: selected_tool = before_event.selected_tool tool_use = before_event.tool_use @@ -129,7 +172,7 @@ async def _stream( # below the last "event" must point to the tool_result event = event.tool_result break - elif isinstance(event, ToolStreamEvent): + elif isinstance(event, (ToolStreamEvent, ToolInterruptEvent)): yield event else: yield ToolStreamEvent(tool_use, event) diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 767071bae..4a985c6af 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -72,7 +72,7 @@ async def _execute( yield event task_events[task_id].set() - asyncio.gather(*tasks) + await asyncio.gather(*tasks) async def _task( self, diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 60e5c7fa7..24b411492 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -5,7 +5,7 @@ from typing_extensions import override from ...telemetry.metrics import Trace -from ...types._events import TypedEvent +from ...types._events import ToolInterruptEvent, TypedEvent from ...types.tools import ToolResult, ToolUse from ._executor import ToolExecutor @@ -45,3 +45,6 @@ async def _execute( ) async for event in events: yield event + + if isinstance(event, ToolInterruptEvent): + break diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 3d0f1d0f0..02937b1f2 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,10 +5,11 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Optional, cast from typing_extensions import override +from ..hooks.interrupt import Interrupt from ..telemetry import EventLoopMetrics from .citations import Citation from .content import Message @@ -220,6 +221,7 @@ def __init__( message: Message, metrics: "EventLoopMetrics", request_state: Any, + interrupts: Optional[list[Interrupt]] = None ) -> None: """Initialize with the final execution results. @@ -228,8 +230,9 @@ def __init__( message: Final message from the model metrics: Execution metrics and performance data request_state: Final state of the agent execution + interrupts: TODO. """ - super().__init__({"stop": (stop_reason, message, metrics, request_state)}) + super().__init__({"stop": (stop_reason, message, metrics, request_state, interrupts)}) @property @override @@ -298,6 +301,14 @@ def tool_use_id(self) -> str: return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId")) +class ToolInterruptEvent(TypedEvent): + """Event emitted when a tool is interrupted.""" + + def __init__(self, interrupt: Interrupt) -> None: + """TODO.""" + super().__init__({"tool_interrupt_event": {"interrupt": interrupt}}) + + class ModelMessageEvent(TypedEvent): """Event emitted when the model invocation has completed. diff --git a/src/strands/types/agent.py b/src/strands/types/agent.py index 151c88f89..382721bab 100644 --- a/src/strands/types/agent.py +++ b/src/strands/types/agent.py @@ -7,4 +7,4 @@ from .content import ContentBlock, Messages -AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None +AgentInput: TypeAlias = str | ContentBlock | list[ContentBlock] | Messages | None diff --git a/src/strands/types/content.py b/src/strands/types/content.py index c3eddca4d..91bc29686 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,7 +6,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional from typing_extensions import TypedDict @@ -80,6 +80,7 @@ class ContentBlock(TypedDict, total=False): guardContent: Contains the content to assess with the guardrail. image: Image to include in the message. reasoningContent: Contains content regarding the reasoning that is carried out by the model. + resume: TODO text: Text to include in the message. toolResult: The result for a tool request that a model makes. toolUse: Information about a tool use request from a model. @@ -92,6 +93,7 @@ class ContentBlock(TypedDict, total=False): guardContent: GuardContent image: ImageContent reasoningContent: ReasoningContentBlock + resume: dict[str, Any] text: str toolResult: ToolResult toolUse: ToolUse diff --git a/src/strands/types/event_loop.py b/src/strands/types/event_loop.py index 2c240972b..17a8bd36b 100644 --- a/src/strands/types/event_loop.py +++ b/src/strands/types/event_loop.py @@ -37,6 +37,7 @@ class Metrics(TypedDict): "content_filtered", "end_turn", "guardrail_intervened", + "interrupt", "max_tokens", "stop_sequence", "tool_use", @@ -46,6 +47,7 @@ class Metrics(TypedDict): - "content_filtered": Content was filtered due to policy violation - "end_turn": Normal completion of the response - "guardrail_intervened": Guardrail system intervened +- "interrupt": Agent was interrupted - "max_tokens": Maximum token limit reached - "stop_sequence": Stop sequence encountered - "tool_use": Model requested to use a tool diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index e8d5531b2..1bbe1e677 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -11,6 +11,7 @@ from typing_extensions import TypedDict +from ..hooks.interrupt import Interrupt from .media import DocumentContent, ImageContent if TYPE_CHECKING: @@ -134,6 +135,7 @@ class ToolContext: model configuration, and other agent state. invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), agent.invoke_async(), etc.). + interrupt: TODO Note: This class is intended to be instantiated by the SDK. Direct construction by users @@ -143,6 +145,7 @@ class ToolContext: tool_use: ToolUse agent: "Agent" invocation_state: dict[str, Any] + interrupt: Interrupt # Individual ToolChoice type aliases