Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Provides minimal foundation for multi-agent patterns (Swarm, Graph).
"""

import copy
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
Expand All @@ -22,6 +24,105 @@ class Status(Enum):
FAILED = "failed"


@dataclass
class MultiAgentNode:
"""Base class for nodes in multi-agent systems."""

node_id: str

def __hash__(self) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit we can remove the hash from SwarmNode since we can now inherit

"""Return hash for MultiAgentNode based on node_id."""
return hash(self.node_id)

def __eq__(self, other: Any) -> bool:
"""Return equality for MultiAgentNode based on node_id."""
if not isinstance(other, MultiAgentNode):
return False
return self.node_id == other.node_id


@dataclass
class SharedContext:
"""Shared context between multi-agent nodes.

This class provides a key-value store for sharing information across nodes
in multi-agent systems like Graph and Swarm. It validates that all values
are JSON serializable to ensure compatibility.
"""

context: dict[str, dict[str, Any]] = field(default_factory=dict)

def add_context(self, node: MultiAgentNode, key: str, value: Any) -> None:
"""Add context for a specific node.

Args:
node: The node object to add context for
key: The key to store the value under
value: The value to store (must be JSON serializable)

Raises:
ValueError: If key is invalid or value is not JSON serializable
"""
self._validate_key(key)
self._validate_json_serializable(value)

if node.node_id not in self.context:
self.context[node.node_id] = {}
self.context[node.node_id][key] = value

def get_context(self, node: MultiAgentNode, key: str | None = None) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I'm set on this approach because of backwards compatibility.

If context were private _context then the deepcopy could make sense to me. Since its not, having a user access different objects though the two approaches feels wrong.

get_context(node, key) and context[node.node_id][key]

"""Get context for a specific node.

Args:
node: The node object to get context for
key: The specific key to retrieve (if None, returns all context for the node)

Returns:
The stored value, entire context dict for the node, or None if not found
"""
if node.node_id not in self.context:
return None if key else {}

if key is None:
return copy.deepcopy(self.context[node.node_id])
else:
value = self.context[node.node_id].get(key)
return copy.deepcopy(value) if value is not None else None

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e


@dataclass
class NodeResult:
"""Unified result from node execution - handles both Agent and nested MultiAgentBase results.
Expand Down
28 changes: 25 additions & 3 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..telemetry import get_tracer
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status

logger = logging.getLogger(__name__)

Expand All @@ -46,6 +46,7 @@ class GraphState:
task: The original input prompt/query provided to the graph execution.
This represents the actual work to be performed by the graph as a whole.
Entry point nodes receive this task as their input if they have no dependencies.
shared_context: Context shared between graph nodes for storing user-defined state.
"""

# Task (with default empty string)
Expand All @@ -61,6 +62,9 @@ class GraphState:
# Results
results: dict[str, NodeResult] = field(default_factory=dict)

# User-defined state shared across nodes
shared_context: "SharedContext" = field(default_factory=lambda: SharedContext())

# Accumulated metrics
accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0))
accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0))
Expand Down Expand Up @@ -126,7 +130,7 @@ def should_traverse(self, state: GraphState) -> bool:


@dataclass
class GraphNode:
class GraphNode(MultiAgentNode):
"""Represents a node in the graph.

The execution_status tracks the node's lifecycle within graph orchestration:
Expand All @@ -135,7 +139,6 @@ class GraphNode:
- COMPLETED/FAILED: Node finished executing (regardless of result quality)
"""

node_id: str
executor: Agent | MultiAgentBase
dependencies: set["GraphNode"] = field(default_factory=set)
execution_status: Status = Status.PENDING
Expand Down Expand Up @@ -389,6 +392,25 @@ def __init__(
self.state = GraphState()
self.tracer = get_tracer()

@property
def shared_context(self) -> SharedContext:
"""Access to the shared context for storing user-defined state across graph nodes.

Returns:
The SharedContext instance that can be used to store and retrieve
information that should be accessible to all nodes in the graph.

Example:
```python
graph = Graph(...)
node1 = graph.nodes["node1"]
node2 = graph.nodes["node2"]
graph.shared_context.add_context(node1, "file_reference", "/path/to/file")
graph.shared_context.get_context(node2, "file_reference")
```
"""
return self.state.shared_context
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay I need to take a look at this again more deeply.

What is strange to be is that we have different hierarchies for Swarm and Graph.

For Swarm we have SharedContext at the same level as SwarmState. For Graph we have SharedContext nested within GraphState.

Even if we kept it this way I'm not sure we need this convenience method to surface shared context at the GraphNode level. It feels strange to place the context on the state then surface it higher.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll come back tomorrow with a final decision on the path for this.


def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
"""Invoke the graph synchronously."""

Expand Down
60 changes: 7 additions & 53 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import asyncio
import copy
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -29,16 +28,15 @@
from ..tools.decorator import tool
from ..types.content import ContentBlock, Messages
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
from .base import MultiAgentBase, MultiAgentNode, MultiAgentResult, NodeResult, SharedContext, Status

logger = logging.getLogger(__name__)


@dataclass
class SwarmNode:
class SwarmNode(MultiAgentNode):
"""Represents a node (e.g. Agent) in the swarm."""

node_id: str
executor: Agent
_initial_messages: Messages = field(default_factory=list, init=False)
_initial_state: AgentState = field(default_factory=AgentState, init=False)
Expand Down Expand Up @@ -73,55 +71,6 @@ def reset_executor_state(self) -> None:
self.executor.state = AgentState(self._initial_state.get())


@dataclass
class SharedContext:
"""Shared context between swarm nodes."""

context: dict[str, dict[str, Any]] = field(default_factory=dict)

def add_context(self, node: SwarmNode, key: str, value: Any) -> None:
"""Add context."""
self._validate_key(key)
self._validate_json_serializable(value)

if node.node_id not in self.context:
self.context[node.node_id] = {}
self.context[node.node_id][key] = value

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e


@dataclass
class SwarmState:
"""Current state of swarm execution."""
Expand Down Expand Up @@ -654,3 +603,8 @@ def _build_result(self) -> SwarmResult:
execution_time=self.state.execution_time,
node_history=self.state.node_history,
)


# Backward compatibility aliases
# These ensure that existing imports continue to work
__all__ = ["SwarmNode", "SharedContext", "Status"]
Loading
Loading