Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 179 additions & 10 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
from .._async import run_async
from ..agent import Agent
from ..agent.state import AgentState
from ..experimental.hooks.multiagent import (
AfterMultiAgentInvocationEvent,
AfterNodeCallEvent,
BeforeMultiAgentInvocationEvent,
BeforeNodeCallEvent,
MultiAgentInitializedEvent,
)
from ..hooks import HookProvider, HookRegistry
from ..session import SessionManager
from ..telemetry import get_tracer
from ..types._events import (
MultiAgentHandoffEvent,
Expand All @@ -40,6 +49,8 @@

logger = logging.getLogger(__name__)

_DEFAULT_GRAPH_ID = "default_graph"


@dataclass
class GraphState:
Expand Down Expand Up @@ -223,6 +234,9 @@ def __init__(self) -> None:
self._execution_timeout: Optional[float] = None
self._node_timeout: Optional[float] = None
self._reset_on_revisit: bool = False
self._id: str = _DEFAULT_GRAPH_ID
self._session_manager: Optional[SessionManager] = None
self._hooks: Optional[list[HookProvider]] = None

def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
Expand Down Expand Up @@ -313,6 +327,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
self._node_timeout = timeout
return self

def set_graph_id(self, graph_id: str) -> "GraphBuilder":
"""Set graph id.

Args:
graph_id: Unique graph id
"""
self._id = graph_id
return self

def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder":
"""Set session manager for the graph.

Args:
session_manager: SessionManager instance
"""
self._session_manager = session_manager
return self

def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder":
"""Set hook providers for the graph.

Args:
hooks: Customer hooks user passes in
"""
self._hooks = hooks
return self

def build(self) -> "Graph":
"""Build and validate the graph with configured settings."""
if not self.nodes:
Expand All @@ -338,6 +379,9 @@ def build(self) -> "Graph":
execution_timeout=self._execution_timeout,
node_timeout=self._node_timeout,
reset_on_revisit=self._reset_on_revisit,
session_manager=self._session_manager,
hooks=self._hooks,
id=self._id,
)

def _validate_graph(self) -> None:
Expand Down Expand Up @@ -365,6 +409,9 @@ def __init__(
execution_timeout: Optional[float] = None,
node_timeout: Optional[float] = None,
reset_on_revisit: bool = False,
session_manager: Optional[SessionManager] = None,
hooks: Optional[list[HookProvider]] = None,
id: str = _DEFAULT_GRAPH_ID,
) -> None:
"""Initialize Graph with execution limits and reset behavior.

Expand All @@ -376,6 +423,9 @@ def __init__(
execution_timeout: Total execution timeout in seconds (default: None - no limit)
node_timeout: Individual node timeout in seconds (default: None - no limit)
reset_on_revisit: Whether to reset node state when revisited (default: False)
session_manager: Session manager for persisting graph state and execution history (default: None)
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
id: Unique graph id (default: None)
"""
super().__init__()

Expand All @@ -391,6 +441,19 @@ def __init__(
self.reset_on_revisit = reset_on_revisit
self.state = GraphState()
self.tracer = get_tracer()
self.session_manager = session_manager
self.hooks = HookRegistry()
if self.session_manager:
self.hooks.add_hook(self.session_manager)
if hooks:
for hook in hooks:
self.hooks.add_hook(hook)

self._resume_next_nodes: list[GraphNode] = []
self._resume_from_session = False
self.id = id

self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))

def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
Expand Down Expand Up @@ -453,18 +516,25 @@ async def stream_async(
if invocation_state is None:
invocation_state = {}

self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state))

logger.debug("task=<%s> | starting graph execution", task)

# Initialize state
start_time = time.time()
self.state = GraphState(
status=Status.EXECUTING,
task=task,
total_nodes=len(self.nodes),
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
entry_points=list(self.entry_points),
start_time=start_time,
)
if not self._resume_from_session:
# Initialize state
self.state = GraphState(
status=Status.EXECUTING,
task=task,
total_nodes=len(self.nodes),
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
entry_points=list(self.entry_points),
start_time=start_time,
)
else:
self.state.status = Status.EXECUTING
self.state.start_time = start_time

span = self.tracer.start_multiagent_span(task, "graph")
with trace_api.use_span(span, end_on_exit=True):
Expand Down Expand Up @@ -499,6 +569,9 @@ async def stream_async(
raise
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self))
self._resume_from_session = False
self._resume_next_nodes.clear()

def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
"""Validate graph nodes for duplicate instances."""
Expand All @@ -514,7 +587,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:

async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute graph and yield TypedEvent objects."""
ready_nodes = list(self.entry_points)
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)

while ready_nodes:
# Check execution limits before continuing
Expand Down Expand Up @@ -703,7 +776,9 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[

async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute a single node and yield TypedEvent objects."""
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state))

# Reset the node's state if reset_on_revisit is enabled, and it's being revisited
if self.reset_on_revisit and node in self.state.completed_nodes:
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
node.reset_executor_state()
Expand Down Expand Up @@ -844,6 +919,9 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
# Re-raise to stop graph execution (fail-fast behavior)
raise

finally:
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))

def _accumulate_metrics(self, node_result: NodeResult) -> None:
"""Accumulate metrics from a node result."""
self.state.accumulated_usage["inputTokens"] += node_result.accumulated_usage.get("inputTokens", 0)
Expand Down Expand Up @@ -928,3 +1006,94 @@ def _build_result(self) -> GraphResult:
edges=self.state.edges,
entry_points=self.state.entry_points,
)

def serialize_state(self) -> dict[str, Any]:
"""Serialize the current graph state to a dictionary."""
compute_nodes = self._compute_ready_nodes_for_resume()
next_nodes = [n.node_id for n in compute_nodes] if compute_nodes else []
return {
"type": "graph",
"id": self.id,
"status": self.state.status.value,
"completed_nodes": [n.node_id for n in self.state.completed_nodes],
"failed_nodes": [n.node_id for n in self.state.failed_nodes],
"node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()},
"next_nodes_to_execute": next_nodes,
"current_task": self.state.task,
"execution_order": [n.node_id for n in self.state.execution_order],
}

def deserialize_state(self, payload: dict[str, Any]) -> None:
"""Restore graph state from a session dict and prepare for execution.

This method handles two scenarios:
1. If the graph execution ended (no next_nodes_to_execute, eg: Completed, or Failed with dead end nodes),
resets all nodes and graph state to allow re-execution from the beginning.
2. If the graph execution was interrupted mid-execution (has next_nodes_to_execute),
restores the persisted state and prepares to resume execution from the next ready nodes.

Args:
payload: Dictionary containing persisted state data including status,
completed nodes, results, and next nodes to execute.
"""
if not payload.get("next_nodes_to_execute"):
# Reset all nodes
for node in self.nodes.values():
node.reset_executor_state()
# Reset graph state
self.state = GraphState()
self._resume_from_session = False
return
else:
self._from_dict(payload)
self._resume_from_session = True

def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
if self.state.status == Status.PENDING:
return []
ready_nodes: list[GraphNode] = []
completed_nodes = set(self.state.completed_nodes)
for node in self.nodes.values():
if node in completed_nodes:
continue
incoming = [e for e in self.edges if e.to_node is node]
if not incoming:
ready_nodes.append(node)
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
ready_nodes.append(node)

return ready_nodes

def _from_dict(self, payload: dict[str, Any]) -> None:
self.state.status = Status(payload["status"])
# Hydrate completed nodes & results
raw_results = payload.get("node_results") or {}
results: dict[str, NodeResult] = {}
for node_id, entry in raw_results.items():
if node_id not in self.nodes:
continue
try:
results[node_id] = NodeResult.from_dict(entry)
except Exception:
logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id)
raise
self.state.results = results

self.state.failed_nodes = set(
self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes
)

# Restore completed nodes from persisted data
completed_node_ids = payload.get("completed_nodes") or []
self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes}

# Execution order (only nodes that still exist)
order_node_ids = payload.get("execution_order") or []
self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes]

# Task
self.state.task = payload.get("current_task", self.state.task)

# next nodes to execute
next_nodes = [self.nodes[nid] for nid in (payload.get("next_nodes_to_execute") or []) if nid in self.nodes]
self._resume_next_nodes = next_nodes
Loading
Loading