Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 135 additions & 99 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
from ..telemetry.tracer import Tracer, get_tracer
from ..tools._validator import validate_and_prepare_tools
from ..types._events import (
EventLoopStopEvent,
Expand All @@ -37,7 +37,7 @@
MaxTokensReachedException,
ModelThrottledException,
)
from ..types.streaming import Metrics, StopReason
from ..types.streaming import StopReason
from ..types.tools import ToolResult, ToolUse
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
from .streaming import stream_messages
Expand Down Expand Up @@ -106,16 +106,142 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
)
invocation_state["event_loop_cycle_span"] = cycle_span

model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)

try:
if stop_reason == "max_tokens":
"""
Handle max_tokens limit reached by the model.

When the model reaches its maximum token limit, this represents a potentially unrecoverable
state where the model's response was truncated. By default, Strands fails hard with an
MaxTokensReachedException to maintain consistency with other failure types.
"""
raise MaxTokensReachedException(
message=(
"Agent has reached an unrecoverable state due to max_tokens limit. "
"For more information see: "
"https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
)
)

# If the model is requesting to use tools
if stop_reason == "tool_use":
# Handle tool execution
tool_events = _handle_tool_execution(
stop_reason,
message,
agent=agent,
cycle_trace=cycle_trace,
cycle_span=cycle_span,
cycle_start_time=cycle_start_time,
invocation_state=invocation_state,
)
async for tool_event in tool_events:
yield tool_event

return

# End the cycle and return results
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes)
if cycle_span:
tracer.end_event_loop_cycle_span(
span=cycle_span,
message=message,
)
except EventLoopException as e:
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)

# Don't yield or log the exception - we already did it when we
# raised the exception and we don't need that duplication.
raise
except (ContextWindowOverflowException, MaxTokensReachedException) as e:
# Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)
raise e
except Exception as e:
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)

# Handle any other exceptions
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e

yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])


async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Make a recursive call to event_loop_cycle with the current state.

This function is used when the event loop needs to continue processing after tool execution.

Args:
agent: Agent for which the recursive call is being made.
invocation_state: Arguments to pass through event_loop_cycle


Yields:
Results from event_loop_cycle where the last result contains:

- StopReason: Reason the model stopped generating
- Message: The generated message from the model
- EventLoopMetrics: Updated metrics for the event loop
- Any: Updated request state
"""
cycle_trace = invocation_state["event_loop_cycle_trace"]

# Recursive call trace
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
cycle_trace.add_child(recursive_trace)

yield StartEvent()

events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
async for event in events:
yield event

recursive_trace.end()


async def _handle_model_execution(
agent: "Agent",
cycle_span: Any,
cycle_trace: Trace,
invocation_state: dict[str, Any],
tracer: Tracer,
) -> AsyncGenerator[TypedEvent, None]:
"""Handle model execution with retry logic for throttling exceptions.

Executes the model inference with automatic retry handling for throttling exceptions.
Manages tracing, hooks, and metrics collection throughout the process.

Args:
agent: The agent executing the model.
cycle_span: Span object for tracing the cycle.
cycle_trace: Trace object for the current event loop cycle.
invocation_state: State maintained across cycles.
tracer: Tracer instance for span management.

Yields:
Model stream events and throttle events during retries.

Raises:
ModelThrottledException: If max retry attempts are exceeded.
Exception: Any other model execution errors.
"""
# Create a trace for the stream_messages call
stream_trace = Trace("stream_messages", parent_id=cycle_trace.id)
cycle_trace.add_child(stream_trace)

# Process messages with exponential backoff for throttling
message: Message
stop_reason: StopReason
usage: Any
metrics: Metrics

# Retry loop for handling throttling exceptions
current_delay = INITIAL_DELAY
for attempt in range(MAX_ATTEMPTS):
Expand All @@ -136,8 +262,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->

try:
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
if not isinstance(event, ModelStopReason):
yield event
yield event

stop_reason, message, usage, metrics = event["stop"]
invocation_state.setdefault("request_state", {})
Expand Down Expand Up @@ -198,108 +323,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
# Add the response message to the conversation
agent.messages.append(message)
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
yield ModelMessageEvent(message=message)

# Update metrics
agent.event_loop_metrics.update_usage(usage)
agent.event_loop_metrics.update_metrics(metrics)

if stop_reason == "max_tokens":
"""
Handle max_tokens limit reached by the model.

When the model reaches its maximum token limit, this represents a potentially unrecoverable
state where the model's response was truncated. By default, Strands fails hard with an
MaxTokensReachedException to maintain consistency with other failure types.
"""
raise MaxTokensReachedException(
message=(
"Agent has reached an unrecoverable state due to max_tokens limit. "
"For more information see: "
"https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
)
)

# If the model is requesting to use tools
if stop_reason == "tool_use":
# Handle tool execution
events = _handle_tool_execution(
stop_reason,
message,
agent=agent,
cycle_trace=cycle_trace,
cycle_span=cycle_span,
cycle_start_time=cycle_start_time,
invocation_state=invocation_state,
)
async for typed_event in events:
yield typed_event

return

# End the cycle and return results
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes)
if cycle_span:
tracer.end_event_loop_cycle_span(
span=cycle_span,
message=message,
)
except EventLoopException as e:
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)

# Don't yield or log the exception - we already did it when we
# raised the exception and we don't need that duplication.
raise
except (ContextWindowOverflowException, MaxTokensReachedException) as e:
# Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)
raise e
except Exception as e:
if cycle_span:
tracer.end_span_with_error(cycle_span, str(e), e)

# Handle any other exceptions
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e

yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])


async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
"""Make a recursive call to event_loop_cycle with the current state.

This function is used when the event loop needs to continue processing after tool execution.

Args:
agent: Agent for which the recursive call is being made.
invocation_state: Arguments to pass through event_loop_cycle


Yields:
Results from event_loop_cycle where the last result contains:

- StopReason: Reason the model stopped generating
- Message: The generated message from the model
- EventLoopMetrics: Updated metrics for the event loop
- Any: Updated request state
"""
cycle_trace = invocation_state["event_loop_cycle_trace"]

# Recursive call trace
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
cycle_trace.add_child(recursive_trace)

yield StartEvent()

events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
async for event in events:
yield event

recursive_trace.end()


async def _handle_tool_execution(
stop_reason: StopReason,
Expand Down