diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 2d3d538fe..b421b70c1 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -26,6 +26,15 @@ from .._async import run_async from ..agent import Agent from ..agent.state import AgentState +from ..experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from ..hooks import HookProvider, HookRegistry +from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( MultiAgentHandoffEvent, @@ -40,6 +49,8 @@ logger = logging.getLogger(__name__) +_DEFAULT_GRAPH_ID = "default_graph" + @dataclass class GraphState: @@ -223,6 +234,9 @@ def __init__(self) -> None: self._execution_timeout: Optional[float] = None self._node_timeout: Optional[float] = None self._reset_on_revisit: bool = False + self._id: str = _DEFAULT_GRAPH_ID + self._session_manager: Optional[SessionManager] = None + self._hooks: Optional[list[HookProvider]] = None def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" @@ -313,6 +327,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder": self._node_timeout = timeout return self + def set_graph_id(self, graph_id: str) -> "GraphBuilder": + """Set graph id. + + Args: + graph_id: Unique graph id + """ + self._id = graph_id + return self + + def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder": + """Set session manager for the graph. + + Args: + session_manager: SessionManager instance + """ + self._session_manager = session_manager + return self + + def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder": + """Set hook providers for the graph. + + Args: + hooks: Customer hooks user passes in + """ + self._hooks = hooks + return self + def build(self) -> "Graph": """Build and validate the graph with configured settings.""" if not self.nodes: @@ -338,6 +379,9 @@ def build(self) -> "Graph": execution_timeout=self._execution_timeout, node_timeout=self._node_timeout, reset_on_revisit=self._reset_on_revisit, + session_manager=self._session_manager, + hooks=self._hooks, + id=self._id, ) def _validate_graph(self) -> None: @@ -365,6 +409,9 @@ def __init__( execution_timeout: Optional[float] = None, node_timeout: Optional[float] = None, reset_on_revisit: bool = False, + session_manager: Optional[SessionManager] = None, + hooks: Optional[list[HookProvider]] = None, + id: str = _DEFAULT_GRAPH_ID, ) -> None: """Initialize Graph with execution limits and reset behavior. @@ -376,6 +423,9 @@ def __init__( execution_timeout: Total execution timeout in seconds (default: None - no limit) node_timeout: Individual node timeout in seconds (default: None - no limit) reset_on_revisit: Whether to reset node state when revisited (default: False) + session_manager: Session manager for persisting graph state and execution history (default: None) + hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) + id: Unique graph id (default: None) """ super().__init__() @@ -391,6 +441,19 @@ def __init__( self.reset_on_revisit = reset_on_revisit self.state = GraphState() self.tracer = get_tracer() + self.session_manager = session_manager + self.hooks = HookRegistry() + if self.session_manager: + self.hooks.add_hook(self.session_manager) + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + + self._resume_next_nodes: list[GraphNode] = [] + self._resume_from_session = False + self.id = id + + self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -453,18 +516,25 @@ async def stream_async( if invocation_state is None: invocation_state = {} + self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + logger.debug("task=<%s> | starting graph execution", task) # Initialize state start_time = time.time() - self.state = GraphState( - status=Status.EXECUTING, - task=task, - total_nodes=len(self.nodes), - edges=[(edge.from_node, edge.to_node) for edge in self.edges], - entry_points=list(self.entry_points), - start_time=start_time, - ) + if not self._resume_from_session: + # Initialize state + self.state = GraphState( + status=Status.EXECUTING, + task=task, + total_nodes=len(self.nodes), + edges=[(edge.from_node, edge.to_node) for edge in self.edges], + entry_points=list(self.entry_points), + start_time=start_time, + ) + else: + self.state.status = Status.EXECUTING + self.state.start_time = start_time span = self.tracer.start_multiagent_span(task, "graph") with trace_api.use_span(span, end_on_exit=True): @@ -499,6 +569,9 @@ async def stream_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self)) + self._resume_from_session = False + self._resume_next_nodes.clear() def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: """Validate graph nodes for duplicate instances.""" @@ -514,7 +587,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" - ready_nodes = list(self.entry_points) + ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -703,7 +776,9 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" - # Reset the node's state if reset_on_revisit is enabled and it's being revisited + self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state)) + + # Reset the node's state if reset_on_revisit is enabled, and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() @@ -844,6 +919,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Re-raise to stop graph execution (fail-fast behavior) raise + finally: + self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state)) + def _accumulate_metrics(self, node_result: NodeResult) -> None: """Accumulate metrics from a node result.""" self.state.accumulated_usage["inputTokens"] += node_result.accumulated_usage.get("inputTokens", 0) @@ -928,3 +1006,94 @@ def _build_result(self) -> GraphResult: edges=self.state.edges, entry_points=self.state.entry_points, ) + + def serialize_state(self) -> dict[str, Any]: + """Serialize the current graph state to a dictionary.""" + compute_nodes = self._compute_ready_nodes_for_resume() + next_nodes = [n.node_id for n in compute_nodes] if compute_nodes else [] + return { + "type": "graph", + "id": self.id, + "status": self.state.status.value, + "completed_nodes": [n.node_id for n in self.state.completed_nodes], + "failed_nodes": [n.node_id for n in self.state.failed_nodes], + "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, + "next_nodes_to_execute": next_nodes, + "current_task": self.state.task, + "execution_order": [n.node_id for n in self.state.execution_order], + } + + def deserialize_state(self, payload: dict[str, Any]) -> None: + """Restore graph state from a session dict and prepare for execution. + + This method handles two scenarios: + 1. If the graph execution ended (no next_nodes_to_execute, eg: Completed, or Failed with dead end nodes), + resets all nodes and graph state to allow re-execution from the beginning. + 2. If the graph execution was interrupted mid-execution (has next_nodes_to_execute), + restores the persisted state and prepares to resume execution from the next ready nodes. + + Args: + payload: Dictionary containing persisted state data including status, + completed nodes, results, and next nodes to execute. + """ + if not payload.get("next_nodes_to_execute"): + # Reset all nodes + for node in self.nodes.values(): + node.reset_executor_state() + # Reset graph state + self.state = GraphState() + self._resume_from_session = False + return + else: + self._from_dict(payload) + self._resume_from_session = True + + def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: + if self.state.status == Status.PENDING: + return [] + ready_nodes: list[GraphNode] = [] + completed_nodes = set(self.state.completed_nodes) + for node in self.nodes.values(): + if node in completed_nodes: + continue + incoming = [e for e in self.edges if e.to_node is node] + if not incoming: + ready_nodes.append(node) + elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): + ready_nodes.append(node) + + return ready_nodes + + def _from_dict(self, payload: dict[str, Any]) -> None: + self.state.status = Status(payload["status"]) + # Hydrate completed nodes & results + raw_results = payload.get("node_results") or {} + results: dict[str, NodeResult] = {} + for node_id, entry in raw_results.items(): + if node_id not in self.nodes: + continue + try: + results[node_id] = NodeResult.from_dict(entry) + except Exception: + logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) + raise + self.state.results = results + + self.state.failed_nodes = set( + self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes + ) + + # Restore completed nodes from persisted data + completed_node_ids = payload.get("completed_nodes") or [] + self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} + + # Execution order (only nodes that still exist) + order_node_ids = payload.get("execution_order") or [] + self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes] + + # Task + self.state.task = payload.get("current_task", self.state.task) + + # next nodes to execute + next_nodes = [self.nodes[nid] for nid in (payload.get("next_nodes_to_execute") or []) if nid in self.nodes] + self._resume_next_nodes = next_nodes diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index cd0a2d74c..accd56463 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -18,13 +18,22 @@ import logging import time from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Tuple, cast +from typing import Any, AsyncIterator, Callable, Optional, Tuple, cast from opentelemetry import trace as trace_api from .._async import run_async from ..agent import Agent from ..agent.state import AgentState +from ..experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from ..hooks import HookProvider, HookRegistry +from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool from ..types._events import ( @@ -40,6 +49,8 @@ logger = logging.getLogger(__name__) +_DEFAULT_SWARM_ID = "default_swarm" + @dataclass class SwarmNode: @@ -210,10 +221,14 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, + session_manager: Optional[SessionManager] = None, + hooks: Optional[list[HookProvider]] = None, + id: str = _DEFAULT_SWARM_ID, ) -> None: """Initialize Swarm with agents and configuration. Args: + id : Unique swarm id (default: None) nodes: List of nodes (e.g. Agent) to include in the swarm entry_point: Agent to start with. If None, uses the first agent (default: None) max_handoffs: Maximum handoffs to agents and users (default: 20) @@ -224,9 +239,11 @@ def __init__( Disabled by default (default: 0) repetitive_handoff_min_unique_agents: Minimum unique agents required in recent sequence Disabled by default (default: 0) + session_manager: Session manager for persisting graph state and execution history (default: None) + hooks: List of hook providers for monitoring and extending graph execution behavior (default: None) """ super().__init__() - + self.id = id self.entry_point = entry_point self.max_handoffs = max_handoffs self.max_iterations = max_iterations @@ -244,8 +261,19 @@ def __init__( ) self.tracer = get_tracer() + self.session_manager = session_manager + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + if self.session_manager: + self.hooks.add_hook(self.session_manager) + + self._resume_from_session = False + self._setup_swarm(nodes) self._inject_swarm_tools() + self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -260,7 +288,6 @@ def __call__( """ if invocation_state is None: invocation_state = {} - return run_async(lambda: self.invoke_async(task, invocation_state)) async def invoke_async( @@ -309,22 +336,24 @@ async def stream_async( if invocation_state is None: invocation_state = {} + self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + logger.debug("starting swarm execution") - # Initialize swarm state with configuration - if self.entry_point: - initial_node = self.nodes[str(self.entry_point.name)] - else: - initial_node = next(iter(self.nodes.values())) + if not self._resume_from_session: + # Initialize swarm state with configuration + initial_node = self._initial_node() - self.state = SwarmState( - current_node=initial_node, - task=task, - completion_status=Status.EXECUTING, - shared_context=self.shared_context, - ) + self.state = SwarmState( + current_node=initial_node, + task=task, + completion_status=Status.EXECUTING, + shared_context=self.shared_context, + ) + else: + self.state.completion_status = Status.EXECUTING + self.state.start_time = time.time() - start_time = time.time() span = self.tracer.start_multiagent_span(task, "swarm") with trace_api.use_span(span, end_on_exit=True): try: @@ -345,7 +374,9 @@ async def stream_async( self.state.completion_status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + self.state.execution_time = round((time.time() - self.state.start_time) * 1000) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self, invocation_state)) + self._resume_from_session = False # Yield final result after execution_time is set result = self._build_result() @@ -656,6 +687,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato # TODO: Implement cancellation token to stop _execute_node from continuing try: # Execute with timeout wrapper for async generator streaming + self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state)) node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -666,6 +698,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato self.state.node_history.append(current_node) + # After self.state add current node, swarm state finish updating, we persist here + self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state)) + logger.debug("node=<%s> | node execution completed", current_node.node_id) # Check if handoff occurred during execution @@ -823,3 +858,84 @@ def _build_result(self) -> SwarmResult: execution_time=self.state.execution_time, node_history=self.state.node_history, ) + + def serialize_state(self) -> dict[str, Any]: + """Serialize the current swarm state to a dictionary.""" + status_str = self.state.completion_status.value + next_nodes = ( + [self.state.current_node.node_id] + if self.state.completion_status == Status.EXECUTING and self.state.current_node + else [] + ) + + return { + "type": "swarm", + "id": self.id, + "status": status_str, + "node_history": [n.node_id for n in self.state.node_history], + "node_results": {k: v.to_dict() for k, v in self.state.results.items()}, + "next_nodes_to_execute": next_nodes, + "current_task": self.state.task, + "context": { + "shared_context": getattr(self.state.shared_context, "context", {}) or {}, + "handoff_message": self.state.handoff_message, + }, + } + + def deserialize_state(self, payload: dict[str, Any]) -> None: + """Restore swarm state from a session dict and prepare for execution. + + This method handles two scenarios: + 1. If the persisted status is COMPLETED, FAILED resets all nodes and graph state + to allow re-execution from the beginning. + 2. Otherwise, restores the persisted state and prepares to resume execution + from the next ready nodes. + + Args: + payload: Dictionary containing persisted state data including status, + completed nodes, results, and next nodes to execute. + """ + if not payload.get("next_nodes_to_execute"): + for node in self.nodes.values(): + node.reset_executor_state() + self.state = SwarmState( + current_node=SwarmNode("", Agent()), + task="", + completion_status=Status.PENDING, + ) + self._resume_from_session = False + return + else: + self._from_dict(payload) + self._resume_from_session = True + + def _from_dict(self, payload: dict[str, Any]) -> None: + self.state.completion_status = Status(payload["status"]) + # Hydrate completed nodes & results + context = payload["context"] or {} + self.shared_context.context = context.get("shared_context") or {} + self.state.handoff_message = context.get("handoff_message") + + self.state.node_history = [self.nodes[nid] for nid in (payload.get("node_history") or []) if nid in self.nodes] + + raw_results = payload.get("node_results") or {} + results: dict[str, NodeResult] = {} + for node_id, entry in raw_results.items(): + if node_id not in self.nodes: + continue + try: + results[node_id] = NodeResult.from_dict(entry) + except Exception: + logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) + raise + self.state.results = results + self.state.task = payload.get("current_task", self.state.task) + + next_node_ids = payload.get("next_nodes_to_execute") or [] + if next_node_ids: + self.state.current_node = self.nodes[next_node_ids[0]] if next_node_ids[0] else self._initial_node() + + def _initial_node(self) -> SwarmNode: + if self.entry_point: + return self.nodes[str(self.entry_point.name)] + return next(iter(self.nodes.values())) # First SwarmNode diff --git a/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py b/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py new file mode 100644 index 000000000..4e97a9217 --- /dev/null +++ b/tests/strands/experimental/hooks/multiagent/test_multi_agent_hooks.py @@ -0,0 +1,130 @@ +import pytest + +from strands import Agent +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.multiagent.graph import Graph, GraphBuilder +from strands.multiagent.swarm import Swarm +from tests.fixtures.mock_multiagent_hook_provider import MockMultiAgentHookProvider +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +@pytest.fixture +def hook_provider(): + return MockMultiAgentHookProvider( + [ + BeforeMultiAgentInvocationEvent, + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, + ] + ) + + +@pytest.fixture +def mock_model(): + agent_messages = [ + {"role": "assistant", "content": [{"text": "Task completed"}]}, + {"role": "assistant", "content": [{"text": "Task completed by agent 2"}]}, + {"role": "assistant", "content": [{"text": "Additional response"}]}, + ] + return MockedModelProvider(agent_messages) + + +@pytest.fixture +def agent1(mock_model): + return Agent(model=mock_model, system_prompt="You are agent 1.", name="agent1") + + +@pytest.fixture +def agent2(mock_model): + return Agent(model=mock_model, system_prompt="You are agent 2.", name="agent2") + + +@pytest.fixture +def swarm(agent1, agent2, hook_provider): + swarm = Swarm(nodes=[agent1, agent2], hooks=[hook_provider]) + return swarm + + +@pytest.fixture +def graph(agent1, agent2, hook_provider): + builder = GraphBuilder() + builder.add_node(agent1, "agent1") + builder.add_node(agent2, "agent2") + builder.add_edge("agent1", "agent2") + builder.set_entry_point("agent1") + graph = Graph(nodes=builder.nodes, edges=builder.edges, entry_points=builder.entry_points, hooks=[hook_provider]) + return graph + + +def test_swarm_complete_hook_lifecycle(swarm, hook_provider): + """E2E test verifying complete hook lifecycle for Swarm.""" + result = swarm("test task") + + length, events = hook_provider.get_events() + assert length == 5 + assert result.status.value == "completed" + + events_list = list(events) + + # Check event types and basic properties, ignoring invocation_state + assert isinstance(events_list[0], MultiAgentInitializedEvent) + assert events_list[0].source == swarm + + assert isinstance(events_list[1], BeforeMultiAgentInvocationEvent) + assert events_list[1].source == swarm + + assert isinstance(events_list[2], BeforeNodeCallEvent) + assert events_list[2].source == swarm + assert events_list[2].node_id == "agent1" + + assert isinstance(events_list[3], AfterNodeCallEvent) + assert events_list[3].source == swarm + assert events_list[3].node_id == "agent1" + + assert isinstance(events_list[4], AfterMultiAgentInvocationEvent) + assert events_list[4].source == swarm + + +def test_graph_complete_hook_lifecycle(graph, hook_provider): + """E2E test verifying complete hook lifecycle for Graph.""" + result = graph("test task") + + length, events = hook_provider.get_events() + assert length == 7 + assert result.status.value == "completed" + + events_list = list(events) + + # Check event types and basic properties, ignoring invocation_state + assert isinstance(events_list[0], MultiAgentInitializedEvent) + assert events_list[0].source == graph + + assert isinstance(events_list[1], BeforeMultiAgentInvocationEvent) + assert events_list[1].source == graph + + assert isinstance(events_list[2], BeforeNodeCallEvent) + assert events_list[2].source == graph + assert events_list[2].node_id == "agent1" + + assert isinstance(events_list[3], AfterNodeCallEvent) + assert events_list[3].source == graph + assert events_list[3].node_id == "agent1" + + assert isinstance(events_list[4], BeforeNodeCallEvent) + assert events_list[4].source == graph + assert events_list[4].node_id == "agent2" + + assert isinstance(events_list[5], AfterNodeCallEvent) + assert events_list[5].source == graph + assert events_list[5].node_id == "agent2" + + assert isinstance(events_list[6], AfterMultiAgentInvocationEvent) + assert events_list[6].source == graph diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 07037a447..b32356cb4 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -10,6 +10,7 @@ from strands.hooks.registry import HookProvider, HookRegistry from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status +from strands.session.file_session_manager import FileSessionManager from strands.session.session_manager import SessionManager @@ -1979,3 +1980,56 @@ async def stream_without_result(*args, **kwargs): mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called_once() + + +@pytest.mark.asyncio +async def test_graph_persisted(mock_strands_tracer, mock_use_span): + """Test graph persistence functionality.""" + # Create mock session manager + session_manager = Mock(spec=FileSessionManager) + session_manager.read_multi_agent().return_value = None + + # Create simple graph with session manager + builder = GraphBuilder() + agent = create_mock_agent("test_agent") + builder.add_node(agent, "test_node") + builder.set_entry_point("test_node") + builder.set_session_manager(session_manager) + + graph = builder.build() + + # Test get_state_from_orchestrator + state = graph.serialize_state() + assert state["type"] == "graph" + assert state["id"] == "default_graph" + assert "status" in state + assert "completed_nodes" in state + assert "node_results" in state + + # Test apply_state_from_dict with persisted state + persisted_state = { + "status": "executing", + "completed_nodes": [], + "failed_nodes": [], + "node_results": {}, + "current_task": "persisted task", + "execution_order": [], + "next_nodes_to_execute": ["test_node"], + } + + graph.deserialize_state(persisted_state) + assert graph.state.task == "persisted task" + + # Execute graph to test persistence integration + result = await graph.invoke_async("Test persistence") + + # Verify execution completed + assert result.status == Status.COMPLETED + assert len(result.results) == 1 + assert "test_node" in result.results + + # Test state serialization after execution + final_state = graph.serialize_state() + assert final_state["status"] == "completed" + assert len(final_state["completed_nodes"]) == 1 + assert "test_node" in final_state["node_results"] diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 14a0ac1d6..e8a6a5f79 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -9,6 +9,7 @@ from strands.hooks.registry import HookRegistry from strands.multiagent.base import Status from strands.multiagent.swarm import SharedContext, Swarm, SwarmNode, SwarmResult, SwarmState +from strands.session.file_session_manager import FileSessionManager from strands.session.session_manager import SessionManager from strands.types._events import MultiAgentNodeStartEvent from strands.types.content import ContentBlock @@ -1098,3 +1099,53 @@ async def failing_execute_swarm(*args, **kwargs): # Verify the swarm status is FAILED assert swarm.state.completion_status == Status.FAILED + + +@pytest.mark.asyncio +async def test_swarm_persistence(mock_strands_tracer, mock_use_span): + """Test swarm persistence functionality.""" + # Create mock session manager + session_manager = Mock(spec=FileSessionManager) + session_manager.read_multi_agent.return_value = None + + # Create simple swarm with session manager + agent = create_mock_agent("test_agent") + swarm = Swarm([agent], session_manager=session_manager) + + # Test get_state_from_orchestrator + state = swarm.serialize_state() + assert state["type"] == "swarm" + assert state["id"] == "default_swarm" + assert "status" in state + assert "node_history" in state + assert "node_results" in state + assert "context" in state + + # Test apply_state_from_dict with persisted state + persisted_state = { + "status": "executing", + "node_history": [], + "node_results": {}, + "current_task": "persisted task", + "next_nodes_to_execute": ["test_agent"], + "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, + } + + swarm._from_dict(persisted_state) + assert swarm.state.task == "persisted task" + assert swarm.state.handoff_message == "test handoff" + assert swarm.shared_context.context["test_agent"]["key"] == "value" + + # Execute swarm to test persistence integration + result = await swarm.invoke_async("Test persistence") + + # Verify execution completed + assert result.status == Status.COMPLETED + assert len(result.results) == 1 + assert "test_agent" in result.results + + # Test state serialization after execution + final_state = swarm.serialize_state() + assert final_state["status"] == "completed" + assert len(final_state["node_history"]) == 1 + assert "test_agent" in final_state["node_results"] diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index a7335feb7..08343a554 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,4 +1,6 @@ from typing import Any, AsyncIterator +from unittest.mock import patch +from uuid import uuid4 import pytest @@ -13,6 +15,7 @@ ) from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status from strands.multiagent.graph import GraphBuilder +from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -458,3 +461,127 @@ async def test_graph_metrics_accumulation(): # Verify accumulated metrics are sum of node metrics total_tokens = sum(node_result.accumulated_usage["totalTokens"] for node_result in result.results.values()) assert result.accumulated_usage["totalTokens"] == total_tokens, "Accumulated tokens don't match sum of node tokens" + + +@pytest.mark.asyncio +async def test_graph_interrupt_and_resume(): + """Test graph interruption and resume functionality with FileSessionManager.""" + + session_id = str(uuid4()) + + # Create real agents + agent1 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 1", name="agent1") + agent2 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 2", name="agent2") + agent3 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 3", name="agent3") + + session_manager = FileSessionManager(session_id=session_id) + + builder = GraphBuilder() + builder.add_node(agent1, "node1") + builder.add_node(agent2, "node2") + builder.add_node(agent3, "node3") + builder.add_edge("node1", "node2") + builder.add_edge("node2", "node3") + builder.set_entry_point("node1") + builder.set_session_manager(session_manager) + + graph = builder.build() + + # Mock agent2 to fail on first execution + async def failing_stream_async(*args, **kwargs): + raise Exception("Simulated failure in agent2") + yield # This line is never reached, but makes it an async generator + + with patch.object(agent2, "stream_async", side_effect=failing_stream_async): + try: + await graph.invoke_async("This is a test task, just do it shortly") + raise AssertionError("Expected exception was not raised") + except Exception as e: + assert "Simulated failure in agent2" in str(e) + + # Verify partial execution was persisted + persisted_state = session_manager.read_multi_agent(session_id, graph.id) + assert persisted_state is not None + assert persisted_state["type"] == "graph" + assert persisted_state["status"] == "failed" + assert len(persisted_state["completed_nodes"]) == 1 # Only node1 completed + assert "node1" in persisted_state["completed_nodes"] + assert "node2" in persisted_state["next_nodes_to_execute"] + assert "node2" in persisted_state["failed_nodes"] + + # Track execution count before resume + initial_execution_count = graph.state.execution_count + + # Execute graph again + result = await graph.invoke_async("Test task") + + # Verify successful completion + assert result.status == Status.COMPLETED + assert len(result.results) == 3 + + execution_order_ids = [node.node_id for node in result.execution_order] + assert execution_order_ids == ["node1", "node2", "node3"] + + # Verify only 2 additional nodes were executed + assert result.execution_count == initial_execution_count + 2 + + final_state = session_manager.read_multi_agent(session_id, graph.id) + assert final_state["status"] == "completed" + assert len(final_state["completed_nodes"]) == 3 + + # Clean up + session_manager.delete_session(session_id) + + +@pytest.mark.asyncio +async def test_self_loop_resume_from_persisted_state(tmp_path): + """Test resuming self-loop from persisted state where next node is itself.""" + + session_id = f"self_loop_resume_{uuid4()}" + session_manager = FileSessionManager(session_id=session_id, storage_dir=str(tmp_path)) + + counter_agent = Agent( + model="us.amazon.nova-pro-v1:0", + system_prompt="You are a counter. Just respond with 'Count: 1', 'Count: 2', Stop at 5.", + ) + + def should_continue_loop(state): + loop_executions = len([node for node in state.execution_order if node.node_id == "loop_node"]) + return loop_executions < 5 + + builder = GraphBuilder() + builder.add_node(counter_agent, "loop_node") + builder.add_edge("loop_node", "loop_node", condition=should_continue_loop) + builder.set_entry_point("loop_node") + builder.set_session_manager(session_manager) + builder.reset_on_revisit(True) + + graph = builder.build() + + call_count = 0 + original_stream = counter_agent.stream_async + + async def failing_after_two(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + async for event in original_stream(*args, **kwargs): + yield event + else: + raise Exception("Simulated failure after two executions") + + with patch.object(counter_agent, "stream_async", side_effect=failing_after_two): + try: + await graph.invoke_async("Count till 5") + except Exception as e: + assert "Simulated failure after two executions" in str(e) + + persisted_state = session_manager.read_multi_agent(session_id, graph.id) + assert persisted_state["status"] == "failed" + assert "loop_node" in persisted_state.get("failed_nodes") + assert len(persisted_state.get("execution_order")) == 2 + + result = await graph.invoke_async("Continue counting to 5") + assert result.status == Status.COMPLETED + assert len(result.execution_order) == 5 + assert all(node.node_id == "loop_node" for node in result.execution_order) diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index ae9129fbb..771030619 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,3 +1,6 @@ +from unittest.mock import patch +from uuid import uuid4 + import pytest from strands import Agent, tool @@ -10,7 +13,9 @@ BeforeToolCallEvent, MessageAddedEvent, ) +from strands.multiagent.base import Status from strands.multiagent.swarm import Swarm +from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock from tests.fixtures.mock_hook_provider import MockHookProvider @@ -319,3 +324,55 @@ async def test_swarm_get_agent_results_flattening(): assert len(agent_results) == 1 assert isinstance(agent_results[0], AgentResult) assert agent_results[0].message is not None + + +@pytest.mark.asyncio +async def test_swarm_interrupt_and_resume(researcher_agent, analyst_agent, writer_agent): + """Test swarm interruption after analyst_agent and resume functionality.""" + session_id = str(uuid4()) + + # Create session manager + session_manager = FileSessionManager(session_id=session_id) + + # Create swarm with session manager + swarm = Swarm([researcher_agent, analyst_agent, writer_agent], session_manager=session_manager) + + # Mock analyst_agent's _invoke method to fail + async def failing_invoke(*args, **kwargs): + raise Exception("Simulated failure in analyst") + yield # This line is never reached, but makes it an async generator + + with patch.object(analyst_agent, "stream_async", side_effect=failing_invoke): + # First execution - should fail at analyst + result = await swarm.invoke_async("Research AI trends and create a brief report") + try: + assert result.status == Status.FAILED + except Exception as e: + assert "Simulated failure in analyst" in str(e) + + # Verify partial execution was persisted + persisted_state = session_manager.read_multi_agent(session_id, swarm.id) + assert persisted_state is not None + assert persisted_state["type"] == "swarm" + assert persisted_state["status"] == "failed" + assert len(persisted_state["node_history"]) == 1 # At least researcher executed + + # Track execution count before resume + initial_execution_count = len(persisted_state["node_history"]) + + # Execute swarm again - should automatically resume from saved state + result = await swarm.invoke_async("Research AI trends and create a brief report") + + # Verify successful completion + assert result.status == Status.COMPLETED + assert len(result.results) > 0 + + assert len(result.node_history) >= initial_execution_count + 1 + + node_names = [node.node_id for node in result.node_history] + assert "researcher" in node_names + # Either analyst or writer (or both) should have executed to complete the task + assert "analyst" in node_names or "writer" in node_names + + # Clean up + session_manager.delete_session(session_id)