diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f2eed063c..d6367e9d9 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -17,7 +17,7 @@ from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent from ..telemetry.metrics import Trace -from ..telemetry.tracer import get_tracer +from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools from ..types._events import ( EventLoopStopEvent, @@ -37,7 +37,7 @@ MaxTokensReachedException, ModelThrottledException, ) -from ..types.streaming import Metrics, StopReason +from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages @@ -106,16 +106,142 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) invocation_state["event_loop_cycle_span"] = cycle_span + model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) + + try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + ) + + # If the model is requesting to use tools + if stop_reason == "tool_use": + # Handle tool execution + tool_events = _handle_tool_execution( + stop_reason, + message, + agent=agent, + cycle_trace=cycle_trace, + cycle_span=cycle_span, + cycle_start_time=cycle_start_time, + invocation_state=invocation_state, + ) + async for tool_event in tool_events: + yield tool_event + + return + + # End the cycle and return results + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) + if cycle_span: + tracer.end_event_loop_cycle_span( + span=cycle_span, + message=message, + ) + except EventLoopException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + + # Don't yield or log the exception - we already did it when we + # raised the exception and we don't need that duplication. + raise + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e + except Exception as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + + # Handle any other exceptions + yield ForceStopEvent(reason=e) + logger.exception("cycle failed") + raise EventLoopException(e, invocation_state["request_state"]) from e + + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + + +async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + """Make a recursive call to event_loop_cycle with the current state. + + This function is used when the event loop needs to continue processing after tool execution. + + Args: + agent: Agent for which the recursive call is being made. + invocation_state: Arguments to pass through event_loop_cycle + + + Yields: + Results from event_loop_cycle where the last result contains: + + - StopReason: Reason the model stopped generating + - Message: The generated message from the model + - EventLoopMetrics: Updated metrics for the event loop + - Any: Updated request state + """ + cycle_trace = invocation_state["event_loop_cycle_trace"] + + # Recursive call trace + recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) + cycle_trace.add_child(recursive_trace) + + yield StartEvent() + + events = event_loop_cycle(agent=agent, invocation_state=invocation_state) + async for event in events: + yield event + + recursive_trace.end() + + +async def _handle_model_execution( + agent: "Agent", + cycle_span: Any, + cycle_trace: Trace, + invocation_state: dict[str, Any], + tracer: Tracer, +) -> AsyncGenerator[TypedEvent, None]: + """Handle model execution with retry logic for throttling exceptions. + + Executes the model inference with automatic retry handling for throttling exceptions. + Manages tracing, hooks, and metrics collection throughout the process. + + Args: + agent: The agent executing the model. + cycle_span: Span object for tracing the cycle. + cycle_trace: Trace object for the current event loop cycle. + invocation_state: State maintained across cycles. + tracer: Tracer instance for span management. + + Yields: + Model stream events and throttle events during retries. + + Raises: + ModelThrottledException: If max retry attempts are exceeded. + Exception: Any other model execution errors. + """ # Create a trace for the stream_messages call stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) cycle_trace.add_child(stream_trace) - # Process messages with exponential backoff for throttling - message: Message - stop_reason: StopReason - usage: Any - metrics: Metrics - # Retry loop for handling throttling exceptions current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): @@ -136,8 +262,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> try: async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if not isinstance(event, ModelStopReason): - yield event + yield event stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -198,108 +323,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Add the response message to the conversation agent.messages.append(message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield ModelMessageEvent(message=message) # Update metrics agent.event_loop_metrics.update_usage(usage) agent.event_loop_metrics.update_metrics(metrics) - if stop_reason == "max_tokens": - """ - Handle max_tokens limit reached by the model. - - When the model reaches its maximum token limit, this represents a potentially unrecoverable - state where the model's response was truncated. By default, Strands fails hard with an - MaxTokensReachedException to maintain consistency with other failure types. - """ - raise MaxTokensReachedException( - message=( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ) - ) - - # If the model is requesting to use tools - if stop_reason == "tool_use": - # Handle tool execution - events = _handle_tool_execution( - stop_reason, - message, - agent=agent, - cycle_trace=cycle_trace, - cycle_span=cycle_span, - cycle_start_time=cycle_start_time, - invocation_state=invocation_state, - ) - async for typed_event in events: - yield typed_event - - return - - # End the cycle and return results - agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, - message=message, - ) - except EventLoopException as e: - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - - # Don't yield or log the exception - we already did it when we - # raised the exception and we don't need that duplication. - raise - except (ContextWindowOverflowException, MaxTokensReachedException) as e: - # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) - # Handle any other exceptions yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) - - -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: - """Make a recursive call to event_loop_cycle with the current state. - - This function is used when the event loop needs to continue processing after tool execution. - - Args: - agent: Agent for which the recursive call is being made. - invocation_state: Arguments to pass through event_loop_cycle - - - Yields: - Results from event_loop_cycle where the last result contains: - - - StopReason: Reason the model stopped generating - - Message: The generated message from the model - - EventLoopMetrics: Updated metrics for the event loop - - Any: Updated request state - """ - cycle_trace = invocation_state["event_loop_cycle_trace"] - - # Recursive call trace - recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) - cycle_trace.add_child(recursive_trace) - - yield StartEvent() - - events = event_loop_cycle(agent=agent, invocation_state=invocation_state) - async for event in events: - yield event - - recursive_trace.end() - async def _handle_tool_execution( stop_reason: StopReason,